Refactored to LLMProvider to streamline model additions

This commit is contained in:
Ian Arawjo 2023-06-29 17:29:10 -04:00
parent d401216744
commit 34884345d9
3 changed files with 144 additions and 146 deletions

View File

@ -1,20 +1,9 @@
// import json, os, asyncio, sys, traceback
// from dataclasses import dataclass
// from enum import Enum
// from typing import Union, List
// from statistics import mean, median, stdev
// from flask import Flask, request, jsonify, render_template
// from flask_cors import CORS
// from chainforge.promptengine.query import PromptLLM, PromptLLMDummy, LLMResponseException
// from chainforge.promptengine.template import PromptTemplate, PromptPermutationGenerator
// from chainforge.promptengine.utils import LLM, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists, set_api_keys
import { mean as __mean, std as __std, median as __median } from "mathjs";
import markdownIt from "markdown-it";
import { Dict, LLMResponseError, LLMResponseObject, StandardizedLLMResponse } from "./typing";
import { LLM } from "./models";
import { APP_IS_RUNNING_LOCALLY, getEnumName, set_api_keys, FLASK_BASE_URL, call_flask_backend } from "./utils";
import { LLM, getEnumName } from "./models";
import { APP_IS_RUNNING_LOCALLY, set_api_keys, FLASK_BASE_URL, call_flask_backend } from "./utils";
import StorageCache from "./cache";
import { PromptPipeline } from "./query";
import { PromptPermutationGenerator, PromptTemplate } from "./template";

View File

@ -2,44 +2,75 @@
* A list of all model APIs natively supported by ChainForge.
*/
export enum LLM {
// OpenAI Chat
OpenAI_ChatGPT = "gpt-3.5-turbo",
OpenAI_ChatGPT_16k = "gpt-3.5-turbo-16k",
OpenAI_ChatGPT_16k_0613 = "gpt-3.5-turbo-16k-0613",
OpenAI_ChatGPT_0301 = "gpt-3.5-turbo-0301",
OpenAI_ChatGPT_0613 = "gpt-3.5-turbo-0613",
OpenAI_GPT4 = "gpt-4",
OpenAI_GPT4_0314 = "gpt-4-0314",
OpenAI_GPT4_0613 = "gpt-4-0613",
OpenAI_GPT4_32k = "gpt-4-32k",
OpenAI_GPT4_32k_0314 = "gpt-4-32k-0314",
OpenAI_GPT4_32k_0613 = "gpt-4-32k-0613",
// OpenAI Chat
OpenAI_ChatGPT = "gpt-3.5-turbo",
OpenAI_ChatGPT_16k = "gpt-3.5-turbo-16k",
OpenAI_ChatGPT_16k_0613 = "gpt-3.5-turbo-16k-0613",
OpenAI_ChatGPT_0301 = "gpt-3.5-turbo-0301",
OpenAI_ChatGPT_0613 = "gpt-3.5-turbo-0613",
OpenAI_GPT4 = "gpt-4",
OpenAI_GPT4_0314 = "gpt-4-0314",
OpenAI_GPT4_0613 = "gpt-4-0613",
OpenAI_GPT4_32k = "gpt-4-32k",
OpenAI_GPT4_32k_0314 = "gpt-4-32k-0314",
OpenAI_GPT4_32k_0613 = "gpt-4-32k-0613",
// OpenAI Text Completions
OpenAI_Davinci003 = "text-davinci-003",
OpenAI_Davinci002 = "text-davinci-002",
// OpenAI Text Completions
OpenAI_Davinci003 = "text-davinci-003",
OpenAI_Davinci002 = "text-davinci-002",
// Azure OpenAI Endpoints
Azure_OpenAI = "azure-openai",
// Azure OpenAI Endpoints
Azure_OpenAI = "azure-openai",
// Dalai-served models (Alpaca and Llama)
Dalai_Alpaca_7B = "alpaca.7B",
Dalai_Alpaca_13B = "alpaca.13B",
Dalai_Llama_7B = "llama.7B",
Dalai_Llama_13B = "llama.13B",
Dalai_Llama_30B = "llama.30B",
Dalai_Llama_65B = "llama.65B",
// Dalai-served models (Alpaca and Llama)
Dalai_Alpaca_7B = "alpaca.7B",
Dalai_Alpaca_13B = "alpaca.13B",
Dalai_Llama_7B = "llama.7B",
Dalai_Llama_13B = "llama.13B",
Dalai_Llama_30B = "llama.30B",
Dalai_Llama_65B = "llama.65B",
// Anthropic
Claude_v1 = "claude-v1",
Claude_v1_0 = "claude-v1.0",
Claude_v1_2 = "claude-v1.2",
Claude_v1_3 = "claude-v1.3",
Claude_v1_instant = "claude-instant-v1",
// Anthropic
Claude_v1 = "claude-v1",
Claude_v1_0 = "claude-v1.0",
Claude_v1_2 = "claude-v1.2",
Claude_v1_3 = "claude-v1.3",
Claude_v1_instant = "claude-instant-v1",
// Google models
PaLM2_Text_Bison = "text-bison-001", // it's really models/text-bison-001, but that's confusing
PaLM2_Chat_Bison = "chat-bison-001",
// Google models
PaLM2_Text_Bison = "text-bison-001", // it's really models/text-bison-001, but that's confusing
PaLM2_Chat_Bison = "chat-bison-001",
}
/**
* A list of model providers
*/
export enum LLMProvider {
OpenAI = "openai",
Azure_OpenAI = "azure",
Dalai = "dalai",
Anthropic = "anthropic",
Google = "google",
}
/**
* Given an LLM, return what the model provider is.
* @param llm the specific large language model
* @returns an `LLMProvider` describing what provider hosts the model
*/
export function getProvider(llm: LLM): LLMProvider | undefined {
const llm_name = getEnumName(LLM, llm.toString());
if (llm_name?.startsWith('OpenAI'))
return LLMProvider.OpenAI;
else if (llm_name?.startsWith('Azure'))
return LLMProvider.Azure_OpenAI;
else if (llm_name?.startsWith('PaLM2'))
return LLMProvider.Google;
else if (llm_name?.startsWith('Dalai'))
return LLMProvider.Dalai;
else if (llm.toString().startsWith('claude'))
return LLMProvider.Anthropic;
return undefined;
}
@ -52,10 +83,25 @@ export enum LLM {
export const RATE_LIMITS: { [key in LLM]?: [number, number] } = {
[LLM.OpenAI_ChatGPT]: [30, 10], // max 30 requests a batch; wait 10 seconds between
[LLM.OpenAI_ChatGPT_0301]: [30, 10],
[LLM.OpenAI_ChatGPT_0613]: [30, 10],
[LLM.OpenAI_ChatGPT_16k]: [30, 10],
[LLM.OpenAI_ChatGPT_16k_0613]: [30, 10],
[LLM.OpenAI_GPT4]: [4, 15], // max 4 requests a batch; wait 15 seconds between
[LLM.OpenAI_GPT4_0314]: [4, 15],
[LLM.OpenAI_GPT4_0613]: [4, 15],
[LLM.OpenAI_GPT4_32k]: [4, 15],
[LLM.OpenAI_GPT4_32k_0314]: [4, 15],
[LLM.OpenAI_GPT4_32k_0613]: [4, 15],
[LLM.Azure_OpenAI]: [30, 10],
[LLM.PaLM2_Text_Bison]: [4, 10], // max 30 requests per minute; so do 4 per batch, 10 seconds between (conservative)
[LLM.PaLM2_Chat_Bison]: [4, 10],
};
};
/** Equivalent to a Python enum's .name property */
export function getEnumName(enumObject: any, enumValue: any): string | undefined {
for (const key in enumObject)
if (enumObject[key] === enumValue)
return key;
return undefined;
}

View File

@ -3,7 +3,7 @@
// from string import Template
// from chainforge.promptengine.models import LLM
import { LLM } from './models';
import { LLM, LLMProvider, getProvider } from './models';
import { Dict, StringDict, LLMAPICall, LLMResponseObject } from './typing';
import { env as process_env } from 'process';
import { StringTemplate } from './template';
@ -30,6 +30,28 @@ export async function call_flask_backend(route: string, params: Dict | string):
});
}
// We only calculate whether the app is running locally once upon load, and store it here:
let _APP_IS_RUNNING_LOCALLY: boolean | undefined = undefined;
/**
* Tries to determine if the ChainForge front-end is running on user's local machine (and hence has access to Flask backend).
* @returns `true` if we think the app is running locally (on localhost or equivalent); `false` if not.
*/
export function APP_IS_RUNNING_LOCALLY(): boolean {
if (_APP_IS_RUNNING_LOCALLY === undefined) {
// Calculate whether we're running the app locally or not, and save the result
try {
const location = window.location;
_APP_IS_RUNNING_LOCALLY = location.hostname === "localhost" || location.hostname === "127.0.0.1" || location.hostname === "";
} catch (e) {
// ReferenceError --window or location does not exist.
// We must not be running client-side in a browser, in this case (e.g., we are running a Node.js server)
_APP_IS_RUNNING_LOCALLY = false;
}
}
return _APP_IS_RUNNING_LOCALLY;
}
/**
* Equivalent to a 'fetch' call, but routes it to the backend Flask server in
* case we are running a local server and prefer to not deal with CORS issues making API calls client-side.
@ -93,29 +115,6 @@ export function set_api_keys(api_keys: StringDict): void {
// Soft fail for non-present keys
}
/** Equivalent to a Python enum's .name property */
export function getEnumName(enumObject: any, enumValue: any): string | undefined {
for (const key in enumObject) {
if (enumObject[key] === enumValue) {
return key;
}
}
return undefined;
}
// async def make_sync_call_async(sync_method, *args, **params):
// """
// Makes a blocking synchronous call asynchronous, so that it can be awaited.
// NOTE: This is necessary for LLM APIs that do not yet support async (e.g. Google PaLM).
// """
// loop = asyncio.get_running_loop()
// method = sync_method
// if len(params) > 0:
// def partial_sync_meth(*a):
// return sync_method(*a, **params)
// method = partial_sync_meth
// return await loop.run_in_executor(None, method, *args)
/**
* Calls OpenAI models via OpenAI's API.
@returns raw query and response JSON dicts.
@ -311,7 +310,6 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
while (responses.length < n) {
const resp = await route_fetch(url, 'POST', headers, query);
responses.push(resp);
console.log(`${model} response ${responses.length} of ${n}:\n${resp}`);
}
return [query, responses];
@ -506,20 +504,21 @@ export async function call_dalai(prompt: string, model: LLM, n: number = 1, temp
export async function call_llm(llm: LLM, prompt: string, n: number, temperature: number, params?: Dict): Promise<[Dict, Dict]> {
// Get the correct API call for the given LLM:
let call_api: LLMAPICall | undefined;
const llm_name = getEnumName(LLM, llm.toString());
if (llm_name?.startsWith('OpenAI'))
call_api = call_chatgpt;
else if (llm_name?.startsWith('Azure'))
call_api = call_azure_openai;
else if (llm_name?.startsWith('PaLM2'))
call_api = call_google_palm;
else if (llm_name?.startsWith('Dalai'))
call_api = call_dalai;
else if (llm.toString().startsWith('claude'))
call_api = call_anthropic;
if (!call_api)
let llm_provider: LLMProvider = getProvider(llm);
if (llm_provider === undefined)
throw new Error(`Language model ${llm} is not supported.`);
if (llm_provider === LLMProvider.OpenAI)
call_api = call_chatgpt;
else if (llm_provider === LLMProvider.Azure_OpenAI)
call_api = call_azure_openai;
else if (llm_provider === LLMProvider.Google)
call_api = call_google_palm;
else if (llm_provider === LLMProvider.Dalai)
call_api = call_dalai;
else if (llm_provider === LLMProvider.Anthropic)
call_api = call_anthropic;
return call_api(prompt, llm, n, temperature, params);
}
@ -594,24 +593,33 @@ function _extract_anthropic_responses(response: Array<Dict>): Array<string> {
* text response(s) part of the response object.
*/
export function extract_responses(response: Array<string | Dict> | Dict, llm: LLM | string): Array<string> {
const llm_name = getEnumName(LLM, llm.toString());
if (llm_name?.startsWith('OpenAI')) {
if (llm_name.toLowerCase().includes('davinci'))
return _extract_openai_completion_responses(response);
else
return _extract_chatgpt_responses(response);
} else if (llm_name?.startsWith('Azure'))
return _extract_openai_responses(response);
else if (llm_name?.startsWith('PaLM2'))
return _extract_palm_responses(response);
else if (llm_name?.startsWith('Dalai'))
return [response.toString()];
else if (llm.toString().startsWith('claude'))
return _extract_anthropic_responses(response as Dict[]);
else
throw new Error(`No method defined to extract responses for LLM ${llm}.`)
let llm_provider: LLMProvider = getProvider(llm as LLM);
switch (llm_provider) {
case LLMProvider.OpenAI:
if (llm.toString().toLowerCase().includes('davinci'))
return _extract_openai_completion_responses(response);
else
return _extract_chatgpt_responses(response);
case LLMProvider.Azure_OpenAI:
return _extract_openai_responses(response);
case LLMProvider.Google:
return _extract_palm_responses(response);
case LLMProvider.Dalai:
return [response.toString()];
case LLMProvider.Anthropic:
return _extract_anthropic_responses(response as Dict[]);
default:
throw new Error(`No method defined to extract responses for LLM ${llm}.`)
}
}
/**
* Marge the 'responses' and 'raw_response' properties of two LLMResponseObjects,
* keeping all the other params from the second argument (llm, query, etc).
*
* If one object is undefined or null, returns the object that is defined, unaltered.
*/
export function merge_response_objs(resp_obj_A: LLMResponseObject | undefined, resp_obj_B: LLMResponseObject | undefined): LLMResponseObject | undefined {
if (!resp_obj_A && !resp_obj_B) {
console.warn('Warning: Merging two undefined response objects.')
@ -638,48 +646,3 @@ export function merge_response_objs(resp_obj_A: LLMResponseObject | undefined, r
metavars: resp_obj_B.metavars,
};
}
export function APP_IS_RUNNING_LOCALLY(): boolean {
try {
const location = window.location;
return location.hostname === "localhost" || location.hostname === "127.0.0.1" || location.hostname === "";
} catch (e) {
// ReferenceError --window or location does not exist.
// We must not be running client-side in a browser, in this case (e.g., we are running a Node.js server)
return false;
}
}
// def create_dir_if_not_exists(path: str) -> None:
// if not os.path.exists(path):
// os.makedirs(path)
// def is_valid_filepath(filepath: str) -> bool:
// try:
// with open(filepath, 'r', encoding='utf-8'):
// pass
// except IOError:
// try:
// # Create the file if it doesn't exist, and write an empty json string to it
// with open(filepath, 'w+', encoding='utf-8') as f:
// f.write("{}")
// pass
// except IOError:
// return False
// return True
// def is_valid_json(json_dict: dict) -> bool:
// if isinstance(json_dict, dict):
// try:
// json.dumps(json_dict)
// return True
// except Exception:
// pass
// return False
// def get_files_at_dir(path: str) -> list:
// f = []
// for (dirpath, dirnames, filenames) in os.walk(path):
// f = filenames
// break
// return f