mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-15 00:36:29 +00:00
Tested queryLLM and StorageCache compressed saving/loading
This commit is contained in:
parent
476e59365f
commit
18885b4b89
1
chainforge/react-server/package-lock.json
generated
1
chainforge/react-server/package-lock.json
generated
@ -45,6 +45,7 @@
|
||||
"emoji-mart": "^5.5.2",
|
||||
"emoji-picker-react": "^4.4.9",
|
||||
"google-auth-library": "^8.8.0",
|
||||
"lz-string": "^1.5.0",
|
||||
"mantine-contextmenu": "^1.2.15",
|
||||
"mantine-react-table": "^1.0.0-beta.8",
|
||||
"openai": "^3.3.0",
|
||||
|
@ -40,6 +40,7 @@
|
||||
"emoji-mart": "^5.5.2",
|
||||
"emoji-picker-react": "^4.4.9",
|
||||
"google-auth-library": "^8.8.0",
|
||||
"lz-string": "^1.5.0",
|
||||
"mantine-contextmenu": "^1.2.15",
|
||||
"mantine-react-table": "^1.0.0-beta.8",
|
||||
"openai": "^3.3.0",
|
||||
|
28
chainforge/react-server/src/backend/__test__/backend.test.ts
Normal file
28
chainforge/react-server/src/backend/__test__/backend.test.ts
Normal file
@ -0,0 +1,28 @@
|
||||
/*
|
||||
* @jest-environment node
|
||||
*/
|
||||
import { LLM } from '../models';
|
||||
import { expect, test } from '@jest/globals';
|
||||
import { queryLLM } from '../backend';
|
||||
import { StandardizedLLMResponse } from '../typing';
|
||||
|
||||
test('call three LLMs with a single prompt', async () => {
|
||||
// Setup params to call
|
||||
const prompt = 'What is one major difference between French and English languages? Be brief.'
|
||||
const llms = [LLM.OpenAI_ChatGPT, LLM.Claude_v1, LLM.PaLM2_Chat_Bison];
|
||||
const n = 1;
|
||||
const progress_listener = (progress: {[key: symbol]: any}) => {
|
||||
console.log(JSON.stringify(progress));
|
||||
};
|
||||
|
||||
// Call all three LLMs with the same prompt, n=1, and listen to progress
|
||||
const {responses, errors} = await queryLLM('testid', llms, n, prompt, {}, undefined, progress_listener);
|
||||
|
||||
// Check responses
|
||||
expect(responses).toHaveLength(3);
|
||||
responses.forEach((resp_obj: StandardizedLLMResponse) => {
|
||||
expect(resp_obj.prompt).toBe(prompt);
|
||||
expect(resp_obj.responses).toHaveLength(1); // since n = 1
|
||||
expect(Object.keys(resp_obj.vars)).toHaveLength(0);
|
||||
});
|
||||
}, 40000);
|
32
chainforge/react-server/src/backend/__test__/cache.test.ts
Normal file
32
chainforge/react-server/src/backend/__test__/cache.test.ts
Normal file
@ -0,0 +1,32 @@
|
||||
import { expect, test } from '@jest/globals';
|
||||
import StorageCache from '../cache';
|
||||
|
||||
test('saving and loading cache data from localStorage', () => {
|
||||
// Store Unicode and numeric data into StorageCache
|
||||
StorageCache.store('hello', {'a': '土', 'b': 'ہوا', 'c': '火'});
|
||||
StorageCache.store('world', 42);
|
||||
|
||||
// Verify stored data:
|
||||
let d = StorageCache.get('hello');
|
||||
expect(d).toHaveProperty('a');
|
||||
expect(d?.a).toBe('土');
|
||||
|
||||
// Save to localStorage
|
||||
StorageCache.saveToLocalStorage('test');
|
||||
|
||||
// Remove all data in the cache
|
||||
StorageCache.clear();
|
||||
|
||||
// Double-check there's no data:
|
||||
d = StorageCache.get('hello');
|
||||
expect(d).toBeUndefined();
|
||||
|
||||
// Load cache from localStorage
|
||||
StorageCache.loadFromLocalStorage('test');
|
||||
|
||||
// Verify stored data:
|
||||
d = StorageCache.get('hello');
|
||||
expect(d).toHaveProperty('c');
|
||||
expect(d?.c).toBe('火');
|
||||
expect(StorageCache.get('world')).toBe(42);
|
||||
});
|
@ -9,7 +9,7 @@
|
||||
// from chainforge.promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
// from chainforge.promptengine.utils import LLM, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists, set_api_keys
|
||||
|
||||
import { Dict, LLMResponseError, LLMResponseObject } from "./typing";
|
||||
import { Dict, LLMResponseError, LLMResponseObject, StandardizedLLMResponse } from "./typing";
|
||||
import { LLM } from "./models";
|
||||
import { set_api_keys } from "./utils";
|
||||
import StorageCache from "./cache";
|
||||
@ -117,16 +117,6 @@ Object.entries(LLM).forEach(([key, value]) => {
|
||||
// md_ast_parser = mistune.create_markdown(renderer='ast')
|
||||
// return md_ast_parser(self.text)
|
||||
|
||||
interface StandardizedLLMResponse {
|
||||
llm: string | Dict,
|
||||
prompt: string,
|
||||
responses: Array<string>,
|
||||
vars: Dict,
|
||||
metavars: Dict,
|
||||
tokens: Dict,
|
||||
eval_res?: Dict,
|
||||
}
|
||||
|
||||
function to_standard_format(r: LLMResponseObject | Dict): StandardizedLLMResponse {
|
||||
let resp_obj = {
|
||||
vars: r['info'],
|
||||
@ -555,6 +545,7 @@ export async function queryLLM(id: string,
|
||||
prompt: string,
|
||||
vars: Dict,
|
||||
api_keys?: Dict,
|
||||
progress_listener?: (progress: {[key: symbol]: any}) => void,
|
||||
no_cache?: boolean): Promise<Dict> {
|
||||
// Verify the integrity of the params
|
||||
if (typeof id !== 'string' || id.trim().length === 0)
|
||||
@ -582,7 +573,8 @@ export async function queryLLM(id: string,
|
||||
|
||||
// Get the storage keys of any cache files for specific models + settings
|
||||
const llms = llm;
|
||||
let cache = StorageCache.get(id); // returns {} if 'id' is not in the storage cache yet
|
||||
let cache: Dict = StorageCache.get(id) || {}; // returns {} if 'id' is not in the storage cache yet
|
||||
|
||||
let llm_to_cache_filename = {};
|
||||
let past_cache_files = {};
|
||||
if (typeof cache === 'object' && cache.cache_files !== undefined) {
|
||||
@ -608,7 +600,7 @@ export async function queryLLM(id: string,
|
||||
// Create a new cache JSON object
|
||||
cache = { cache_files: {}, responses_last_run: [] };
|
||||
let prev_filenames: Array<string> = [];
|
||||
llms.forEach(llm_spec => {
|
||||
llms.forEach((llm_spec: string | Dict) => {
|
||||
const fname = gen_unique_cache_filename(id, prev_filenames);
|
||||
llm_to_cache_filename[extract_llm_key(llm_spec)] = fname;
|
||||
cache.cache_files[fname] = llm_spec;
|
||||
@ -625,10 +617,13 @@ export async function queryLLM(id: string,
|
||||
let progressProxy = new Proxy(progress, {
|
||||
set: function (target, key, value) {
|
||||
console.log(`${key.toString()} set to ${value.toString()}`);
|
||||
// ...
|
||||
// Call any callbacks here
|
||||
// ...
|
||||
target[key] = value;
|
||||
|
||||
// If the caller provided a callback, notify it
|
||||
// of the changes to the 'progress' object:
|
||||
if (progress_listener)
|
||||
progress_listener(target);
|
||||
|
||||
return true;
|
||||
}
|
||||
});
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { Dict } from "./typing";
|
||||
import LZString from 'lz-string';
|
||||
|
||||
/**
|
||||
* Singleton JSON cache that functions like a local filesystem in a Python backend,
|
||||
@ -21,10 +22,10 @@ export default class StorageCache {
|
||||
return StorageCache.instance;
|
||||
}
|
||||
|
||||
private getCacheData(key: string): Dict {
|
||||
return this.data[key] || {};
|
||||
private getCacheData(key: string): Dict | undefined {
|
||||
return this.data[key] || undefined;
|
||||
}
|
||||
public static get(key: string): Dict {
|
||||
public static get(key: string): Dict | undefined {
|
||||
return StorageCache.getInstance().getCacheData(key);
|
||||
}
|
||||
|
||||
@ -34,4 +35,62 @@ export default class StorageCache {
|
||||
public static store(key: string, data: any): void {
|
||||
StorageCache.getInstance().storeCacheData(key, data);
|
||||
}
|
||||
|
||||
private clearCache(): void {
|
||||
this.data = {};
|
||||
}
|
||||
public static clear(): void {
|
||||
StorageCache.getInstance().clearCache();
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to store the entire cache in localStorage.
|
||||
* Performs lz-string compression (https://pieroxy.net/blog/pages/lz-string/index.html)
|
||||
* before storing a JSON object in UTF encoding.
|
||||
*
|
||||
* Use loadFromLocalStorage to unpack the localStorage data.
|
||||
*
|
||||
* @param localStorageKey The key that will be used in localStorage (default='chainforge')
|
||||
* @returns True if succeeded, false if failure (e.g., too big for localStorage).
|
||||
*/
|
||||
public static saveToLocalStorage(localStorageKey: string='chainforge'): boolean {
|
||||
const data = StorageCache.getInstance().data;
|
||||
const compressed = LZString.compressToUTF16(JSON.stringify(data));
|
||||
try {
|
||||
localStorage.setItem(localStorageKey, compressed);
|
||||
return true;
|
||||
} catch (error) {
|
||||
if (error instanceof DOMException && error.name === "QuotaExceededError") {
|
||||
// Handle the error when storage quota is exceeded
|
||||
console.warn("Storage quota exceeded");
|
||||
} else {
|
||||
// Handle other types of storage-related errors
|
||||
console.error("Error storing data in localStorage:", error.message);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to load a previously stored cache JSON from localStorage.
|
||||
* Performs lz-string decompression from UTF16 encoding.
|
||||
*
|
||||
* @param localStorageKey The key that will be used in localStorage (default='chainforge')
|
||||
* @returns True if succeeded, false if failure (e.g., key not found).
|
||||
*/
|
||||
public static loadFromLocalStorage(localStorageKey: string='chainforge'): boolean {
|
||||
const compressed = localStorage.getItem(localStorageKey);
|
||||
if (!compressed) {
|
||||
console.error(`Could not find cache data in localStorage with key ${localStorageKey}.`);
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
let data = JSON.parse(LZString.decompressFromUTF16(compressed));
|
||||
StorageCache.getInstance().data = data;
|
||||
return true;
|
||||
} catch (error) {
|
||||
console.error(error.message);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
@ -214,7 +214,7 @@ export class PromptPipeline {
|
||||
* Useful for continuing if computation was interrupted halfway through.
|
||||
*/
|
||||
_load_cached_responses(): {[key: string]: LLMResponseObject} {
|
||||
return StorageCache.get(this._storageKey);
|
||||
return StorageCache.get(this._storageKey) || {};
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -29,4 +29,15 @@ export interface LLMAPICall {
|
||||
n: number,
|
||||
temperature: number,
|
||||
params?: Dict): Promise<[Dict, Dict]>
|
||||
}
|
||||
|
||||
/** A standard response format expected by the front-end. */
|
||||
export interface StandardizedLLMResponse {
|
||||
llm: string | Dict,
|
||||
prompt: string,
|
||||
responses: Array<string>,
|
||||
vars: Dict,
|
||||
metavars: Dict,
|
||||
tokens: Dict,
|
||||
eval_res?: Dict,
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user