mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-15 00:36:29 +00:00
jest tests for query, utils, template.ts. Confirmed PromptPipeline works.
This commit is contained in:
parent
dc02e8f44a
commit
d61bd922ca
1950
chainforge/react-server/package-lock.json
generated
1950
chainforge/react-server/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -9,6 +9,7 @@
|
||||
"@codemirror/lang-python": "^6.1.2",
|
||||
"@emoji-mart/data": "^1.1.2",
|
||||
"@emoji-mart/react": "^1.1.1",
|
||||
"@google-ai/generativelanguage": "^0.2.0",
|
||||
"@mantine/core": "^6.0.9",
|
||||
"@mantine/dates": "^6.0.13",
|
||||
"@mantine/form": "^6.0.11",
|
||||
@ -38,6 +39,7 @@
|
||||
"dayjs": "^1.11.8",
|
||||
"emoji-mart": "^5.5.2",
|
||||
"emoji-picker-react": "^4.4.9",
|
||||
"google-auth-library": "^8.8.0",
|
||||
"mantine-contextmenu": "^1.2.15",
|
||||
"mantine-react-table": "^1.0.0-beta.8",
|
||||
"openai": "^3.3.0",
|
||||
@ -82,5 +84,8 @@
|
||||
"last 1 firefox version",
|
||||
"last 1 safari version"
|
||||
]
|
||||
},
|
||||
"devDependencies": {
|
||||
"jest": "^27.5.1"
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +0,0 @@
|
||||
import { render, screen } from '@testing-library/react';
|
||||
import App from './App';
|
||||
|
||||
test('renders learn react link', () => {
|
||||
render(<App />);
|
||||
const linkElement = screen.getByText(/learn react/i);
|
||||
expect(linkElement).toBeInTheDocument();
|
||||
});
|
74
chainforge/react-server/src/backend/__test__/query.test.ts
Normal file
74
chainforge/react-server/src/backend/__test__/query.test.ts
Normal file
@ -0,0 +1,74 @@
|
||||
/*
|
||||
* @jest-environment node
|
||||
*/
|
||||
import { PromptPipeline } from '../query';
|
||||
import { LLM } from '../models';
|
||||
import { expect, test } from '@jest/globals';
|
||||
import { LLMResponseError, LLMResponseObject } from '../typing';
|
||||
|
||||
async function prompt_model(model: LLM): Promise<void> {
|
||||
const pipeline = new PromptPipeline('What is the oldest {thing} in the world? Keep your answer brief.', model.toString());
|
||||
let responses: Array<LLMResponseObject | LLMResponseError> = [];
|
||||
for await (const response of pipeline.gen_responses({thing: ['bar', 'tree', 'book']}, model, 1, 1.0)) {
|
||||
responses.push(response);
|
||||
}
|
||||
expect(responses).toHaveLength(3);
|
||||
|
||||
// Double-check the cache'd results
|
||||
let cache = pipeline._load_cached_responses();
|
||||
Object.entries(cache).forEach(([prompt, response]) => {
|
||||
console.log(`Prompt: ${prompt}\nResponse: ${response.responses[0]}`);
|
||||
});
|
||||
expect(Object.keys(cache)).toHaveLength(3); // expect 3 prompts
|
||||
|
||||
// Now query ChatGPT again, but set n=2 to force it to send off 1 query per prompt.
|
||||
responses = [];
|
||||
for await (const response of pipeline.gen_responses({thing: ['bar', 'tree', 'book']}, model, 2, 1.0)) {
|
||||
responses.push(response);
|
||||
}
|
||||
expect(responses).toHaveLength(3); // still 3
|
||||
responses.forEach(resp_obj => {
|
||||
if (resp_obj instanceof LLMResponseError) return;
|
||||
expect(resp_obj.responses).toHaveLength(2); // each response object should have 2 candidates, as n=2
|
||||
});
|
||||
|
||||
// Double-check the cache'd results
|
||||
cache = pipeline._load_cached_responses();
|
||||
Object.entries(cache).forEach(([prompt, resp_obj]) => {
|
||||
console.log(`Prompt: ${prompt}\nResponses: ${JSON.stringify(resp_obj.responses)}`);
|
||||
expect(resp_obj.responses).toHaveLength(2);
|
||||
expect(resp_obj.raw_response).toHaveLength(2); // these should've been merged
|
||||
});
|
||||
expect(Object.keys(cache)).toHaveLength(3); // still expect 3 prompts
|
||||
|
||||
// Now send off the exact same query. It should use only the cache'd results:
|
||||
responses = [];
|
||||
for await (const response of pipeline.gen_responses({thing: ['bar', 'tree', 'book']}, model, 2, 1.0)) {
|
||||
responses.push(response);
|
||||
}
|
||||
expect(responses).toHaveLength(3); // still 3
|
||||
responses.forEach(resp_obj => {
|
||||
if (resp_obj instanceof LLMResponseError) return;
|
||||
expect(resp_obj.responses).toHaveLength(2); // each response object should have 2 candidates, as n=2
|
||||
});
|
||||
|
||||
cache = pipeline._load_cached_responses();
|
||||
Object.entries(cache).forEach(([prompt, resp_obj]) => {
|
||||
expect(resp_obj.responses).toHaveLength(2);
|
||||
expect(resp_obj.raw_response).toHaveLength(2); // these should've been merged
|
||||
});
|
||||
expect(Object.keys(cache)).toHaveLength(3); // still expect 3 prompts
|
||||
}
|
||||
|
||||
test('basic prompt pipeline with chatgpt', async () => {
|
||||
// Setup a simple pipeline with a prompt template, 1 variable and 3 input values
|
||||
await prompt_model(LLM.OpenAI_ChatGPT);
|
||||
}, 20000);
|
||||
|
||||
test('basic prompt pipeline with anthropic', async () => {
|
||||
await prompt_model(LLM.Claude_v1);
|
||||
}, 40000);
|
||||
|
||||
test('basic prompt pipeline with google palm2', async () => {
|
||||
await prompt_model(LLM.PaLM2_Chat_Bison);
|
||||
}, 40000);
|
@ -0,0 +1,96 @@
|
||||
import { StringTemplate, PromptTemplate, PromptPermutationGenerator } from '../template';
|
||||
import { expect, test } from '@jest/globals';
|
||||
|
||||
test('string template', () => {
|
||||
// Test regular string template
|
||||
const st = new StringTemplate('{pronoun} favorite {thing} is...');
|
||||
expect(st.has_vars()).toBe(true);
|
||||
expect(st.has_vars(['thing'])).toBe(true);
|
||||
expect(st.has_vars(['pronoun'])).toBe(true);
|
||||
expect(st.has_vars(['tacos'])).toBe(false);
|
||||
expect(st.safe_substitute({thing: 'food'})).toBe('{pronoun} favorite food is...');
|
||||
expect(st.safe_substitute({pronoun: 'My'})).toBe('My favorite {thing} is...');
|
||||
expect(st.safe_substitute({pronoun: 'My', thing: 'food'})).toBe('My favorite food is...');
|
||||
expect(st.safe_substitute({meat: 'chorizo'})).toBe('{pronoun} favorite {thing} is...');
|
||||
expect(new StringTemplate(st.safe_substitute({thing: 'programming language'})).has_vars()).toBe(true);
|
||||
});
|
||||
|
||||
test('string template escaped group', () => {
|
||||
const st = new StringTemplate('{pronoun} favorite \\{thing\\} is...');
|
||||
expect(st.has_vars(['thing'])).toBe(false);
|
||||
expect(st.has_vars(['pronoun'])).toBe(true);
|
||||
expect(st.safe_substitute({thing: 'food'})).toBe('{pronoun} favorite \\{thing\\} is...'); // no substitution
|
||||
expect(st.safe_substitute({pronoun: 'Our'})).toBe('Our favorite \\{thing\\} is...');
|
||||
});
|
||||
|
||||
test('single template', () => {
|
||||
let prompt_gen = new PromptPermutationGenerator('What is the {timeframe} when {person} was born?');
|
||||
let vars: {[key: string]: any} = {
|
||||
'timeframe': ['year', 'decade', 'century'],
|
||||
'person': ['Howard Hughes', 'Toni Morrison', 'Otis Redding']
|
||||
};
|
||||
let num_prompts = 0;
|
||||
for (const prompt of prompt_gen.generate(vars)) {
|
||||
// console.log(prompt.toString());
|
||||
expect(prompt.fill_history).toHaveProperty('timeframe');
|
||||
expect(prompt.fill_history).toHaveProperty('person');
|
||||
num_prompts += 1;
|
||||
}
|
||||
expect(num_prompts).toBe(9);
|
||||
});
|
||||
|
||||
test('nested templates', () => {
|
||||
let prompt_gen = new PromptPermutationGenerator('{prefix}... {suffix}');
|
||||
let vars = {
|
||||
'prefix': ['Who invented {tool}?', 'When was {tool} invented?', 'What can you do with {tool}?'],
|
||||
'suffix': ['Phrase your answer in the form of a {response_type}', 'Respond with a {response_type}'],
|
||||
'tool': ['the flashlight', 'CRISPR', 'rubber'],
|
||||
'response_type': ['question', 'poem', 'nightmare']
|
||||
};
|
||||
let num_prompts = 0;
|
||||
for (const prompt of prompt_gen.generate(vars)) {
|
||||
// console.log(prompt.toString());
|
||||
expect(prompt.fill_history).toHaveProperty('prefix');
|
||||
expect(prompt.fill_history).toHaveProperty('suffix');
|
||||
expect(prompt.fill_history).toHaveProperty('tool');
|
||||
expect(prompt.fill_history).toHaveProperty('response_type');
|
||||
num_prompts += 1;
|
||||
}
|
||||
expect(num_prompts).toBe((3*3)*(2*3));
|
||||
});
|
||||
|
||||
test('carry together vars', () => {
|
||||
// # 'Carry together' vars with 'metavar' data attached
|
||||
// NOTE: This feature may be used when passing rows of a table, so that vars that have associated values,
|
||||
// like 'inventor' with 'tool', 'carry together' when being filled into the prompt template.
|
||||
// In addition, 'metavars' may be attached which are, commonly, the values of other columns for that row, but
|
||||
// columns which weren't used to fill in the prompt template explcitly.
|
||||
let prompt_gen = new PromptPermutationGenerator('What {timeframe} did {inventor} invent the {tool}?')
|
||||
let vars = {
|
||||
'inventor': [
|
||||
{'text': "Thomas Edison", "fill_history": {}, "associate_id": "A", "metavars": { "year": 1879 }},
|
||||
{'text': "Alexander Fleming", "fill_history": {}, "associate_id": "B", "metavars": { "year": 1928 }},
|
||||
{'text': "William Shockley", "fill_history": {}, "associate_id": "C", "metavars": { "year": 1947 }},
|
||||
],
|
||||
'tool': [
|
||||
{'text': "lightbulb", "fill_history": {}, "associate_id": "A"},
|
||||
{'text': "penicillin", "fill_history": {}, "associate_id": "B"},
|
||||
{'text': "transistor", "fill_history": {}, "associate_id": "C"},
|
||||
],
|
||||
'timeframe': [ "year", "decade", "century" ]
|
||||
};
|
||||
let num_prompts = 0;
|
||||
for (const prompt of prompt_gen.generate(vars)) {
|
||||
const prompt_str = prompt.toString();
|
||||
// console.log(prompt_str, prompt.metavars)
|
||||
expect(prompt.metavars).toHaveProperty('year');
|
||||
if (prompt_str.includes('Edison'))
|
||||
expect(prompt_str.includes('lightbulb')).toBe(true);
|
||||
else if (prompt_str.includes('Fleming'))
|
||||
expect(prompt_str.includes('penicillin')).toBe(true);
|
||||
else if (prompt_str.includes('Shockley'))
|
||||
expect(prompt_str.includes('transistor')).toBe(true);
|
||||
num_prompts += 1;
|
||||
}
|
||||
expect(num_prompts).toBe(3*3);
|
||||
});
|
116
chainforge/react-server/src/backend/__test__/utils.test.ts
Normal file
116
chainforge/react-server/src/backend/__test__/utils.test.ts
Normal file
@ -0,0 +1,116 @@
|
||||
/*
|
||||
* @jest-environment node
|
||||
*/
|
||||
import { call_anthropic, call_chatgpt, call_google_palm, extract_responses, merge_response_objs } from '../utils';
|
||||
import { LLM } from '../models';
|
||||
import { expect, test } from '@jest/globals';
|
||||
import { LLMResponseObject } from '../typing';
|
||||
|
||||
test('merge response objects', () => {
|
||||
// Merging two response objects
|
||||
const A: LLMResponseObject = {
|
||||
responses: ['x', 'y', 'z'],
|
||||
raw_response: ['x', 'y', 'z'],
|
||||
prompt: 'this is a test',
|
||||
query: {},
|
||||
llm: LLM.OpenAI_ChatGPT,
|
||||
info: { var1: 'value1', var2: 'value2' },
|
||||
metavars: { meta1: 'meta1' },
|
||||
};
|
||||
const B: LLMResponseObject = {
|
||||
responses: ['a', 'b', 'c'],
|
||||
raw_response: {B: 'B'},
|
||||
prompt: 'this is a test 2',
|
||||
query: {},
|
||||
llm: LLM.OpenAI_ChatGPT,
|
||||
info: { varB1: 'valueB1', varB2: 'valueB2' },
|
||||
metavars: { metaB1: 'metaB1' },
|
||||
};
|
||||
const C = merge_response_objs(A, B) as LLMResponseObject;
|
||||
expect(C.responses).toHaveLength(6);
|
||||
expect(JSON.stringify(C.responses)).toBe(JSON.stringify(['x', 'y', 'z', 'a', 'b', 'c']));
|
||||
expect(C.raw_response).toHaveLength(4);
|
||||
expect(Object.keys(C.info)).toHaveLength(2);
|
||||
expect(Object.keys(C.info)).toContain('varB1');
|
||||
expect(Object.keys(C.metavars)).toHaveLength(1);
|
||||
expect(Object.keys(C.metavars)).toContain('metaB1');
|
||||
|
||||
// Merging one empty object should return the non-empty object
|
||||
expect(merge_response_objs(A, undefined)).toBe(A);
|
||||
expect(merge_response_objs(undefined, B)).toBe(B);
|
||||
})
|
||||
|
||||
test('UNCOMMENT BELOW API CALL TESTS WHEN READY', () => {
|
||||
// NOTE: API CALL TESTS ASSUME YOUR ENVIRONMENT VARIABLE IS SET!
|
||||
});
|
||||
|
||||
test('openai chat completions', async () => {
|
||||
// Call ChatGPT with a basic question, and n=2
|
||||
const [query, response] = await call_chatgpt("Who invented modern playing cards? Keep your answer brief.", LLM.OpenAI_ChatGPT, 2, 1.0);
|
||||
console.log(response.choices[0].message);
|
||||
expect(response.choices).toHaveLength(2);
|
||||
expect(query).toHaveProperty('temperature');
|
||||
|
||||
// Extract responses, check their type
|
||||
const resps = extract_responses(response, LLM.OpenAI_ChatGPT);
|
||||
expect(resps).toHaveLength(2);
|
||||
expect(typeof resps[0]).toBe('string');
|
||||
}, 20000);
|
||||
|
||||
test('openai text completions', async () => {
|
||||
// Call OpenAI template with a basic question, and n=2
|
||||
const [query, response] = await call_chatgpt("Who invented modern playing cards? The answer is ", LLM.OpenAI_Davinci003, 2, 1.0);
|
||||
console.log(response.choices[0].text);
|
||||
expect(response.choices).toHaveLength(2);
|
||||
expect(query).toHaveProperty('n');
|
||||
|
||||
// Extract responses, check their type
|
||||
const resps = extract_responses(response, LLM.OpenAI_Davinci003);
|
||||
expect(resps).toHaveLength(2);
|
||||
expect(typeof resps[0]).toBe('string');
|
||||
}, 20000);
|
||||
|
||||
test('anthropic models', async () => {
|
||||
// Call Anthropic's Claude with a basic question
|
||||
const [query, response] = await call_anthropic("Who invented modern playing cards?", LLM.Claude_v1, 1, 1.0);
|
||||
console.log(response);
|
||||
expect(response).toHaveLength(1);
|
||||
expect(query).toHaveProperty('max_tokens_to_sample');
|
||||
|
||||
// Extract responses, check their type
|
||||
const resps = extract_responses(response, LLM.Claude_v1);
|
||||
expect(resps).toHaveLength(1);
|
||||
expect(typeof resps[0]).toBe('string');
|
||||
}, 20000);
|
||||
|
||||
test('google palm2 models', async () => {
|
||||
// Call Google's PaLM Chat API with a basic question
|
||||
let [query, response] = await call_google_palm("Who invented modern playing cards?", LLM.PaLM2_Chat_Bison, 3, 0.7);
|
||||
expect(response.candidates).toHaveLength(3);
|
||||
expect(query).toHaveProperty('candidateCount');
|
||||
|
||||
// Extract responses, check their type
|
||||
let resps = extract_responses(response, LLM.PaLM2_Chat_Bison);
|
||||
expect(resps).toHaveLength(3);
|
||||
expect(typeof resps[0]).toBe('string');
|
||||
console.log(JSON.stringify(resps));
|
||||
|
||||
// Call Google's PaLM Text Completions API with a basic question
|
||||
[query, response] = await call_google_palm("Who invented modern playing cards? The answer ", LLM.PaLM2_Text_Bison, 3, 0.7);
|
||||
expect(response.candidates).toHaveLength(3);
|
||||
expect(query).toHaveProperty('maxOutputTokens');
|
||||
|
||||
// Extract responses, check their type
|
||||
resps = extract_responses(response, LLM.PaLM2_Chat_Bison);
|
||||
expect(resps).toHaveLength(3);
|
||||
expect(typeof resps[0]).toBe('string');
|
||||
console.log(JSON.stringify(resps));
|
||||
}, 40000);
|
||||
|
||||
// test('call_', async () => {
|
||||
// // Call Anthropic's Claude with a basic question
|
||||
// const [query, response] = await call_anthropic("Who invented modern playing cards? Keep your answer brief.", LLM.Claude_v1, 1, 1.0);
|
||||
// console.log(response);
|
||||
// expect(response).toHaveLength(1);
|
||||
// expect(query).toHaveProperty('max_tokens_to_sample');
|
||||
// }, 20000);
|
@ -8,13 +8,13 @@
|
||||
import { PromptTemplate, PromptPermutationGenerator } from "./template";
|
||||
import { LLM, RATE_LIMITS } from './models';
|
||||
import { Dict, LLMResponseError, LLMResponseObject, LLMAPICall } from "./typing";
|
||||
import { call_chatgpt, call_anthropic, call_google_palm, call_azure_openai, call_dalai, getEnumName } from "./utils";
|
||||
import { call_chatgpt, call_anthropic, call_google_palm, call_azure_openai, call_dalai, getEnumName, extract_responses, merge_response_objs } from "./utils";
|
||||
|
||||
interface _IntermediateLLMResponseType {
|
||||
prompt: PromptTemplate | string,
|
||||
query?: Dict,
|
||||
response?: Dict | LLMResponseError,
|
||||
past_resp_obj?: Dict,
|
||||
past_resp_obj?: LLMResponseObject | undefined,
|
||||
}
|
||||
|
||||
/**
|
||||
@ -76,9 +76,9 @@ function sleep(ms: number): Promise<void> {
|
||||
/**
|
||||
* Abstract class that captures a generic querying interface to prompt LLMs
|
||||
*/
|
||||
class PromptPipeline {
|
||||
_storageKey: string
|
||||
_template: string
|
||||
export class PromptPipeline {
|
||||
private _storageKey: string;
|
||||
private _template: string;
|
||||
|
||||
constructor(template: string, storageKey: string) {
|
||||
this._template = template;
|
||||
@ -110,13 +110,13 @@ class PromptPipeline {
|
||||
llm: LLM,
|
||||
n: number = 1,
|
||||
temperature: number = 1.0,
|
||||
llm_params: Dict): AsyncGenerator<LLMResponseObject | LLMResponseError, boolean, undefined> {
|
||||
llm_params?: Dict): AsyncGenerator<LLMResponseObject | LLMResponseError, boolean, undefined> {
|
||||
// Double-check that properties is the correct type (JSON dict):
|
||||
// Load any cache'd responses
|
||||
let responses = this._load_cached_responses();
|
||||
|
||||
// Query LLM with each prompt, yield + cache the responses
|
||||
let tasks: Array<Promise<_IntermediateLLMResponseObject>> = [];
|
||||
let tasks: Array<Promise<_IntermediateLLMResponseType>> = [];
|
||||
const rate_limit = RATE_LIMITS[llm] || [1, 0];
|
||||
let [max_req, wait_secs] = rate_limit ? rate_limit : [1, 0];
|
||||
let num_queries_sent = -1;
|
||||
@ -129,7 +129,7 @@ class PromptPipeline {
|
||||
let info = prompt.fill_history;
|
||||
let metavars = prompt.metavars;
|
||||
|
||||
let cached_resp = prompt_str in responses ? responses[prompt_str] : null;
|
||||
let cached_resp = prompt_str in responses ? responses[prompt_str] : undefined;
|
||||
let extracted_resps: Array<any> = cached_resp ? cached_resp["responses"] : [];
|
||||
|
||||
// First check if there is already a response for this item under these settings. If so, we can save an LLM call:
|
||||
@ -154,23 +154,18 @@ class PromptPipeline {
|
||||
if (max_req > 1) {
|
||||
// Call the LLM asynchronously to generate a response, sending off
|
||||
// requests in batches of size 'max_req' separated by seconds 'wait_secs' to avoid hitting rate limit
|
||||
tasks.push(this._prompt_llm(llm=llm,
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
temperature=temperature,
|
||||
past_resp_obj=cached_resp,
|
||||
query_number=num_queries_sent,
|
||||
rate_limit_batch_size=max_req,
|
||||
rate_limit_wait_secs=wait_secs,
|
||||
**llm_params));
|
||||
tasks.push(this._prompt_llm(llm, prompt, n, temperature,
|
||||
cached_resp,
|
||||
num_queries_sent,
|
||||
max_req,
|
||||
wait_secs,
|
||||
llm_params));
|
||||
} else {
|
||||
// Block. Await + yield a single LLM call.
|
||||
let {_, query, response, past_resp_obj} = await this._prompt_llm(llm=llm,
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
temperature=temperature,
|
||||
past_resp_obj=cached_resp,
|
||||
**llm_params);
|
||||
let result = await this._prompt_llm(llm, prompt, n, temperature, cached_resp,
|
||||
undefined, undefined, undefined,
|
||||
llm_params);
|
||||
let { query, response, past_resp_obj } = result;
|
||||
|
||||
// Check for selective failure
|
||||
if (!query && response instanceof LLMResponseError) {
|
||||
@ -178,20 +173,24 @@ class PromptPipeline {
|
||||
continue;
|
||||
}
|
||||
|
||||
// We now know there was a response; type it correctly:
|
||||
query = query as Dict;
|
||||
response = response as Dict;
|
||||
|
||||
// Create a response obj to represent the response
|
||||
let resp_obj: LLMResponseObject = {
|
||||
"prompt": prompt.toString(),
|
||||
"query": query,
|
||||
"responses": extract_responses(response, llm),
|
||||
"raw_response": response,
|
||||
"llm": llm,
|
||||
"info": info,
|
||||
"metavars": metavars,
|
||||
prompt: prompt.toString(),
|
||||
query: query,
|
||||
responses: extract_responses(response, llm),
|
||||
raw_response: response,
|
||||
llm: llm,
|
||||
info: info,
|
||||
metavars: metavars,
|
||||
}
|
||||
|
||||
// Merge the response obj with the past one, if necessary
|
||||
if (past_resp_obj)
|
||||
resp_obj = merge_response_objs(resp_obj, past_resp_obj);
|
||||
resp_obj = merge_response_objs(resp_obj, past_resp_obj) as LLMResponseObject;
|
||||
|
||||
// Save the current state of cache'd responses to a JSON file
|
||||
responses[resp_obj["prompt"]] = resp_obj;
|
||||
@ -230,7 +229,7 @@ class PromptPipeline {
|
||||
|
||||
// Merge the response obj with the past one, if necessary
|
||||
if (past_resp_obj)
|
||||
resp_obj = merge_response_objs(resp_obj, past_resp_obj);
|
||||
resp_obj = merge_response_objs(resp_obj, past_resp_obj) as LLMResponseObject;
|
||||
|
||||
// Save the current state of cache'd responses to a JSON file
|
||||
// NOTE: We do this to save money --in case something breaks between calls, can ensure we got the data!
|
||||
@ -249,7 +248,7 @@ class PromptPipeline {
|
||||
* Loads cache'd responses of JSON.
|
||||
* Useful for continuing if computation was interrupted halfway through.
|
||||
*/
|
||||
_load_cached_responses(): Dict {
|
||||
_load_cached_responses(): {[key: string]: LLMResponseObject} {
|
||||
return StorageCache.get(this._storageKey);
|
||||
}
|
||||
|
||||
@ -265,7 +264,7 @@ class PromptPipeline {
|
||||
prompt: PromptTemplate,
|
||||
n: number = 1,
|
||||
temperature: number = 1.0,
|
||||
past_resp_obj?: Dict,
|
||||
past_resp_obj?: LLMResponseObject,
|
||||
query_number?: number,
|
||||
rate_limit_batch_size?: number,
|
||||
rate_limit_wait_secs?: number,
|
||||
@ -303,7 +302,6 @@ class PromptPipeline {
|
||||
call_llm = call_dalai;
|
||||
else if (llm.toString().startsWith('claude'))
|
||||
call_llm = call_anthropic;
|
||||
|
||||
if (!call_llm)
|
||||
throw new LLMResponseError(`Language model ${llm} is not supported.`);
|
||||
|
||||
@ -313,19 +311,17 @@ class PromptPipeline {
|
||||
let response: Dict | LLMResponseError;
|
||||
try {
|
||||
[query, response] = await call_llm(prompt.toString(), llm, n=n, temperature, llm_params);
|
||||
} catch(e: Error) {
|
||||
} catch(err) {
|
||||
return { prompt: prompt,
|
||||
query: undefined,
|
||||
response: new LLMResponseError(e.toString()),
|
||||
response: new LLMResponseError(err.toString()),
|
||||
past_resp_obj: undefined };
|
||||
}
|
||||
|
||||
return {
|
||||
prompt,
|
||||
query,
|
||||
response,
|
||||
past_resp_obj
|
||||
};
|
||||
return { prompt,
|
||||
query,
|
||||
response,
|
||||
past_resp_obj };
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -97,6 +97,10 @@ export class StringTemplate {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
toString(): string {
|
||||
return this.val;
|
||||
}
|
||||
}
|
||||
|
||||
export class PromptTemplate {
|
||||
@ -335,79 +339,3 @@ export class PromptPermutationGenerator {
|
||||
return true; // done
|
||||
}
|
||||
}
|
||||
|
||||
function assert(condition: boolean, message?: string) {
|
||||
if (!condition) {
|
||||
throw new Error(message || "Assertion failed");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Run test cases on `PromptPermutationGenerator`.
|
||||
*/
|
||||
function _test() {
|
||||
// Single template
|
||||
let prompt_gen = new PromptPermutationGenerator('What is the {timeframe} when {person} was born?');
|
||||
let vars: {[key: string]: any} = {
|
||||
'timeframe': ['year', 'decade', 'century'],
|
||||
'person': ['Howard Hughes', 'Toni Morrison', 'Otis Redding']
|
||||
};
|
||||
let num_prompts = 0;
|
||||
for (const prompt of prompt_gen.generate(vars)) {
|
||||
console.log(prompt.toString());
|
||||
num_prompts += 1;
|
||||
}
|
||||
assert(num_prompts === 9);
|
||||
|
||||
// Nested templates
|
||||
prompt_gen = new PromptPermutationGenerator('{prefix}... {suffix}');
|
||||
vars = {
|
||||
'prefix': ['Who invented {tool}?', 'When was {tool} invented?', 'What can you do with {tool}?'],
|
||||
'suffix': ['Phrase your answer in the form of a {response_type}', 'Respond with a {response_type}'],
|
||||
'tool': ['the flashlight', 'CRISPR', 'rubber'],
|
||||
'response_type': ['question', 'poem', 'nightmare']
|
||||
};
|
||||
num_prompts = 0;
|
||||
for (const prompt of prompt_gen.generate(vars)) {
|
||||
console.log(prompt.toString());
|
||||
num_prompts += 1;
|
||||
}
|
||||
assert(num_prompts === (3*3)*(2*3));
|
||||
|
||||
// # 'Carry together' vars with 'metavar' data attached
|
||||
// # NOTE: This feature may be used when passing rows of a table, so that vars that have associated values,
|
||||
// # like 'inventor' with 'tool', 'carry together' when being filled into the prompt template.
|
||||
// # In addition, 'metavars' may be attached which are, commonly, the values of other columns for that row, but
|
||||
// # columns which weren't used to fill in the prompt template explcitly.
|
||||
prompt_gen = new PromptPermutationGenerator('What {timeframe} did {inventor} invent the {tool}?')
|
||||
vars = {
|
||||
'inventor': [
|
||||
{'text': "Thomas Edison", "fill_history": {}, "associate_id": "A", "metavars": { "year": 1879 }},
|
||||
{'text': "Alexander Fleming", "fill_history": {}, "associate_id": "B", "metavars": { "year": 1928 }},
|
||||
{'text': "William Shockley", "fill_history": {}, "associate_id": "C", "metavars": { "year": 1947 }},
|
||||
],
|
||||
'tool': [
|
||||
{'text': "lightbulb", "fill_history": {}, "associate_id": "A"},
|
||||
{'text': "penicillin", "fill_history": {}, "associate_id": "B"},
|
||||
{'text': "transistor", "fill_history": {}, "associate_id": "C"},
|
||||
],
|
||||
'timeframe': [ "year", "decade", "century" ]
|
||||
};
|
||||
num_prompts = 0;
|
||||
for (const prompt of prompt_gen.generate(vars)) {
|
||||
const prompt_str = prompt.toString();
|
||||
console.log(prompt_str, prompt.metavars)
|
||||
assert("year" in prompt.metavars);
|
||||
if (prompt_str.includes('Edison'))
|
||||
assert(prompt_str.includes('lightbulb'));
|
||||
else if (prompt_str.includes('Fleming'))
|
||||
assert(prompt_str.includes('penicillin'));
|
||||
else if (prompt_str.includes('Shockley'))
|
||||
assert(prompt_str.includes('transistor'));
|
||||
num_prompts += 1;
|
||||
}
|
||||
assert(num_prompts === 3*3);
|
||||
}
|
||||
|
||||
// Uncomment and run 'ts-node template.ts' to test:
|
||||
// _test();
|
@ -4,13 +4,16 @@
|
||||
|
||||
// from chainforge.promptengine.models import LLM
|
||||
import { LLM } from './models';
|
||||
import { Dict, StringDict, LLMAPICall } from './typing';
|
||||
import { Dict, StringDict, LLMAPICall, LLMResponseObject } from './typing';
|
||||
import { env as process_env } from 'process';
|
||||
import { StringTemplate } from './template';
|
||||
|
||||
/* LLM API SDKs */
|
||||
import { Configuration as OpenAIConfig, OpenAIApi } from "openai";
|
||||
import { OpenAIClient as AzureOpenAIClient, AzureKeyCredential } from "@azure/openai";
|
||||
import { AI_PROMPT, Client as AnthropicClient, HUMAN_PROMPT } from "@anthropic-ai/sdk";
|
||||
import { DiscussServiceClient, TextServiceClient } from "@google-ai/generativelanguage";
|
||||
import { GoogleAuth } from "google-auth-library";
|
||||
|
||||
function get_environ(key: string): string | undefined {
|
||||
if (key in process_env)
|
||||
@ -48,7 +51,7 @@ export function set_api_keys(api_keys: StringDict): void {
|
||||
}
|
||||
|
||||
/** Equivalent to a Python enum's .name property */
|
||||
function getEnumName(enumObject: any, enumValue: any): string | undefined {
|
||||
export function getEnumName(enumObject: any, enumValue: any): string | undefined {
|
||||
for (const key in enumObject) {
|
||||
if (enumObject[key] === enumValue) {
|
||||
return key;
|
||||
@ -129,17 +132,6 @@ export async function call_chatgpt(prompt: string, model: LLM, n: number = 1, te
|
||||
return [query, response];
|
||||
}
|
||||
|
||||
function _test() {
|
||||
// Call ChatGPT
|
||||
call_chatgpt("Who invented modern playing cards?",
|
||||
LLM.OpenAI_ChatGPT,
|
||||
1, 1.0).then(([query, response]) => {
|
||||
console.log(response.choices[0].message);
|
||||
});
|
||||
}
|
||||
|
||||
_test();
|
||||
|
||||
/**
|
||||
* Calls OpenAI models hosted on Microsoft Azure services.
|
||||
* Returns raw query and response JSON dicts.
|
||||
@ -247,7 +239,7 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
|
||||
...params,
|
||||
};
|
||||
|
||||
console.log(`Calling Anthropic model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`)
|
||||
console.log(`Calling Anthropic model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`);
|
||||
|
||||
// Repeat call n times, waiting for each response to come in:
|
||||
let responses: Array<Dict> = [];
|
||||
@ -260,67 +252,94 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
|
||||
return [query, responses];
|
||||
}
|
||||
|
||||
// async def call_google_palm(prompt: str, model: LLM, n: int = 1, temperature: float= 0.7,
|
||||
// max_output_tokens=800,
|
||||
// async_mode=False,
|
||||
// **params) -> Tuple[Dict, Dict]:
|
||||
// """
|
||||
// Calls a Google PaLM model.
|
||||
// Returns raw query and response JSON dicts.
|
||||
// """
|
||||
// if GOOGLE_PALM_API_KEY is None:
|
||||
// raise Exception("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local Python environment.")
|
||||
/**
|
||||
* Calls a Google PaLM model.
|
||||
Returns raw query and response JSON dicts.
|
||||
*/
|
||||
export async function call_google_palm(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
|
||||
if (!GOOGLE_PALM_API_KEY)
|
||||
throw Error("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local environment.");
|
||||
|
||||
// import google.generativeai as palm
|
||||
// palm.configure(api_key=GOOGLE_PALM_API_KEY)
|
||||
const is_chat_model = model.toString().includes('chat');
|
||||
const client = new (is_chat_model ? DiscussServiceClient : TextServiceClient)({
|
||||
authClient: new GoogleAuth().fromAPIKey(GOOGLE_PALM_API_KEY),
|
||||
});
|
||||
|
||||
// is_chat_model = 'chat' in model.value
|
||||
// Required non-standard params
|
||||
const max_output_tokens = params?.max_output_tokens || 800;
|
||||
|
||||
// query = {
|
||||
// 'model': f"models/{model.value}",
|
||||
// 'prompt': prompt,
|
||||
// 'candidate_count': n,
|
||||
// 'temperature': temperature,
|
||||
// 'max_output_tokens': max_output_tokens,
|
||||
// **params,
|
||||
// }
|
||||
let query: Dict = {
|
||||
model: `models/${model}`,
|
||||
candidate_count: n,
|
||||
temperature: temperature,
|
||||
max_output_tokens: max_output_tokens,
|
||||
...params,
|
||||
};
|
||||
|
||||
// # Remove erroneous parameters for text and chat models
|
||||
// if 'top_k' in query and query['top_k'] <= 0:
|
||||
// del query['top_k']
|
||||
// if 'top_p' in query and query['top_p'] <= 0:
|
||||
// del query['top_p']
|
||||
// if is_chat_model and 'max_output_tokens' in query:
|
||||
// del query['max_output_tokens']
|
||||
// if is_chat_model and 'stop_sequences' in query:
|
||||
// del query['stop_sequences']
|
||||
|
||||
// # Get the correct model's completions call
|
||||
// palm_call = palm.chat if is_chat_model else palm.generate_text
|
||||
// Remove erroneous parameters for text and chat models
|
||||
if (query.top_k !== undefined && query.top_k <= 0)
|
||||
delete query.top_k;
|
||||
if (query.top_p !== undefined && query.top_p <= 0)
|
||||
delete query.top_p;
|
||||
if (is_chat_model && query.max_output_tokens !== undefined)
|
||||
delete query.max_output_tokens;
|
||||
if (is_chat_model && query.stop_sequences !== undefined)
|
||||
delete query.stop_sequences;
|
||||
|
||||
// # Google PaLM's python API does not currently support async calls.
|
||||
// # To make one, we need to wrap it in an asynchronous executor:
|
||||
// completion = await make_sync_call_async(palm_call, **query)
|
||||
// completion_dict = completion.to_dict()
|
||||
// For some reason Google needs to be special and have its API params be different names --camel or snake-case
|
||||
// --depending on if it's the Python or Node JS API. ChainForge needs a consistent name, so we must convert snake to camel:
|
||||
const casemap = {
|
||||
safety_settings: 'safetySettings',
|
||||
stop_sequences: 'stopSequences',
|
||||
candidate_count: 'candidateCount',
|
||||
max_output_tokens: 'maxOutputTokens',
|
||||
top_p: 'topP',
|
||||
top_k: 'topK',
|
||||
};
|
||||
Object.entries(casemap).forEach(([key, val]) => {
|
||||
if (key in query) {
|
||||
query[val] = query[key];
|
||||
delete query[key];
|
||||
}
|
||||
});
|
||||
|
||||
// # Google PaLM, unlike other chat models, will output empty
|
||||
// # responses for any response it deems unsafe (blocks). Although the text completions
|
||||
// # API has a (relatively undocumented) 'safety_settings' parameter,
|
||||
// # the current chat completions API provides users no control over the blocking.
|
||||
// # We need to detect this and fill the response with the safety reasoning:
|
||||
// if len(completion.filters) > 0:
|
||||
// # Request was blocked. Output why in the response text,
|
||||
// # repairing the candidate dict to mock up 'n' responses
|
||||
// block_error_msg = f'[[BLOCKED_REQUEST]] Request was blocked because it triggered safety filters: {str(completion.filters)}'
|
||||
// completion_dict['candidates'] = [{'author': 1, 'content':block_error_msg}] * n
|
||||
console.log(`Calling Google PaLM model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`);
|
||||
|
||||
// Call the correct model client
|
||||
let completion;
|
||||
if (is_chat_model) {
|
||||
// Chat completions
|
||||
query.prompt = { messages: [{content: prompt}] };
|
||||
completion = await (client as DiscussServiceClient).generateMessage(query);
|
||||
} else {
|
||||
// Text completions
|
||||
query.prompt = { text: prompt };
|
||||
completion = await (client as TextServiceClient).generateText(query);
|
||||
}
|
||||
|
||||
// # Weirdly, google ignores candidate_count if temperature is 0.
|
||||
// # We have to check for this and manually append the n-1 responses:
|
||||
// if n > 1 and temperature == 0 and len(completion_dict['candidates']) == 1:
|
||||
// copied_candidates = [completion_dict['candidates'][0]] * n
|
||||
// completion_dict['candidates'] = copied_candidates
|
||||
// Google PaLM, unlike other chat models, will output empty
|
||||
// responses for any response it deems unsafe (blocks). Although the text completions
|
||||
// API has a (relatively undocumented) 'safety_settings' parameter,
|
||||
// the current chat completions API provides users no control over the blocking.
|
||||
// We need to detect this and fill the response with the safety reasoning:
|
||||
if (completion[0].filters.length > 0) {
|
||||
// Request was blocked. Output why in the response text, repairing the candidate dict to mock up 'n' responses
|
||||
const block_error_msg = `[[BLOCKED_REQUEST]] Request was blocked because it triggered safety filters: ${JSON.stringify(completion.filters)}`
|
||||
completion[0].candidates = new Array(n).fill({'author': '1', 'content':block_error_msg});
|
||||
}
|
||||
|
||||
// return query, completion_dict
|
||||
// Weirdly, google ignores candidate_count if temperature is 0.
|
||||
// We have to check for this and manually append the n-1 responses:
|
||||
// if n > 1 and temperature == 0 and len(completion_dict['candidates']) == 1:
|
||||
// copied_candidates = [completion_dict['candidates'][0]] * n
|
||||
// completion_dict['candidates'] = copied_candidates
|
||||
|
||||
return [query, completion[0]];
|
||||
}
|
||||
|
||||
export async function call_dalai(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
|
||||
throw Error("Dalai support in JS backend is not yet implemented.");
|
||||
}
|
||||
|
||||
// async def call_dalai(prompt: str, model: LLM, server: str="http://localhost:4000", n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
|
||||
// """
|
||||
@ -463,11 +482,18 @@ function _extract_palm_responses(completion: Dict): Array<string> {
|
||||
return completion['candidates'].map((c: Dict) => c.output || c.content);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the text part of an Anthropic text completion.
|
||||
*/
|
||||
function _extract_anthropic_responses(response: Array<Dict>): Array<string> {
|
||||
return response.map((r: Dict) => r.completion.trim());
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a LLM and a response object from its API, extract the
|
||||
* text response(s) part of the response object.
|
||||
*/
|
||||
function extract_responses(response: Array<string> | Dict, llm: LLM | string): Array<string> {
|
||||
export function extract_responses(response: Array<string | Dict> | Dict, llm: LLM | string): Array<string> {
|
||||
const llm_name = getEnumName(LLM, llm.toString());
|
||||
if (llm_name?.startsWith('OpenAI')) {
|
||||
if (llm_name.toLowerCase().includes('davinci'))
|
||||
@ -475,42 +501,41 @@ function extract_responses(response: Array<string> | Dict, llm: LLM | string): A
|
||||
else
|
||||
return _extract_chatgpt_responses(response);
|
||||
} else if (llm_name?.startsWith('Azure'))
|
||||
return _extract_openai_responses(response)
|
||||
return _extract_openai_responses(response);
|
||||
else if (llm_name?.startsWith('PaLM2'))
|
||||
return _extract_palm_responses(response)
|
||||
return _extract_palm_responses(response);
|
||||
else if (llm_name?.startsWith('Dalai'))
|
||||
return [response.toString()];
|
||||
else if (llm.toString().startsWith('claude'))
|
||||
return response.map((r: Dict) => r.completion);
|
||||
return _extract_anthropic_responses(response as Dict[]);
|
||||
else
|
||||
throw new Error(`LLM ${llm_str} is unsupported.`)
|
||||
throw new Error(`No method defined to extract responses for LLM ${llm}.`)
|
||||
}
|
||||
|
||||
function merge_response_objs(resp_obj_A: Dict | undefined, resp_obj_B: Dict | undefined): Dict {
|
||||
export function merge_response_objs(resp_obj_A: LLMResponseObject | undefined, resp_obj_B: LLMResponseObject | undefined): LLMResponseObject | undefined {
|
||||
if (!resp_obj_A && !resp_obj_B) {
|
||||
console.warn('Warning: Merging two undefined response objects.')
|
||||
return {};
|
||||
return undefined;
|
||||
} else if (!resp_obj_B && resp_obj_A)
|
||||
return resp_obj_A;
|
||||
else if (!resp_obj_A && resp_obj_B)
|
||||
return resp_obj_B;
|
||||
let raw_resp_A = resp_obj_A?.raw_response;
|
||||
let raw_resp_B = resp_obj_B?.raw_response;
|
||||
resp_obj_A = resp_obj_A as LLMResponseObject; // required by typescript
|
||||
resp_obj_B = resp_obj_B as LLMResponseObject;
|
||||
let raw_resp_A = resp_obj_A.raw_response;
|
||||
let raw_resp_B = resp_obj_B.raw_response;
|
||||
if (!Array.isArray(raw_resp_A))
|
||||
raw_resp_A = [ raw_resp_A ];
|
||||
if (!Array.isArray(raw_resp_B))
|
||||
raw_resp_B = [ raw_resp_B ];
|
||||
const C: Dict = {
|
||||
responses: resp_obj_A?.responses.concat(resp_obj_B?.responses),
|
||||
raw_response: raw_resp_A.concat(raw_resp_B),
|
||||
};
|
||||
return {
|
||||
...C,
|
||||
prompt: resp_obj_B?.prompt,
|
||||
query: resp_obj_B?.query,
|
||||
llm: resp_obj_B?.llm,
|
||||
info: resp_obj_B?.info,
|
||||
metavars: resp_obj_B?.metavars,
|
||||
responses: resp_obj_A.responses.concat(resp_obj_B.responses),
|
||||
raw_response: raw_resp_A.concat(raw_resp_B),
|
||||
prompt: resp_obj_B.prompt,
|
||||
query: resp_obj_B.query,
|
||||
llm: resp_obj_B.llm,
|
||||
info: resp_obj_B.info,
|
||||
metavars: resp_obj_B.metavars,
|
||||
};
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user