Tested queryLLM and StorageCache compressed saving/loading

This commit is contained in:
Ian Arawjo 2023-06-26 10:08:58 -04:00
parent 476e59365f
commit 18885b4b89
8 changed files with 147 additions and 20 deletions

View File

@ -45,6 +45,7 @@
"emoji-mart": "^5.5.2",
"emoji-picker-react": "^4.4.9",
"google-auth-library": "^8.8.0",
"lz-string": "^1.5.0",
"mantine-contextmenu": "^1.2.15",
"mantine-react-table": "^1.0.0-beta.8",
"openai": "^3.3.0",

View File

@ -40,6 +40,7 @@
"emoji-mart": "^5.5.2",
"emoji-picker-react": "^4.4.9",
"google-auth-library": "^8.8.0",
"lz-string": "^1.5.0",
"mantine-contextmenu": "^1.2.15",
"mantine-react-table": "^1.0.0-beta.8",
"openai": "^3.3.0",

View File

@ -0,0 +1,28 @@
/*
* @jest-environment node
*/
import { LLM } from '../models';
import { expect, test } from '@jest/globals';
import { queryLLM } from '../backend';
import { StandardizedLLMResponse } from '../typing';
test('call three LLMs with a single prompt', async () => {
// Setup params to call
const prompt = 'What is one major difference between French and English languages? Be brief.'
const llms = [LLM.OpenAI_ChatGPT, LLM.Claude_v1, LLM.PaLM2_Chat_Bison];
const n = 1;
const progress_listener = (progress: {[key: symbol]: any}) => {
console.log(JSON.stringify(progress));
};
// Call all three LLMs with the same prompt, n=1, and listen to progress
const {responses, errors} = await queryLLM('testid', llms, n, prompt, {}, undefined, progress_listener);
// Check responses
expect(responses).toHaveLength(3);
responses.forEach((resp_obj: StandardizedLLMResponse) => {
expect(resp_obj.prompt).toBe(prompt);
expect(resp_obj.responses).toHaveLength(1); // since n = 1
expect(Object.keys(resp_obj.vars)).toHaveLength(0);
});
}, 40000);

View File

@ -0,0 +1,32 @@
import { expect, test } from '@jest/globals';
import StorageCache from '../cache';
test('saving and loading cache data from localStorage', () => {
// Store Unicode and numeric data into StorageCache
StorageCache.store('hello', {'a': '土', 'b': 'ہوا', 'c': '火'});
StorageCache.store('world', 42);
// Verify stored data:
let d = StorageCache.get('hello');
expect(d).toHaveProperty('a');
expect(d?.a).toBe('土');
// Save to localStorage
StorageCache.saveToLocalStorage('test');
// Remove all data in the cache
StorageCache.clear();
// Double-check there's no data:
d = StorageCache.get('hello');
expect(d).toBeUndefined();
// Load cache from localStorage
StorageCache.loadFromLocalStorage('test');
// Verify stored data:
d = StorageCache.get('hello');
expect(d).toHaveProperty('c');
expect(d?.c).toBe('火');
expect(StorageCache.get('world')).toBe(42);
});

View File

@ -9,7 +9,7 @@
// from chainforge.promptengine.template import PromptTemplate, PromptPermutationGenerator
// from chainforge.promptengine.utils import LLM, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists, set_api_keys
import { Dict, LLMResponseError, LLMResponseObject } from "./typing";
import { Dict, LLMResponseError, LLMResponseObject, StandardizedLLMResponse } from "./typing";
import { LLM } from "./models";
import { set_api_keys } from "./utils";
import StorageCache from "./cache";
@ -117,16 +117,6 @@ Object.entries(LLM).forEach(([key, value]) => {
// md_ast_parser = mistune.create_markdown(renderer='ast')
// return md_ast_parser(self.text)
interface StandardizedLLMResponse {
llm: string | Dict,
prompt: string,
responses: Array<string>,
vars: Dict,
metavars: Dict,
tokens: Dict,
eval_res?: Dict,
}
function to_standard_format(r: LLMResponseObject | Dict): StandardizedLLMResponse {
let resp_obj = {
vars: r['info'],
@ -555,6 +545,7 @@ export async function queryLLM(id: string,
prompt: string,
vars: Dict,
api_keys?: Dict,
progress_listener?: (progress: {[key: symbol]: any}) => void,
no_cache?: boolean): Promise<Dict> {
// Verify the integrity of the params
if (typeof id !== 'string' || id.trim().length === 0)
@ -582,7 +573,8 @@ export async function queryLLM(id: string,
// Get the storage keys of any cache files for specific models + settings
const llms = llm;
let cache = StorageCache.get(id); // returns {} if 'id' is not in the storage cache yet
let cache: Dict = StorageCache.get(id) || {}; // returns {} if 'id' is not in the storage cache yet
let llm_to_cache_filename = {};
let past_cache_files = {};
if (typeof cache === 'object' && cache.cache_files !== undefined) {
@ -608,7 +600,7 @@ export async function queryLLM(id: string,
// Create a new cache JSON object
cache = { cache_files: {}, responses_last_run: [] };
let prev_filenames: Array<string> = [];
llms.forEach(llm_spec => {
llms.forEach((llm_spec: string | Dict) => {
const fname = gen_unique_cache_filename(id, prev_filenames);
llm_to_cache_filename[extract_llm_key(llm_spec)] = fname;
cache.cache_files[fname] = llm_spec;
@ -625,10 +617,13 @@ export async function queryLLM(id: string,
let progressProxy = new Proxy(progress, {
set: function (target, key, value) {
console.log(`${key.toString()} set to ${value.toString()}`);
// ...
// Call any callbacks here
// ...
target[key] = value;
// If the caller provided a callback, notify it
// of the changes to the 'progress' object:
if (progress_listener)
progress_listener(target);
return true;
}
});

View File

@ -1,4 +1,5 @@
import { Dict } from "./typing";
import LZString from 'lz-string';
/**
* Singleton JSON cache that functions like a local filesystem in a Python backend,
@ -21,10 +22,10 @@ export default class StorageCache {
return StorageCache.instance;
}
private getCacheData(key: string): Dict {
return this.data[key] || {};
private getCacheData(key: string): Dict | undefined {
return this.data[key] || undefined;
}
public static get(key: string): Dict {
public static get(key: string): Dict | undefined {
return StorageCache.getInstance().getCacheData(key);
}
@ -34,4 +35,62 @@ export default class StorageCache {
public static store(key: string, data: any): void {
StorageCache.getInstance().storeCacheData(key, data);
}
private clearCache(): void {
this.data = {};
}
public static clear(): void {
StorageCache.getInstance().clearCache();
}
/**
* Attempts to store the entire cache in localStorage.
* Performs lz-string compression (https://pieroxy.net/blog/pages/lz-string/index.html)
* before storing a JSON object in UTF encoding.
*
* Use loadFromLocalStorage to unpack the localStorage data.
*
* @param localStorageKey The key that will be used in localStorage (default='chainforge')
* @returns True if succeeded, false if failure (e.g., too big for localStorage).
*/
public static saveToLocalStorage(localStorageKey: string='chainforge'): boolean {
const data = StorageCache.getInstance().data;
const compressed = LZString.compressToUTF16(JSON.stringify(data));
try {
localStorage.setItem(localStorageKey, compressed);
return true;
} catch (error) {
if (error instanceof DOMException && error.name === "QuotaExceededError") {
// Handle the error when storage quota is exceeded
console.warn("Storage quota exceeded");
} else {
// Handle other types of storage-related errors
console.error("Error storing data in localStorage:", error.message);
}
return false;
}
}
/**
* Attempts to load a previously stored cache JSON from localStorage.
* Performs lz-string decompression from UTF16 encoding.
*
* @param localStorageKey The key that will be used in localStorage (default='chainforge')
* @returns True if succeeded, false if failure (e.g., key not found).
*/
public static loadFromLocalStorage(localStorageKey: string='chainforge'): boolean {
const compressed = localStorage.getItem(localStorageKey);
if (!compressed) {
console.error(`Could not find cache data in localStorage with key ${localStorageKey}.`);
return false;
}
try {
let data = JSON.parse(LZString.decompressFromUTF16(compressed));
StorageCache.getInstance().data = data;
return true;
} catch (error) {
console.error(error.message);
return false;
}
}
}

View File

@ -214,7 +214,7 @@ export class PromptPipeline {
* Useful for continuing if computation was interrupted halfway through.
*/
_load_cached_responses(): {[key: string]: LLMResponseObject} {
return StorageCache.get(this._storageKey);
return StorageCache.get(this._storageKey) || {};
}
/**

View File

@ -29,4 +29,15 @@ export interface LLMAPICall {
n: number,
temperature: number,
params?: Dict): Promise<[Dict, Dict]>
}
/** A standard response format expected by the front-end. */
export interface StandardizedLLMResponse {
llm: string | Dict,
prompt: string,
responses: Array<string>,
vars: Dict,
metavars: Dict,
tokens: Dict,
eval_res?: Dict,
}