mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-15 00:36:29 +00:00
Refactored to LLMProvider to streamline model additions
This commit is contained in:
parent
d401216744
commit
34884345d9
@ -1,20 +1,9 @@
|
||||
// import json, os, asyncio, sys, traceback
|
||||
// from dataclasses import dataclass
|
||||
// from enum import Enum
|
||||
// from typing import Union, List
|
||||
// from statistics import mean, median, stdev
|
||||
// from flask import Flask, request, jsonify, render_template
|
||||
// from flask_cors import CORS
|
||||
// from chainforge.promptengine.query import PromptLLM, PromptLLMDummy, LLMResponseException
|
||||
// from chainforge.promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
// from chainforge.promptengine.utils import LLM, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists, set_api_keys
|
||||
|
||||
import { mean as __mean, std as __std, median as __median } from "mathjs";
|
||||
import markdownIt from "markdown-it";
|
||||
|
||||
import { Dict, LLMResponseError, LLMResponseObject, StandardizedLLMResponse } from "./typing";
|
||||
import { LLM } from "./models";
|
||||
import { APP_IS_RUNNING_LOCALLY, getEnumName, set_api_keys, FLASK_BASE_URL, call_flask_backend } from "./utils";
|
||||
import { LLM, getEnumName } from "./models";
|
||||
import { APP_IS_RUNNING_LOCALLY, set_api_keys, FLASK_BASE_URL, call_flask_backend } from "./utils";
|
||||
import StorageCache from "./cache";
|
||||
import { PromptPipeline } from "./query";
|
||||
import { PromptPermutationGenerator, PromptTemplate } from "./template";
|
||||
|
@ -2,44 +2,75 @@
|
||||
* A list of all model APIs natively supported by ChainForge.
|
||||
*/
|
||||
export enum LLM {
|
||||
// OpenAI Chat
|
||||
OpenAI_ChatGPT = "gpt-3.5-turbo",
|
||||
OpenAI_ChatGPT_16k = "gpt-3.5-turbo-16k",
|
||||
OpenAI_ChatGPT_16k_0613 = "gpt-3.5-turbo-16k-0613",
|
||||
OpenAI_ChatGPT_0301 = "gpt-3.5-turbo-0301",
|
||||
OpenAI_ChatGPT_0613 = "gpt-3.5-turbo-0613",
|
||||
OpenAI_GPT4 = "gpt-4",
|
||||
OpenAI_GPT4_0314 = "gpt-4-0314",
|
||||
OpenAI_GPT4_0613 = "gpt-4-0613",
|
||||
OpenAI_GPT4_32k = "gpt-4-32k",
|
||||
OpenAI_GPT4_32k_0314 = "gpt-4-32k-0314",
|
||||
OpenAI_GPT4_32k_0613 = "gpt-4-32k-0613",
|
||||
// OpenAI Chat
|
||||
OpenAI_ChatGPT = "gpt-3.5-turbo",
|
||||
OpenAI_ChatGPT_16k = "gpt-3.5-turbo-16k",
|
||||
OpenAI_ChatGPT_16k_0613 = "gpt-3.5-turbo-16k-0613",
|
||||
OpenAI_ChatGPT_0301 = "gpt-3.5-turbo-0301",
|
||||
OpenAI_ChatGPT_0613 = "gpt-3.5-turbo-0613",
|
||||
OpenAI_GPT4 = "gpt-4",
|
||||
OpenAI_GPT4_0314 = "gpt-4-0314",
|
||||
OpenAI_GPT4_0613 = "gpt-4-0613",
|
||||
OpenAI_GPT4_32k = "gpt-4-32k",
|
||||
OpenAI_GPT4_32k_0314 = "gpt-4-32k-0314",
|
||||
OpenAI_GPT4_32k_0613 = "gpt-4-32k-0613",
|
||||
|
||||
// OpenAI Text Completions
|
||||
OpenAI_Davinci003 = "text-davinci-003",
|
||||
OpenAI_Davinci002 = "text-davinci-002",
|
||||
// OpenAI Text Completions
|
||||
OpenAI_Davinci003 = "text-davinci-003",
|
||||
OpenAI_Davinci002 = "text-davinci-002",
|
||||
|
||||
// Azure OpenAI Endpoints
|
||||
Azure_OpenAI = "azure-openai",
|
||||
// Azure OpenAI Endpoints
|
||||
Azure_OpenAI = "azure-openai",
|
||||
|
||||
// Dalai-served models (Alpaca and Llama)
|
||||
Dalai_Alpaca_7B = "alpaca.7B",
|
||||
Dalai_Alpaca_13B = "alpaca.13B",
|
||||
Dalai_Llama_7B = "llama.7B",
|
||||
Dalai_Llama_13B = "llama.13B",
|
||||
Dalai_Llama_30B = "llama.30B",
|
||||
Dalai_Llama_65B = "llama.65B",
|
||||
// Dalai-served models (Alpaca and Llama)
|
||||
Dalai_Alpaca_7B = "alpaca.7B",
|
||||
Dalai_Alpaca_13B = "alpaca.13B",
|
||||
Dalai_Llama_7B = "llama.7B",
|
||||
Dalai_Llama_13B = "llama.13B",
|
||||
Dalai_Llama_30B = "llama.30B",
|
||||
Dalai_Llama_65B = "llama.65B",
|
||||
|
||||
// Anthropic
|
||||
Claude_v1 = "claude-v1",
|
||||
Claude_v1_0 = "claude-v1.0",
|
||||
Claude_v1_2 = "claude-v1.2",
|
||||
Claude_v1_3 = "claude-v1.3",
|
||||
Claude_v1_instant = "claude-instant-v1",
|
||||
// Anthropic
|
||||
Claude_v1 = "claude-v1",
|
||||
Claude_v1_0 = "claude-v1.0",
|
||||
Claude_v1_2 = "claude-v1.2",
|
||||
Claude_v1_3 = "claude-v1.3",
|
||||
Claude_v1_instant = "claude-instant-v1",
|
||||
|
||||
// Google models
|
||||
PaLM2_Text_Bison = "text-bison-001", // it's really models/text-bison-001, but that's confusing
|
||||
PaLM2_Chat_Bison = "chat-bison-001",
|
||||
// Google models
|
||||
PaLM2_Text_Bison = "text-bison-001", // it's really models/text-bison-001, but that's confusing
|
||||
PaLM2_Chat_Bison = "chat-bison-001",
|
||||
}
|
||||
|
||||
/**
|
||||
* A list of model providers
|
||||
*/
|
||||
export enum LLMProvider {
|
||||
OpenAI = "openai",
|
||||
Azure_OpenAI = "azure",
|
||||
Dalai = "dalai",
|
||||
Anthropic = "anthropic",
|
||||
Google = "google",
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an LLM, return what the model provider is.
|
||||
* @param llm the specific large language model
|
||||
* @returns an `LLMProvider` describing what provider hosts the model
|
||||
*/
|
||||
export function getProvider(llm: LLM): LLMProvider | undefined {
|
||||
const llm_name = getEnumName(LLM, llm.toString());
|
||||
if (llm_name?.startsWith('OpenAI'))
|
||||
return LLMProvider.OpenAI;
|
||||
else if (llm_name?.startsWith('Azure'))
|
||||
return LLMProvider.Azure_OpenAI;
|
||||
else if (llm_name?.startsWith('PaLM2'))
|
||||
return LLMProvider.Google;
|
||||
else if (llm_name?.startsWith('Dalai'))
|
||||
return LLMProvider.Dalai;
|
||||
else if (llm.toString().startsWith('claude'))
|
||||
return LLMProvider.Anthropic;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
|
||||
@ -52,10 +83,25 @@ export enum LLM {
|
||||
export const RATE_LIMITS: { [key in LLM]?: [number, number] } = {
|
||||
[LLM.OpenAI_ChatGPT]: [30, 10], // max 30 requests a batch; wait 10 seconds between
|
||||
[LLM.OpenAI_ChatGPT_0301]: [30, 10],
|
||||
[LLM.OpenAI_ChatGPT_0613]: [30, 10],
|
||||
[LLM.OpenAI_ChatGPT_16k]: [30, 10],
|
||||
[LLM.OpenAI_ChatGPT_16k_0613]: [30, 10],
|
||||
[LLM.OpenAI_GPT4]: [4, 15], // max 4 requests a batch; wait 15 seconds between
|
||||
[LLM.OpenAI_GPT4_0314]: [4, 15],
|
||||
[LLM.OpenAI_GPT4_0613]: [4, 15],
|
||||
[LLM.OpenAI_GPT4_32k]: [4, 15],
|
||||
[LLM.OpenAI_GPT4_32k_0314]: [4, 15],
|
||||
[LLM.OpenAI_GPT4_32k_0613]: [4, 15],
|
||||
[LLM.Azure_OpenAI]: [30, 10],
|
||||
[LLM.PaLM2_Text_Bison]: [4, 10], // max 30 requests per minute; so do 4 per batch, 10 seconds between (conservative)
|
||||
[LLM.PaLM2_Chat_Bison]: [4, 10],
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
/** Equivalent to a Python enum's .name property */
|
||||
export function getEnumName(enumObject: any, enumValue: any): string | undefined {
|
||||
for (const key in enumObject)
|
||||
if (enumObject[key] === enumValue)
|
||||
return key;
|
||||
return undefined;
|
||||
}
|
@ -3,7 +3,7 @@
|
||||
// from string import Template
|
||||
|
||||
// from chainforge.promptengine.models import LLM
|
||||
import { LLM } from './models';
|
||||
import { LLM, LLMProvider, getProvider } from './models';
|
||||
import { Dict, StringDict, LLMAPICall, LLMResponseObject } from './typing';
|
||||
import { env as process_env } from 'process';
|
||||
import { StringTemplate } from './template';
|
||||
@ -30,6 +30,28 @@ export async function call_flask_backend(route: string, params: Dict | string):
|
||||
});
|
||||
}
|
||||
|
||||
// We only calculate whether the app is running locally once upon load, and store it here:
|
||||
let _APP_IS_RUNNING_LOCALLY: boolean | undefined = undefined;
|
||||
|
||||
/**
|
||||
* Tries to determine if the ChainForge front-end is running on user's local machine (and hence has access to Flask backend).
|
||||
* @returns `true` if we think the app is running locally (on localhost or equivalent); `false` if not.
|
||||
*/
|
||||
export function APP_IS_RUNNING_LOCALLY(): boolean {
|
||||
if (_APP_IS_RUNNING_LOCALLY === undefined) {
|
||||
// Calculate whether we're running the app locally or not, and save the result
|
||||
try {
|
||||
const location = window.location;
|
||||
_APP_IS_RUNNING_LOCALLY = location.hostname === "localhost" || location.hostname === "127.0.0.1" || location.hostname === "";
|
||||
} catch (e) {
|
||||
// ReferenceError --window or location does not exist.
|
||||
// We must not be running client-side in a browser, in this case (e.g., we are running a Node.js server)
|
||||
_APP_IS_RUNNING_LOCALLY = false;
|
||||
}
|
||||
}
|
||||
return _APP_IS_RUNNING_LOCALLY;
|
||||
}
|
||||
|
||||
/**
|
||||
* Equivalent to a 'fetch' call, but routes it to the backend Flask server in
|
||||
* case we are running a local server and prefer to not deal with CORS issues making API calls client-side.
|
||||
@ -93,29 +115,6 @@ export function set_api_keys(api_keys: StringDict): void {
|
||||
// Soft fail for non-present keys
|
||||
}
|
||||
|
||||
/** Equivalent to a Python enum's .name property */
|
||||
export function getEnumName(enumObject: any, enumValue: any): string | undefined {
|
||||
for (const key in enumObject) {
|
||||
if (enumObject[key] === enumValue) {
|
||||
return key;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// async def make_sync_call_async(sync_method, *args, **params):
|
||||
// """
|
||||
// Makes a blocking synchronous call asynchronous, so that it can be awaited.
|
||||
// NOTE: This is necessary for LLM APIs that do not yet support async (e.g. Google PaLM).
|
||||
// """
|
||||
// loop = asyncio.get_running_loop()
|
||||
// method = sync_method
|
||||
// if len(params) > 0:
|
||||
// def partial_sync_meth(*a):
|
||||
// return sync_method(*a, **params)
|
||||
// method = partial_sync_meth
|
||||
// return await loop.run_in_executor(None, method, *args)
|
||||
|
||||
/**
|
||||
* Calls OpenAI models via OpenAI's API.
|
||||
@returns raw query and response JSON dicts.
|
||||
@ -311,7 +310,6 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
|
||||
while (responses.length < n) {
|
||||
const resp = await route_fetch(url, 'POST', headers, query);
|
||||
responses.push(resp);
|
||||
console.log(`${model} response ${responses.length} of ${n}:\n${resp}`);
|
||||
}
|
||||
|
||||
return [query, responses];
|
||||
@ -506,20 +504,21 @@ export async function call_dalai(prompt: string, model: LLM, n: number = 1, temp
|
||||
export async function call_llm(llm: LLM, prompt: string, n: number, temperature: number, params?: Dict): Promise<[Dict, Dict]> {
|
||||
// Get the correct API call for the given LLM:
|
||||
let call_api: LLMAPICall | undefined;
|
||||
const llm_name = getEnumName(LLM, llm.toString());
|
||||
if (llm_name?.startsWith('OpenAI'))
|
||||
call_api = call_chatgpt;
|
||||
else if (llm_name?.startsWith('Azure'))
|
||||
call_api = call_azure_openai;
|
||||
else if (llm_name?.startsWith('PaLM2'))
|
||||
call_api = call_google_palm;
|
||||
else if (llm_name?.startsWith('Dalai'))
|
||||
call_api = call_dalai;
|
||||
else if (llm.toString().startsWith('claude'))
|
||||
call_api = call_anthropic;
|
||||
|
||||
if (!call_api)
|
||||
let llm_provider: LLMProvider = getProvider(llm);
|
||||
|
||||
if (llm_provider === undefined)
|
||||
throw new Error(`Language model ${llm} is not supported.`);
|
||||
|
||||
if (llm_provider === LLMProvider.OpenAI)
|
||||
call_api = call_chatgpt;
|
||||
else if (llm_provider === LLMProvider.Azure_OpenAI)
|
||||
call_api = call_azure_openai;
|
||||
else if (llm_provider === LLMProvider.Google)
|
||||
call_api = call_google_palm;
|
||||
else if (llm_provider === LLMProvider.Dalai)
|
||||
call_api = call_dalai;
|
||||
else if (llm_provider === LLMProvider.Anthropic)
|
||||
call_api = call_anthropic;
|
||||
|
||||
return call_api(prompt, llm, n, temperature, params);
|
||||
}
|
||||
@ -594,24 +593,33 @@ function _extract_anthropic_responses(response: Array<Dict>): Array<string> {
|
||||
* text response(s) part of the response object.
|
||||
*/
|
||||
export function extract_responses(response: Array<string | Dict> | Dict, llm: LLM | string): Array<string> {
|
||||
const llm_name = getEnumName(LLM, llm.toString());
|
||||
if (llm_name?.startsWith('OpenAI')) {
|
||||
if (llm_name.toLowerCase().includes('davinci'))
|
||||
return _extract_openai_completion_responses(response);
|
||||
else
|
||||
return _extract_chatgpt_responses(response);
|
||||
} else if (llm_name?.startsWith('Azure'))
|
||||
return _extract_openai_responses(response);
|
||||
else if (llm_name?.startsWith('PaLM2'))
|
||||
return _extract_palm_responses(response);
|
||||
else if (llm_name?.startsWith('Dalai'))
|
||||
return [response.toString()];
|
||||
else if (llm.toString().startsWith('claude'))
|
||||
return _extract_anthropic_responses(response as Dict[]);
|
||||
else
|
||||
throw new Error(`No method defined to extract responses for LLM ${llm}.`)
|
||||
let llm_provider: LLMProvider = getProvider(llm as LLM);
|
||||
|
||||
switch (llm_provider) {
|
||||
case LLMProvider.OpenAI:
|
||||
if (llm.toString().toLowerCase().includes('davinci'))
|
||||
return _extract_openai_completion_responses(response);
|
||||
else
|
||||
return _extract_chatgpt_responses(response);
|
||||
case LLMProvider.Azure_OpenAI:
|
||||
return _extract_openai_responses(response);
|
||||
case LLMProvider.Google:
|
||||
return _extract_palm_responses(response);
|
||||
case LLMProvider.Dalai:
|
||||
return [response.toString()];
|
||||
case LLMProvider.Anthropic:
|
||||
return _extract_anthropic_responses(response as Dict[]);
|
||||
default:
|
||||
throw new Error(`No method defined to extract responses for LLM ${llm}.`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Marge the 'responses' and 'raw_response' properties of two LLMResponseObjects,
|
||||
* keeping all the other params from the second argument (llm, query, etc).
|
||||
*
|
||||
* If one object is undefined or null, returns the object that is defined, unaltered.
|
||||
*/
|
||||
export function merge_response_objs(resp_obj_A: LLMResponseObject | undefined, resp_obj_B: LLMResponseObject | undefined): LLMResponseObject | undefined {
|
||||
if (!resp_obj_A && !resp_obj_B) {
|
||||
console.warn('Warning: Merging two undefined response objects.')
|
||||
@ -638,48 +646,3 @@ export function merge_response_objs(resp_obj_A: LLMResponseObject | undefined, r
|
||||
metavars: resp_obj_B.metavars,
|
||||
};
|
||||
}
|
||||
|
||||
export function APP_IS_RUNNING_LOCALLY(): boolean {
|
||||
try {
|
||||
const location = window.location;
|
||||
return location.hostname === "localhost" || location.hostname === "127.0.0.1" || location.hostname === "";
|
||||
} catch (e) {
|
||||
// ReferenceError --window or location does not exist.
|
||||
// We must not be running client-side in a browser, in this case (e.g., we are running a Node.js server)
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// def create_dir_if_not_exists(path: str) -> None:
|
||||
// if not os.path.exists(path):
|
||||
// os.makedirs(path)
|
||||
|
||||
// def is_valid_filepath(filepath: str) -> bool:
|
||||
// try:
|
||||
// with open(filepath, 'r', encoding='utf-8'):
|
||||
// pass
|
||||
// except IOError:
|
||||
// try:
|
||||
// # Create the file if it doesn't exist, and write an empty json string to it
|
||||
// with open(filepath, 'w+', encoding='utf-8') as f:
|
||||
// f.write("{}")
|
||||
// pass
|
||||
// except IOError:
|
||||
// return False
|
||||
// return True
|
||||
|
||||
// def is_valid_json(json_dict: dict) -> bool:
|
||||
// if isinstance(json_dict, dict):
|
||||
// try:
|
||||
// json.dumps(json_dict)
|
||||
// return True
|
||||
// except Exception:
|
||||
// pass
|
||||
// return False
|
||||
|
||||
// def get_files_at_dir(path: str) -> list:
|
||||
// f = []
|
||||
// for (dirpath, dirnames, filenames) in os.walk(path):
|
||||
// f = filenames
|
||||
// break
|
||||
// return f
|
Loading…
x
Reference in New Issue
Block a user