jest tests for query, utils, template.ts. Confirmed PromptPipeline works.

This commit is contained in:
Ian Arawjo 2023-06-25 14:02:25 -04:00
parent dc02e8f44a
commit d61bd922ca
9 changed files with 2142 additions and 466 deletions

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,7 @@
"@codemirror/lang-python": "^6.1.2",
"@emoji-mart/data": "^1.1.2",
"@emoji-mart/react": "^1.1.1",
"@google-ai/generativelanguage": "^0.2.0",
"@mantine/core": "^6.0.9",
"@mantine/dates": "^6.0.13",
"@mantine/form": "^6.0.11",
@ -38,6 +39,7 @@
"dayjs": "^1.11.8",
"emoji-mart": "^5.5.2",
"emoji-picker-react": "^4.4.9",
"google-auth-library": "^8.8.0",
"mantine-contextmenu": "^1.2.15",
"mantine-react-table": "^1.0.0-beta.8",
"openai": "^3.3.0",
@ -82,5 +84,8 @@
"last 1 firefox version",
"last 1 safari version"
]
},
"devDependencies": {
"jest": "^27.5.1"
}
}

View File

@ -1,8 +0,0 @@
import { render, screen } from '@testing-library/react';
import App from './App';
test('renders learn react link', () => {
render(<App />);
const linkElement = screen.getByText(/learn react/i);
expect(linkElement).toBeInTheDocument();
});

View File

@ -0,0 +1,74 @@
/*
* @jest-environment node
*/
import { PromptPipeline } from '../query';
import { LLM } from '../models';
import { expect, test } from '@jest/globals';
import { LLMResponseError, LLMResponseObject } from '../typing';
async function prompt_model(model: LLM): Promise<void> {
const pipeline = new PromptPipeline('What is the oldest {thing} in the world? Keep your answer brief.', model.toString());
let responses: Array<LLMResponseObject | LLMResponseError> = [];
for await (const response of pipeline.gen_responses({thing: ['bar', 'tree', 'book']}, model, 1, 1.0)) {
responses.push(response);
}
expect(responses).toHaveLength(3);
// Double-check the cache'd results
let cache = pipeline._load_cached_responses();
Object.entries(cache).forEach(([prompt, response]) => {
console.log(`Prompt: ${prompt}\nResponse: ${response.responses[0]}`);
});
expect(Object.keys(cache)).toHaveLength(3); // expect 3 prompts
// Now query ChatGPT again, but set n=2 to force it to send off 1 query per prompt.
responses = [];
for await (const response of pipeline.gen_responses({thing: ['bar', 'tree', 'book']}, model, 2, 1.0)) {
responses.push(response);
}
expect(responses).toHaveLength(3); // still 3
responses.forEach(resp_obj => {
if (resp_obj instanceof LLMResponseError) return;
expect(resp_obj.responses).toHaveLength(2); // each response object should have 2 candidates, as n=2
});
// Double-check the cache'd results
cache = pipeline._load_cached_responses();
Object.entries(cache).forEach(([prompt, resp_obj]) => {
console.log(`Prompt: ${prompt}\nResponses: ${JSON.stringify(resp_obj.responses)}`);
expect(resp_obj.responses).toHaveLength(2);
expect(resp_obj.raw_response).toHaveLength(2); // these should've been merged
});
expect(Object.keys(cache)).toHaveLength(3); // still expect 3 prompts
// Now send off the exact same query. It should use only the cache'd results:
responses = [];
for await (const response of pipeline.gen_responses({thing: ['bar', 'tree', 'book']}, model, 2, 1.0)) {
responses.push(response);
}
expect(responses).toHaveLength(3); // still 3
responses.forEach(resp_obj => {
if (resp_obj instanceof LLMResponseError) return;
expect(resp_obj.responses).toHaveLength(2); // each response object should have 2 candidates, as n=2
});
cache = pipeline._load_cached_responses();
Object.entries(cache).forEach(([prompt, resp_obj]) => {
expect(resp_obj.responses).toHaveLength(2);
expect(resp_obj.raw_response).toHaveLength(2); // these should've been merged
});
expect(Object.keys(cache)).toHaveLength(3); // still expect 3 prompts
}
test('basic prompt pipeline with chatgpt', async () => {
// Setup a simple pipeline with a prompt template, 1 variable and 3 input values
await prompt_model(LLM.OpenAI_ChatGPT);
}, 20000);
test('basic prompt pipeline with anthropic', async () => {
await prompt_model(LLM.Claude_v1);
}, 40000);
test('basic prompt pipeline with google palm2', async () => {
await prompt_model(LLM.PaLM2_Chat_Bison);
}, 40000);

View File

@ -0,0 +1,96 @@
import { StringTemplate, PromptTemplate, PromptPermutationGenerator } from '../template';
import { expect, test } from '@jest/globals';
test('string template', () => {
// Test regular string template
const st = new StringTemplate('{pronoun} favorite {thing} is...');
expect(st.has_vars()).toBe(true);
expect(st.has_vars(['thing'])).toBe(true);
expect(st.has_vars(['pronoun'])).toBe(true);
expect(st.has_vars(['tacos'])).toBe(false);
expect(st.safe_substitute({thing: 'food'})).toBe('{pronoun} favorite food is...');
expect(st.safe_substitute({pronoun: 'My'})).toBe('My favorite {thing} is...');
expect(st.safe_substitute({pronoun: 'My', thing: 'food'})).toBe('My favorite food is...');
expect(st.safe_substitute({meat: 'chorizo'})).toBe('{pronoun} favorite {thing} is...');
expect(new StringTemplate(st.safe_substitute({thing: 'programming language'})).has_vars()).toBe(true);
});
test('string template escaped group', () => {
const st = new StringTemplate('{pronoun} favorite \\{thing\\} is...');
expect(st.has_vars(['thing'])).toBe(false);
expect(st.has_vars(['pronoun'])).toBe(true);
expect(st.safe_substitute({thing: 'food'})).toBe('{pronoun} favorite \\{thing\\} is...'); // no substitution
expect(st.safe_substitute({pronoun: 'Our'})).toBe('Our favorite \\{thing\\} is...');
});
test('single template', () => {
let prompt_gen = new PromptPermutationGenerator('What is the {timeframe} when {person} was born?');
let vars: {[key: string]: any} = {
'timeframe': ['year', 'decade', 'century'],
'person': ['Howard Hughes', 'Toni Morrison', 'Otis Redding']
};
let num_prompts = 0;
for (const prompt of prompt_gen.generate(vars)) {
// console.log(prompt.toString());
expect(prompt.fill_history).toHaveProperty('timeframe');
expect(prompt.fill_history).toHaveProperty('person');
num_prompts += 1;
}
expect(num_prompts).toBe(9);
});
test('nested templates', () => {
let prompt_gen = new PromptPermutationGenerator('{prefix}... {suffix}');
let vars = {
'prefix': ['Who invented {tool}?', 'When was {tool} invented?', 'What can you do with {tool}?'],
'suffix': ['Phrase your answer in the form of a {response_type}', 'Respond with a {response_type}'],
'tool': ['the flashlight', 'CRISPR', 'rubber'],
'response_type': ['question', 'poem', 'nightmare']
};
let num_prompts = 0;
for (const prompt of prompt_gen.generate(vars)) {
// console.log(prompt.toString());
expect(prompt.fill_history).toHaveProperty('prefix');
expect(prompt.fill_history).toHaveProperty('suffix');
expect(prompt.fill_history).toHaveProperty('tool');
expect(prompt.fill_history).toHaveProperty('response_type');
num_prompts += 1;
}
expect(num_prompts).toBe((3*3)*(2*3));
});
test('carry together vars', () => {
// # 'Carry together' vars with 'metavar' data attached
// NOTE: This feature may be used when passing rows of a table, so that vars that have associated values,
// like 'inventor' with 'tool', 'carry together' when being filled into the prompt template.
// In addition, 'metavars' may be attached which are, commonly, the values of other columns for that row, but
// columns which weren't used to fill in the prompt template explcitly.
let prompt_gen = new PromptPermutationGenerator('What {timeframe} did {inventor} invent the {tool}?')
let vars = {
'inventor': [
{'text': "Thomas Edison", "fill_history": {}, "associate_id": "A", "metavars": { "year": 1879 }},
{'text': "Alexander Fleming", "fill_history": {}, "associate_id": "B", "metavars": { "year": 1928 }},
{'text': "William Shockley", "fill_history": {}, "associate_id": "C", "metavars": { "year": 1947 }},
],
'tool': [
{'text': "lightbulb", "fill_history": {}, "associate_id": "A"},
{'text': "penicillin", "fill_history": {}, "associate_id": "B"},
{'text': "transistor", "fill_history": {}, "associate_id": "C"},
],
'timeframe': [ "year", "decade", "century" ]
};
let num_prompts = 0;
for (const prompt of prompt_gen.generate(vars)) {
const prompt_str = prompt.toString();
// console.log(prompt_str, prompt.metavars)
expect(prompt.metavars).toHaveProperty('year');
if (prompt_str.includes('Edison'))
expect(prompt_str.includes('lightbulb')).toBe(true);
else if (prompt_str.includes('Fleming'))
expect(prompt_str.includes('penicillin')).toBe(true);
else if (prompt_str.includes('Shockley'))
expect(prompt_str.includes('transistor')).toBe(true);
num_prompts += 1;
}
expect(num_prompts).toBe(3*3);
});

View File

@ -0,0 +1,116 @@
/*
* @jest-environment node
*/
import { call_anthropic, call_chatgpt, call_google_palm, extract_responses, merge_response_objs } from '../utils';
import { LLM } from '../models';
import { expect, test } from '@jest/globals';
import { LLMResponseObject } from '../typing';
test('merge response objects', () => {
// Merging two response objects
const A: LLMResponseObject = {
responses: ['x', 'y', 'z'],
raw_response: ['x', 'y', 'z'],
prompt: 'this is a test',
query: {},
llm: LLM.OpenAI_ChatGPT,
info: { var1: 'value1', var2: 'value2' },
metavars: { meta1: 'meta1' },
};
const B: LLMResponseObject = {
responses: ['a', 'b', 'c'],
raw_response: {B: 'B'},
prompt: 'this is a test 2',
query: {},
llm: LLM.OpenAI_ChatGPT,
info: { varB1: 'valueB1', varB2: 'valueB2' },
metavars: { metaB1: 'metaB1' },
};
const C = merge_response_objs(A, B) as LLMResponseObject;
expect(C.responses).toHaveLength(6);
expect(JSON.stringify(C.responses)).toBe(JSON.stringify(['x', 'y', 'z', 'a', 'b', 'c']));
expect(C.raw_response).toHaveLength(4);
expect(Object.keys(C.info)).toHaveLength(2);
expect(Object.keys(C.info)).toContain('varB1');
expect(Object.keys(C.metavars)).toHaveLength(1);
expect(Object.keys(C.metavars)).toContain('metaB1');
// Merging one empty object should return the non-empty object
expect(merge_response_objs(A, undefined)).toBe(A);
expect(merge_response_objs(undefined, B)).toBe(B);
})
test('UNCOMMENT BELOW API CALL TESTS WHEN READY', () => {
// NOTE: API CALL TESTS ASSUME YOUR ENVIRONMENT VARIABLE IS SET!
});
test('openai chat completions', async () => {
// Call ChatGPT with a basic question, and n=2
const [query, response] = await call_chatgpt("Who invented modern playing cards? Keep your answer brief.", LLM.OpenAI_ChatGPT, 2, 1.0);
console.log(response.choices[0].message);
expect(response.choices).toHaveLength(2);
expect(query).toHaveProperty('temperature');
// Extract responses, check their type
const resps = extract_responses(response, LLM.OpenAI_ChatGPT);
expect(resps).toHaveLength(2);
expect(typeof resps[0]).toBe('string');
}, 20000);
test('openai text completions', async () => {
// Call OpenAI template with a basic question, and n=2
const [query, response] = await call_chatgpt("Who invented modern playing cards? The answer is ", LLM.OpenAI_Davinci003, 2, 1.0);
console.log(response.choices[0].text);
expect(response.choices).toHaveLength(2);
expect(query).toHaveProperty('n');
// Extract responses, check their type
const resps = extract_responses(response, LLM.OpenAI_Davinci003);
expect(resps).toHaveLength(2);
expect(typeof resps[0]).toBe('string');
}, 20000);
test('anthropic models', async () => {
// Call Anthropic's Claude with a basic question
const [query, response] = await call_anthropic("Who invented modern playing cards?", LLM.Claude_v1, 1, 1.0);
console.log(response);
expect(response).toHaveLength(1);
expect(query).toHaveProperty('max_tokens_to_sample');
// Extract responses, check their type
const resps = extract_responses(response, LLM.Claude_v1);
expect(resps).toHaveLength(1);
expect(typeof resps[0]).toBe('string');
}, 20000);
test('google palm2 models', async () => {
// Call Google's PaLM Chat API with a basic question
let [query, response] = await call_google_palm("Who invented modern playing cards?", LLM.PaLM2_Chat_Bison, 3, 0.7);
expect(response.candidates).toHaveLength(3);
expect(query).toHaveProperty('candidateCount');
// Extract responses, check their type
let resps = extract_responses(response, LLM.PaLM2_Chat_Bison);
expect(resps).toHaveLength(3);
expect(typeof resps[0]).toBe('string');
console.log(JSON.stringify(resps));
// Call Google's PaLM Text Completions API with a basic question
[query, response] = await call_google_palm("Who invented modern playing cards? The answer ", LLM.PaLM2_Text_Bison, 3, 0.7);
expect(response.candidates).toHaveLength(3);
expect(query).toHaveProperty('maxOutputTokens');
// Extract responses, check their type
resps = extract_responses(response, LLM.PaLM2_Chat_Bison);
expect(resps).toHaveLength(3);
expect(typeof resps[0]).toBe('string');
console.log(JSON.stringify(resps));
}, 40000);
// test('call_', async () => {
// // Call Anthropic's Claude with a basic question
// const [query, response] = await call_anthropic("Who invented modern playing cards? Keep your answer brief.", LLM.Claude_v1, 1, 1.0);
// console.log(response);
// expect(response).toHaveLength(1);
// expect(query).toHaveProperty('max_tokens_to_sample');
// }, 20000);

View File

@ -8,13 +8,13 @@
import { PromptTemplate, PromptPermutationGenerator } from "./template";
import { LLM, RATE_LIMITS } from './models';
import { Dict, LLMResponseError, LLMResponseObject, LLMAPICall } from "./typing";
import { call_chatgpt, call_anthropic, call_google_palm, call_azure_openai, call_dalai, getEnumName } from "./utils";
import { call_chatgpt, call_anthropic, call_google_palm, call_azure_openai, call_dalai, getEnumName, extract_responses, merge_response_objs } from "./utils";
interface _IntermediateLLMResponseType {
prompt: PromptTemplate | string,
query?: Dict,
response?: Dict | LLMResponseError,
past_resp_obj?: Dict,
past_resp_obj?: LLMResponseObject | undefined,
}
/**
@ -76,9 +76,9 @@ function sleep(ms: number): Promise<void> {
/**
* Abstract class that captures a generic querying interface to prompt LLMs
*/
class PromptPipeline {
_storageKey: string
_template: string
export class PromptPipeline {
private _storageKey: string;
private _template: string;
constructor(template: string, storageKey: string) {
this._template = template;
@ -110,13 +110,13 @@ class PromptPipeline {
llm: LLM,
n: number = 1,
temperature: number = 1.0,
llm_params: Dict): AsyncGenerator<LLMResponseObject | LLMResponseError, boolean, undefined> {
llm_params?: Dict): AsyncGenerator<LLMResponseObject | LLMResponseError, boolean, undefined> {
// Double-check that properties is the correct type (JSON dict):
// Load any cache'd responses
let responses = this._load_cached_responses();
// Query LLM with each prompt, yield + cache the responses
let tasks: Array<Promise<_IntermediateLLMResponseObject>> = [];
let tasks: Array<Promise<_IntermediateLLMResponseType>> = [];
const rate_limit = RATE_LIMITS[llm] || [1, 0];
let [max_req, wait_secs] = rate_limit ? rate_limit : [1, 0];
let num_queries_sent = -1;
@ -129,7 +129,7 @@ class PromptPipeline {
let info = prompt.fill_history;
let metavars = prompt.metavars;
let cached_resp = prompt_str in responses ? responses[prompt_str] : null;
let cached_resp = prompt_str in responses ? responses[prompt_str] : undefined;
let extracted_resps: Array<any> = cached_resp ? cached_resp["responses"] : [];
// First check if there is already a response for this item under these settings. If so, we can save an LLM call:
@ -154,23 +154,18 @@ class PromptPipeline {
if (max_req > 1) {
// Call the LLM asynchronously to generate a response, sending off
// requests in batches of size 'max_req' separated by seconds 'wait_secs' to avoid hitting rate limit
tasks.push(this._prompt_llm(llm=llm,
prompt=prompt,
n=n,
temperature=temperature,
past_resp_obj=cached_resp,
query_number=num_queries_sent,
rate_limit_batch_size=max_req,
rate_limit_wait_secs=wait_secs,
**llm_params));
tasks.push(this._prompt_llm(llm, prompt, n, temperature,
cached_resp,
num_queries_sent,
max_req,
wait_secs,
llm_params));
} else {
// Block. Await + yield a single LLM call.
let {_, query, response, past_resp_obj} = await this._prompt_llm(llm=llm,
prompt=prompt,
n=n,
temperature=temperature,
past_resp_obj=cached_resp,
**llm_params);
let result = await this._prompt_llm(llm, prompt, n, temperature, cached_resp,
undefined, undefined, undefined,
llm_params);
let { query, response, past_resp_obj } = result;
// Check for selective failure
if (!query && response instanceof LLMResponseError) {
@ -178,20 +173,24 @@ class PromptPipeline {
continue;
}
// We now know there was a response; type it correctly:
query = query as Dict;
response = response as Dict;
// Create a response obj to represent the response
let resp_obj: LLMResponseObject = {
"prompt": prompt.toString(),
"query": query,
"responses": extract_responses(response, llm),
"raw_response": response,
"llm": llm,
"info": info,
"metavars": metavars,
prompt: prompt.toString(),
query: query,
responses: extract_responses(response, llm),
raw_response: response,
llm: llm,
info: info,
metavars: metavars,
}
// Merge the response obj with the past one, if necessary
if (past_resp_obj)
resp_obj = merge_response_objs(resp_obj, past_resp_obj);
resp_obj = merge_response_objs(resp_obj, past_resp_obj) as LLMResponseObject;
// Save the current state of cache'd responses to a JSON file
responses[resp_obj["prompt"]] = resp_obj;
@ -230,7 +229,7 @@ class PromptPipeline {
// Merge the response obj with the past one, if necessary
if (past_resp_obj)
resp_obj = merge_response_objs(resp_obj, past_resp_obj);
resp_obj = merge_response_objs(resp_obj, past_resp_obj) as LLMResponseObject;
// Save the current state of cache'd responses to a JSON file
// NOTE: We do this to save money --in case something breaks between calls, can ensure we got the data!
@ -249,7 +248,7 @@ class PromptPipeline {
* Loads cache'd responses of JSON.
* Useful for continuing if computation was interrupted halfway through.
*/
_load_cached_responses(): Dict {
_load_cached_responses(): {[key: string]: LLMResponseObject} {
return StorageCache.get(this._storageKey);
}
@ -265,7 +264,7 @@ class PromptPipeline {
prompt: PromptTemplate,
n: number = 1,
temperature: number = 1.0,
past_resp_obj?: Dict,
past_resp_obj?: LLMResponseObject,
query_number?: number,
rate_limit_batch_size?: number,
rate_limit_wait_secs?: number,
@ -303,7 +302,6 @@ class PromptPipeline {
call_llm = call_dalai;
else if (llm.toString().startsWith('claude'))
call_llm = call_anthropic;
if (!call_llm)
throw new LLMResponseError(`Language model ${llm} is not supported.`);
@ -313,19 +311,17 @@ class PromptPipeline {
let response: Dict | LLMResponseError;
try {
[query, response] = await call_llm(prompt.toString(), llm, n=n, temperature, llm_params);
} catch(e: Error) {
} catch(err) {
return { prompt: prompt,
query: undefined,
response: new LLMResponseError(e.toString()),
response: new LLMResponseError(err.toString()),
past_resp_obj: undefined };
}
return {
prompt,
query,
response,
past_resp_obj
};
return { prompt,
query,
response,
past_resp_obj };
}
}

View File

@ -97,6 +97,10 @@ export class StringTemplate {
}
return false;
}
toString(): string {
return this.val;
}
}
export class PromptTemplate {
@ -335,79 +339,3 @@ export class PromptPermutationGenerator {
return true; // done
}
}
function assert(condition: boolean, message?: string) {
if (!condition) {
throw new Error(message || "Assertion failed");
}
}
/**
* Run test cases on `PromptPermutationGenerator`.
*/
function _test() {
// Single template
let prompt_gen = new PromptPermutationGenerator('What is the {timeframe} when {person} was born?');
let vars: {[key: string]: any} = {
'timeframe': ['year', 'decade', 'century'],
'person': ['Howard Hughes', 'Toni Morrison', 'Otis Redding']
};
let num_prompts = 0;
for (const prompt of prompt_gen.generate(vars)) {
console.log(prompt.toString());
num_prompts += 1;
}
assert(num_prompts === 9);
// Nested templates
prompt_gen = new PromptPermutationGenerator('{prefix}... {suffix}');
vars = {
'prefix': ['Who invented {tool}?', 'When was {tool} invented?', 'What can you do with {tool}?'],
'suffix': ['Phrase your answer in the form of a {response_type}', 'Respond with a {response_type}'],
'tool': ['the flashlight', 'CRISPR', 'rubber'],
'response_type': ['question', 'poem', 'nightmare']
};
num_prompts = 0;
for (const prompt of prompt_gen.generate(vars)) {
console.log(prompt.toString());
num_prompts += 1;
}
assert(num_prompts === (3*3)*(2*3));
// # 'Carry together' vars with 'metavar' data attached
// # NOTE: This feature may be used when passing rows of a table, so that vars that have associated values,
// # like 'inventor' with 'tool', 'carry together' when being filled into the prompt template.
// # In addition, 'metavars' may be attached which are, commonly, the values of other columns for that row, but
// # columns which weren't used to fill in the prompt template explcitly.
prompt_gen = new PromptPermutationGenerator('What {timeframe} did {inventor} invent the {tool}?')
vars = {
'inventor': [
{'text': "Thomas Edison", "fill_history": {}, "associate_id": "A", "metavars": { "year": 1879 }},
{'text': "Alexander Fleming", "fill_history": {}, "associate_id": "B", "metavars": { "year": 1928 }},
{'text': "William Shockley", "fill_history": {}, "associate_id": "C", "metavars": { "year": 1947 }},
],
'tool': [
{'text': "lightbulb", "fill_history": {}, "associate_id": "A"},
{'text': "penicillin", "fill_history": {}, "associate_id": "B"},
{'text': "transistor", "fill_history": {}, "associate_id": "C"},
],
'timeframe': [ "year", "decade", "century" ]
};
num_prompts = 0;
for (const prompt of prompt_gen.generate(vars)) {
const prompt_str = prompt.toString();
console.log(prompt_str, prompt.metavars)
assert("year" in prompt.metavars);
if (prompt_str.includes('Edison'))
assert(prompt_str.includes('lightbulb'));
else if (prompt_str.includes('Fleming'))
assert(prompt_str.includes('penicillin'));
else if (prompt_str.includes('Shockley'))
assert(prompt_str.includes('transistor'));
num_prompts += 1;
}
assert(num_prompts === 3*3);
}
// Uncomment and run 'ts-node template.ts' to test:
// _test();

View File

@ -4,13 +4,16 @@
// from chainforge.promptengine.models import LLM
import { LLM } from './models';
import { Dict, StringDict, LLMAPICall } from './typing';
import { Dict, StringDict, LLMAPICall, LLMResponseObject } from './typing';
import { env as process_env } from 'process';
import { StringTemplate } from './template';
/* LLM API SDKs */
import { Configuration as OpenAIConfig, OpenAIApi } from "openai";
import { OpenAIClient as AzureOpenAIClient, AzureKeyCredential } from "@azure/openai";
import { AI_PROMPT, Client as AnthropicClient, HUMAN_PROMPT } from "@anthropic-ai/sdk";
import { DiscussServiceClient, TextServiceClient } from "@google-ai/generativelanguage";
import { GoogleAuth } from "google-auth-library";
function get_environ(key: string): string | undefined {
if (key in process_env)
@ -48,7 +51,7 @@ export function set_api_keys(api_keys: StringDict): void {
}
/** Equivalent to a Python enum's .name property */
function getEnumName(enumObject: any, enumValue: any): string | undefined {
export function getEnumName(enumObject: any, enumValue: any): string | undefined {
for (const key in enumObject) {
if (enumObject[key] === enumValue) {
return key;
@ -129,17 +132,6 @@ export async function call_chatgpt(prompt: string, model: LLM, n: number = 1, te
return [query, response];
}
function _test() {
// Call ChatGPT
call_chatgpt("Who invented modern playing cards?",
LLM.OpenAI_ChatGPT,
1, 1.0).then(([query, response]) => {
console.log(response.choices[0].message);
});
}
_test();
/**
* Calls OpenAI models hosted on Microsoft Azure services.
* Returns raw query and response JSON dicts.
@ -247,7 +239,7 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
...params,
};
console.log(`Calling Anthropic model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`)
console.log(`Calling Anthropic model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`);
// Repeat call n times, waiting for each response to come in:
let responses: Array<Dict> = [];
@ -260,67 +252,94 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
return [query, responses];
}
// async def call_google_palm(prompt: str, model: LLM, n: int = 1, temperature: float= 0.7,
// max_output_tokens=800,
// async_mode=False,
// **params) -> Tuple[Dict, Dict]:
// """
// Calls a Google PaLM model.
// Returns raw query and response JSON dicts.
// """
// if GOOGLE_PALM_API_KEY is None:
// raise Exception("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local Python environment.")
/**
* Calls a Google PaLM model.
Returns raw query and response JSON dicts.
*/
export async function call_google_palm(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
if (!GOOGLE_PALM_API_KEY)
throw Error("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local environment.");
// import google.generativeai as palm
// palm.configure(api_key=GOOGLE_PALM_API_KEY)
const is_chat_model = model.toString().includes('chat');
const client = new (is_chat_model ? DiscussServiceClient : TextServiceClient)({
authClient: new GoogleAuth().fromAPIKey(GOOGLE_PALM_API_KEY),
});
// is_chat_model = 'chat' in model.value
// Required non-standard params
const max_output_tokens = params?.max_output_tokens || 800;
// query = {
// 'model': f"models/{model.value}",
// 'prompt': prompt,
// 'candidate_count': n,
// 'temperature': temperature,
// 'max_output_tokens': max_output_tokens,
// **params,
// }
let query: Dict = {
model: `models/${model}`,
candidate_count: n,
temperature: temperature,
max_output_tokens: max_output_tokens,
...params,
};
// # Remove erroneous parameters for text and chat models
// if 'top_k' in query and query['top_k'] <= 0:
// del query['top_k']
// if 'top_p' in query and query['top_p'] <= 0:
// del query['top_p']
// if is_chat_model and 'max_output_tokens' in query:
// del query['max_output_tokens']
// if is_chat_model and 'stop_sequences' in query:
// del query['stop_sequences']
// # Get the correct model's completions call
// palm_call = palm.chat if is_chat_model else palm.generate_text
// Remove erroneous parameters for text and chat models
if (query.top_k !== undefined && query.top_k <= 0)
delete query.top_k;
if (query.top_p !== undefined && query.top_p <= 0)
delete query.top_p;
if (is_chat_model && query.max_output_tokens !== undefined)
delete query.max_output_tokens;
if (is_chat_model && query.stop_sequences !== undefined)
delete query.stop_sequences;
// # Google PaLM's python API does not currently support async calls.
// # To make one, we need to wrap it in an asynchronous executor:
// completion = await make_sync_call_async(palm_call, **query)
// completion_dict = completion.to_dict()
// For some reason Google needs to be special and have its API params be different names --camel or snake-case
// --depending on if it's the Python or Node JS API. ChainForge needs a consistent name, so we must convert snake to camel:
const casemap = {
safety_settings: 'safetySettings',
stop_sequences: 'stopSequences',
candidate_count: 'candidateCount',
max_output_tokens: 'maxOutputTokens',
top_p: 'topP',
top_k: 'topK',
};
Object.entries(casemap).forEach(([key, val]) => {
if (key in query) {
query[val] = query[key];
delete query[key];
}
});
// # Google PaLM, unlike other chat models, will output empty
// # responses for any response it deems unsafe (blocks). Although the text completions
// # API has a (relatively undocumented) 'safety_settings' parameter,
// # the current chat completions API provides users no control over the blocking.
// # We need to detect this and fill the response with the safety reasoning:
// if len(completion.filters) > 0:
// # Request was blocked. Output why in the response text,
// # repairing the candidate dict to mock up 'n' responses
// block_error_msg = f'[[BLOCKED_REQUEST]] Request was blocked because it triggered safety filters: {str(completion.filters)}'
// completion_dict['candidates'] = [{'author': 1, 'content':block_error_msg}] * n
console.log(`Calling Google PaLM model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`);
// Call the correct model client
let completion;
if (is_chat_model) {
// Chat completions
query.prompt = { messages: [{content: prompt}] };
completion = await (client as DiscussServiceClient).generateMessage(query);
} else {
// Text completions
query.prompt = { text: prompt };
completion = await (client as TextServiceClient).generateText(query);
}
// # Weirdly, google ignores candidate_count if temperature is 0.
// # We have to check for this and manually append the n-1 responses:
// if n > 1 and temperature == 0 and len(completion_dict['candidates']) == 1:
// copied_candidates = [completion_dict['candidates'][0]] * n
// completion_dict['candidates'] = copied_candidates
// Google PaLM, unlike other chat models, will output empty
// responses for any response it deems unsafe (blocks). Although the text completions
// API has a (relatively undocumented) 'safety_settings' parameter,
// the current chat completions API provides users no control over the blocking.
// We need to detect this and fill the response with the safety reasoning:
if (completion[0].filters.length > 0) {
// Request was blocked. Output why in the response text, repairing the candidate dict to mock up 'n' responses
const block_error_msg = `[[BLOCKED_REQUEST]] Request was blocked because it triggered safety filters: ${JSON.stringify(completion.filters)}`
completion[0].candidates = new Array(n).fill({'author': '1', 'content':block_error_msg});
}
// return query, completion_dict
// Weirdly, google ignores candidate_count if temperature is 0.
// We have to check for this and manually append the n-1 responses:
// if n > 1 and temperature == 0 and len(completion_dict['candidates']) == 1:
// copied_candidates = [completion_dict['candidates'][0]] * n
// completion_dict['candidates'] = copied_candidates
return [query, completion[0]];
}
export async function call_dalai(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
throw Error("Dalai support in JS backend is not yet implemented.");
}
// async def call_dalai(prompt: str, model: LLM, server: str="http://localhost:4000", n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
// """
@ -463,11 +482,18 @@ function _extract_palm_responses(completion: Dict): Array<string> {
return completion['candidates'].map((c: Dict) => c.output || c.content);
}
/**
* Extracts the text part of an Anthropic text completion.
*/
function _extract_anthropic_responses(response: Array<Dict>): Array<string> {
return response.map((r: Dict) => r.completion.trim());
}
/**
* Given a LLM and a response object from its API, extract the
* text response(s) part of the response object.
*/
function extract_responses(response: Array<string> | Dict, llm: LLM | string): Array<string> {
export function extract_responses(response: Array<string | Dict> | Dict, llm: LLM | string): Array<string> {
const llm_name = getEnumName(LLM, llm.toString());
if (llm_name?.startsWith('OpenAI')) {
if (llm_name.toLowerCase().includes('davinci'))
@ -475,42 +501,41 @@ function extract_responses(response: Array<string> | Dict, llm: LLM | string): A
else
return _extract_chatgpt_responses(response);
} else if (llm_name?.startsWith('Azure'))
return _extract_openai_responses(response)
return _extract_openai_responses(response);
else if (llm_name?.startsWith('PaLM2'))
return _extract_palm_responses(response)
return _extract_palm_responses(response);
else if (llm_name?.startsWith('Dalai'))
return [response.toString()];
else if (llm.toString().startsWith('claude'))
return response.map((r: Dict) => r.completion);
return _extract_anthropic_responses(response as Dict[]);
else
throw new Error(`LLM ${llm_str} is unsupported.`)
throw new Error(`No method defined to extract responses for LLM ${llm}.`)
}
function merge_response_objs(resp_obj_A: Dict | undefined, resp_obj_B: Dict | undefined): Dict {
export function merge_response_objs(resp_obj_A: LLMResponseObject | undefined, resp_obj_B: LLMResponseObject | undefined): LLMResponseObject | undefined {
if (!resp_obj_A && !resp_obj_B) {
console.warn('Warning: Merging two undefined response objects.')
return {};
return undefined;
} else if (!resp_obj_B && resp_obj_A)
return resp_obj_A;
else if (!resp_obj_A && resp_obj_B)
return resp_obj_B;
let raw_resp_A = resp_obj_A?.raw_response;
let raw_resp_B = resp_obj_B?.raw_response;
resp_obj_A = resp_obj_A as LLMResponseObject; // required by typescript
resp_obj_B = resp_obj_B as LLMResponseObject;
let raw_resp_A = resp_obj_A.raw_response;
let raw_resp_B = resp_obj_B.raw_response;
if (!Array.isArray(raw_resp_A))
raw_resp_A = [ raw_resp_A ];
if (!Array.isArray(raw_resp_B))
raw_resp_B = [ raw_resp_B ];
const C: Dict = {
responses: resp_obj_A?.responses.concat(resp_obj_B?.responses),
raw_response: raw_resp_A.concat(raw_resp_B),
};
return {
...C,
prompt: resp_obj_B?.prompt,
query: resp_obj_B?.query,
llm: resp_obj_B?.llm,
info: resp_obj_B?.info,
metavars: resp_obj_B?.metavars,
responses: resp_obj_A.responses.concat(resp_obj_B.responses),
raw_response: raw_resp_A.concat(raw_resp_B),
prompt: resp_obj_B.prompt,
query: resp_obj_B.query,
llm: resp_obj_B.llm,
info: resp_obj_B.info,
metavars: resp_obj_B.metavars,
};
}