diff --git a/chainforge/react-server/package-lock.json b/chainforge/react-server/package-lock.json index b0aca99..1e01d0e 100644 --- a/chainforge/react-server/package-lock.json +++ b/chainforge/react-server/package-lock.json @@ -45,6 +45,7 @@ "emoji-mart": "^5.5.2", "emoji-picker-react": "^4.4.9", "google-auth-library": "^8.8.0", + "lz-string": "^1.5.0", "mantine-contextmenu": "^1.2.15", "mantine-react-table": "^1.0.0-beta.8", "openai": "^3.3.0", diff --git a/chainforge/react-server/package.json b/chainforge/react-server/package.json index 748d900..e16c3a8 100644 --- a/chainforge/react-server/package.json +++ b/chainforge/react-server/package.json @@ -40,6 +40,7 @@ "emoji-mart": "^5.5.2", "emoji-picker-react": "^4.4.9", "google-auth-library": "^8.8.0", + "lz-string": "^1.5.0", "mantine-contextmenu": "^1.2.15", "mantine-react-table": "^1.0.0-beta.8", "openai": "^3.3.0", diff --git a/chainforge/react-server/src/backend/__test__/backend.test.ts b/chainforge/react-server/src/backend/__test__/backend.test.ts new file mode 100644 index 0000000..f56864f --- /dev/null +++ b/chainforge/react-server/src/backend/__test__/backend.test.ts @@ -0,0 +1,28 @@ +/* +* @jest-environment node +*/ +import { LLM } from '../models'; +import { expect, test } from '@jest/globals'; +import { queryLLM } from '../backend'; +import { StandardizedLLMResponse } from '../typing'; + +test('call three LLMs with a single prompt', async () => { + // Setup params to call + const prompt = 'What is one major difference between French and English languages? Be brief.' + const llms = [LLM.OpenAI_ChatGPT, LLM.Claude_v1, LLM.PaLM2_Chat_Bison]; + const n = 1; + const progress_listener = (progress: {[key: symbol]: any}) => { + console.log(JSON.stringify(progress)); + }; + + // Call all three LLMs with the same prompt, n=1, and listen to progress + const {responses, errors} = await queryLLM('testid', llms, n, prompt, {}, undefined, progress_listener); + + // Check responses + expect(responses).toHaveLength(3); + responses.forEach((resp_obj: StandardizedLLMResponse) => { + expect(resp_obj.prompt).toBe(prompt); + expect(resp_obj.responses).toHaveLength(1); // since n = 1 + expect(Object.keys(resp_obj.vars)).toHaveLength(0); + }); +}, 40000); \ No newline at end of file diff --git a/chainforge/react-server/src/backend/__test__/cache.test.ts b/chainforge/react-server/src/backend/__test__/cache.test.ts new file mode 100644 index 0000000..bb43221 --- /dev/null +++ b/chainforge/react-server/src/backend/__test__/cache.test.ts @@ -0,0 +1,32 @@ +import { expect, test } from '@jest/globals'; +import StorageCache from '../cache'; + +test('saving and loading cache data from localStorage', () => { + // Store Unicode and numeric data into StorageCache + StorageCache.store('hello', {'a': '土', 'b': 'ہوا', 'c': '火'}); + StorageCache.store('world', 42); + + // Verify stored data: + let d = StorageCache.get('hello'); + expect(d).toHaveProperty('a'); + expect(d?.a).toBe('土'); + + // Save to localStorage + StorageCache.saveToLocalStorage('test'); + + // Remove all data in the cache + StorageCache.clear(); + + // Double-check there's no data: + d = StorageCache.get('hello'); + expect(d).toBeUndefined(); + + // Load cache from localStorage + StorageCache.loadFromLocalStorage('test'); + + // Verify stored data: + d = StorageCache.get('hello'); + expect(d).toHaveProperty('c'); + expect(d?.c).toBe('火'); + expect(StorageCache.get('world')).toBe(42); +}); \ No newline at end of file diff --git a/chainforge/react-server/src/backend/backend.ts b/chainforge/react-server/src/backend/backend.ts index 2e8bf62..caebfaa 100644 --- a/chainforge/react-server/src/backend/backend.ts +++ b/chainforge/react-server/src/backend/backend.ts @@ -9,7 +9,7 @@ // from chainforge.promptengine.template import PromptTemplate, PromptPermutationGenerator // from chainforge.promptengine.utils import LLM, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists, set_api_keys -import { Dict, LLMResponseError, LLMResponseObject } from "./typing"; +import { Dict, LLMResponseError, LLMResponseObject, StandardizedLLMResponse } from "./typing"; import { LLM } from "./models"; import { set_api_keys } from "./utils"; import StorageCache from "./cache"; @@ -117,16 +117,6 @@ Object.entries(LLM).forEach(([key, value]) => { // md_ast_parser = mistune.create_markdown(renderer='ast') // return md_ast_parser(self.text) -interface StandardizedLLMResponse { - llm: string | Dict, - prompt: string, - responses: Array, - vars: Dict, - metavars: Dict, - tokens: Dict, - eval_res?: Dict, -} - function to_standard_format(r: LLMResponseObject | Dict): StandardizedLLMResponse { let resp_obj = { vars: r['info'], @@ -555,6 +545,7 @@ export async function queryLLM(id: string, prompt: string, vars: Dict, api_keys?: Dict, + progress_listener?: (progress: {[key: symbol]: any}) => void, no_cache?: boolean): Promise { // Verify the integrity of the params if (typeof id !== 'string' || id.trim().length === 0) @@ -582,7 +573,8 @@ export async function queryLLM(id: string, // Get the storage keys of any cache files for specific models + settings const llms = llm; - let cache = StorageCache.get(id); // returns {} if 'id' is not in the storage cache yet + let cache: Dict = StorageCache.get(id) || {}; // returns {} if 'id' is not in the storage cache yet + let llm_to_cache_filename = {}; let past_cache_files = {}; if (typeof cache === 'object' && cache.cache_files !== undefined) { @@ -608,7 +600,7 @@ export async function queryLLM(id: string, // Create a new cache JSON object cache = { cache_files: {}, responses_last_run: [] }; let prev_filenames: Array = []; - llms.forEach(llm_spec => { + llms.forEach((llm_spec: string | Dict) => { const fname = gen_unique_cache_filename(id, prev_filenames); llm_to_cache_filename[extract_llm_key(llm_spec)] = fname; cache.cache_files[fname] = llm_spec; @@ -625,10 +617,13 @@ export async function queryLLM(id: string, let progressProxy = new Proxy(progress, { set: function (target, key, value) { console.log(`${key.toString()} set to ${value.toString()}`); - // ... - // Call any callbacks here - // ... target[key] = value; + + // If the caller provided a callback, notify it + // of the changes to the 'progress' object: + if (progress_listener) + progress_listener(target); + return true; } }); diff --git a/chainforge/react-server/src/backend/cache.ts b/chainforge/react-server/src/backend/cache.ts index 34eadf0..810bb39 100644 --- a/chainforge/react-server/src/backend/cache.ts +++ b/chainforge/react-server/src/backend/cache.ts @@ -1,4 +1,5 @@ import { Dict } from "./typing"; +import LZString from 'lz-string'; /** * Singleton JSON cache that functions like a local filesystem in a Python backend, @@ -21,10 +22,10 @@ export default class StorageCache { return StorageCache.instance; } - private getCacheData(key: string): Dict { - return this.data[key] || {}; + private getCacheData(key: string): Dict | undefined { + return this.data[key] || undefined; } - public static get(key: string): Dict { + public static get(key: string): Dict | undefined { return StorageCache.getInstance().getCacheData(key); } @@ -34,4 +35,62 @@ export default class StorageCache { public static store(key: string, data: any): void { StorageCache.getInstance().storeCacheData(key, data); } + + private clearCache(): void { + this.data = {}; + } + public static clear(): void { + StorageCache.getInstance().clearCache(); + } + + /** + * Attempts to store the entire cache in localStorage. + * Performs lz-string compression (https://pieroxy.net/blog/pages/lz-string/index.html) + * before storing a JSON object in UTF encoding. + * + * Use loadFromLocalStorage to unpack the localStorage data. + * + * @param localStorageKey The key that will be used in localStorage (default='chainforge') + * @returns True if succeeded, false if failure (e.g., too big for localStorage). + */ + public static saveToLocalStorage(localStorageKey: string='chainforge'): boolean { + const data = StorageCache.getInstance().data; + const compressed = LZString.compressToUTF16(JSON.stringify(data)); + try { + localStorage.setItem(localStorageKey, compressed); + return true; + } catch (error) { + if (error instanceof DOMException && error.name === "QuotaExceededError") { + // Handle the error when storage quota is exceeded + console.warn("Storage quota exceeded"); + } else { + // Handle other types of storage-related errors + console.error("Error storing data in localStorage:", error.message); + } + return false; + } + } + + /** + * Attempts to load a previously stored cache JSON from localStorage. + * Performs lz-string decompression from UTF16 encoding. + * + * @param localStorageKey The key that will be used in localStorage (default='chainforge') + * @returns True if succeeded, false if failure (e.g., key not found). + */ + public static loadFromLocalStorage(localStorageKey: string='chainforge'): boolean { + const compressed = localStorage.getItem(localStorageKey); + if (!compressed) { + console.error(`Could not find cache data in localStorage with key ${localStorageKey}.`); + return false; + } + try { + let data = JSON.parse(LZString.decompressFromUTF16(compressed)); + StorageCache.getInstance().data = data; + return true; + } catch (error) { + console.error(error.message); + return false; + } + } } \ No newline at end of file diff --git a/chainforge/react-server/src/backend/query.ts b/chainforge/react-server/src/backend/query.ts index 97dce22..4a27eb0 100644 --- a/chainforge/react-server/src/backend/query.ts +++ b/chainforge/react-server/src/backend/query.ts @@ -214,7 +214,7 @@ export class PromptPipeline { * Useful for continuing if computation was interrupted halfway through. */ _load_cached_responses(): {[key: string]: LLMResponseObject} { - return StorageCache.get(this._storageKey); + return StorageCache.get(this._storageKey) || {}; } /** diff --git a/chainforge/react-server/src/backend/typing.ts b/chainforge/react-server/src/backend/typing.ts index 3b78201..7a03016 100644 --- a/chainforge/react-server/src/backend/typing.ts +++ b/chainforge/react-server/src/backend/typing.ts @@ -29,4 +29,15 @@ export interface LLMAPICall { n: number, temperature: number, params?: Dict): Promise<[Dict, Dict]> +} + +/** A standard response format expected by the front-end. */ +export interface StandardizedLLMResponse { + llm: string | Dict, + prompt: string, + responses: Array, + vars: Dict, + metavars: Dict, + tokens: Dict, + eval_res?: Dict, } \ No newline at end of file