From 34884345d94a7fd0497c175b7f09fabcf68f9207 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Thu, 29 Jun 2023 17:29:10 -0400 Subject: [PATCH] Refactored to LLMProvider to streamline model additions --- .../react-server/src/backend/backend.ts | 15 +- chainforge/react-server/src/backend/models.ts | 114 +++++++++---- chainforge/react-server/src/backend/utils.ts | 161 +++++++----------- 3 files changed, 144 insertions(+), 146 deletions(-) diff --git a/chainforge/react-server/src/backend/backend.ts b/chainforge/react-server/src/backend/backend.ts index 0c37134..8ac7669 100644 --- a/chainforge/react-server/src/backend/backend.ts +++ b/chainforge/react-server/src/backend/backend.ts @@ -1,20 +1,9 @@ -// import json, os, asyncio, sys, traceback -// from dataclasses import dataclass -// from enum import Enum -// from typing import Union, List -// from statistics import mean, median, stdev -// from flask import Flask, request, jsonify, render_template -// from flask_cors import CORS -// from chainforge.promptengine.query import PromptLLM, PromptLLMDummy, LLMResponseException -// from chainforge.promptengine.template import PromptTemplate, PromptPermutationGenerator -// from chainforge.promptengine.utils import LLM, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists, set_api_keys - import { mean as __mean, std as __std, median as __median } from "mathjs"; import markdownIt from "markdown-it"; import { Dict, LLMResponseError, LLMResponseObject, StandardizedLLMResponse } from "./typing"; -import { LLM } from "./models"; -import { APP_IS_RUNNING_LOCALLY, getEnumName, set_api_keys, FLASK_BASE_URL, call_flask_backend } from "./utils"; +import { LLM, getEnumName } from "./models"; +import { APP_IS_RUNNING_LOCALLY, set_api_keys, FLASK_BASE_URL, call_flask_backend } from "./utils"; import StorageCache from "./cache"; import { PromptPipeline } from "./query"; import { PromptPermutationGenerator, PromptTemplate } from "./template"; diff --git a/chainforge/react-server/src/backend/models.ts b/chainforge/react-server/src/backend/models.ts index 2d847b4..42f9e14 100644 --- a/chainforge/react-server/src/backend/models.ts +++ b/chainforge/react-server/src/backend/models.ts @@ -2,44 +2,75 @@ * A list of all model APIs natively supported by ChainForge. */ export enum LLM { - // OpenAI Chat - OpenAI_ChatGPT = "gpt-3.5-turbo", - OpenAI_ChatGPT_16k = "gpt-3.5-turbo-16k", - OpenAI_ChatGPT_16k_0613 = "gpt-3.5-turbo-16k-0613", - OpenAI_ChatGPT_0301 = "gpt-3.5-turbo-0301", - OpenAI_ChatGPT_0613 = "gpt-3.5-turbo-0613", - OpenAI_GPT4 = "gpt-4", - OpenAI_GPT4_0314 = "gpt-4-0314", - OpenAI_GPT4_0613 = "gpt-4-0613", - OpenAI_GPT4_32k = "gpt-4-32k", - OpenAI_GPT4_32k_0314 = "gpt-4-32k-0314", - OpenAI_GPT4_32k_0613 = "gpt-4-32k-0613", + // OpenAI Chat + OpenAI_ChatGPT = "gpt-3.5-turbo", + OpenAI_ChatGPT_16k = "gpt-3.5-turbo-16k", + OpenAI_ChatGPT_16k_0613 = "gpt-3.5-turbo-16k-0613", + OpenAI_ChatGPT_0301 = "gpt-3.5-turbo-0301", + OpenAI_ChatGPT_0613 = "gpt-3.5-turbo-0613", + OpenAI_GPT4 = "gpt-4", + OpenAI_GPT4_0314 = "gpt-4-0314", + OpenAI_GPT4_0613 = "gpt-4-0613", + OpenAI_GPT4_32k = "gpt-4-32k", + OpenAI_GPT4_32k_0314 = "gpt-4-32k-0314", + OpenAI_GPT4_32k_0613 = "gpt-4-32k-0613", - // OpenAI Text Completions - OpenAI_Davinci003 = "text-davinci-003", - OpenAI_Davinci002 = "text-davinci-002", + // OpenAI Text Completions + OpenAI_Davinci003 = "text-davinci-003", + OpenAI_Davinci002 = "text-davinci-002", - // Azure OpenAI Endpoints - Azure_OpenAI = "azure-openai", + // Azure OpenAI Endpoints + Azure_OpenAI = "azure-openai", - // Dalai-served models (Alpaca and Llama) - Dalai_Alpaca_7B = "alpaca.7B", - Dalai_Alpaca_13B = "alpaca.13B", - Dalai_Llama_7B = "llama.7B", - Dalai_Llama_13B = "llama.13B", - Dalai_Llama_30B = "llama.30B", - Dalai_Llama_65B = "llama.65B", + // Dalai-served models (Alpaca and Llama) + Dalai_Alpaca_7B = "alpaca.7B", + Dalai_Alpaca_13B = "alpaca.13B", + Dalai_Llama_7B = "llama.7B", + Dalai_Llama_13B = "llama.13B", + Dalai_Llama_30B = "llama.30B", + Dalai_Llama_65B = "llama.65B", - // Anthropic - Claude_v1 = "claude-v1", - Claude_v1_0 = "claude-v1.0", - Claude_v1_2 = "claude-v1.2", - Claude_v1_3 = "claude-v1.3", - Claude_v1_instant = "claude-instant-v1", + // Anthropic + Claude_v1 = "claude-v1", + Claude_v1_0 = "claude-v1.0", + Claude_v1_2 = "claude-v1.2", + Claude_v1_3 = "claude-v1.3", + Claude_v1_instant = "claude-instant-v1", - // Google models - PaLM2_Text_Bison = "text-bison-001", // it's really models/text-bison-001, but that's confusing - PaLM2_Chat_Bison = "chat-bison-001", + // Google models + PaLM2_Text_Bison = "text-bison-001", // it's really models/text-bison-001, but that's confusing + PaLM2_Chat_Bison = "chat-bison-001", +} + +/** + * A list of model providers + */ +export enum LLMProvider { + OpenAI = "openai", + Azure_OpenAI = "azure", + Dalai = "dalai", + Anthropic = "anthropic", + Google = "google", +} + +/** + * Given an LLM, return what the model provider is. + * @param llm the specific large language model + * @returns an `LLMProvider` describing what provider hosts the model + */ +export function getProvider(llm: LLM): LLMProvider | undefined { + const llm_name = getEnumName(LLM, llm.toString()); + if (llm_name?.startsWith('OpenAI')) + return LLMProvider.OpenAI; + else if (llm_name?.startsWith('Azure')) + return LLMProvider.Azure_OpenAI; + else if (llm_name?.startsWith('PaLM2')) + return LLMProvider.Google; + else if (llm_name?.startsWith('Dalai')) + return LLMProvider.Dalai; + else if (llm.toString().startsWith('claude')) + return LLMProvider.Anthropic; + return undefined; } @@ -52,10 +83,25 @@ export enum LLM { export const RATE_LIMITS: { [key in LLM]?: [number, number] } = { [LLM.OpenAI_ChatGPT]: [30, 10], // max 30 requests a batch; wait 10 seconds between [LLM.OpenAI_ChatGPT_0301]: [30, 10], + [LLM.OpenAI_ChatGPT_0613]: [30, 10], + [LLM.OpenAI_ChatGPT_16k]: [30, 10], + [LLM.OpenAI_ChatGPT_16k_0613]: [30, 10], [LLM.OpenAI_GPT4]: [4, 15], // max 4 requests a batch; wait 15 seconds between [LLM.OpenAI_GPT4_0314]: [4, 15], + [LLM.OpenAI_GPT4_0613]: [4, 15], [LLM.OpenAI_GPT4_32k]: [4, 15], [LLM.OpenAI_GPT4_32k_0314]: [4, 15], + [LLM.OpenAI_GPT4_32k_0613]: [4, 15], + [LLM.Azure_OpenAI]: [30, 10], [LLM.PaLM2_Text_Bison]: [4, 10], // max 30 requests per minute; so do 4 per batch, 10 seconds between (conservative) [LLM.PaLM2_Chat_Bison]: [4, 10], -}; \ No newline at end of file +}; + + +/** Equivalent to a Python enum's .name property */ +export function getEnumName(enumObject: any, enumValue: any): string | undefined { + for (const key in enumObject) + if (enumObject[key] === enumValue) + return key; + return undefined; +} \ No newline at end of file diff --git a/chainforge/react-server/src/backend/utils.ts b/chainforge/react-server/src/backend/utils.ts index 2ef43d9..d1996e3 100644 --- a/chainforge/react-server/src/backend/utils.ts +++ b/chainforge/react-server/src/backend/utils.ts @@ -3,7 +3,7 @@ // from string import Template // from chainforge.promptengine.models import LLM -import { LLM } from './models'; +import { LLM, LLMProvider, getProvider } from './models'; import { Dict, StringDict, LLMAPICall, LLMResponseObject } from './typing'; import { env as process_env } from 'process'; import { StringTemplate } from './template'; @@ -30,6 +30,28 @@ export async function call_flask_backend(route: string, params: Dict | string): }); } +// We only calculate whether the app is running locally once upon load, and store it here: +let _APP_IS_RUNNING_LOCALLY: boolean | undefined = undefined; + +/** + * Tries to determine if the ChainForge front-end is running on user's local machine (and hence has access to Flask backend). + * @returns `true` if we think the app is running locally (on localhost or equivalent); `false` if not. + */ +export function APP_IS_RUNNING_LOCALLY(): boolean { + if (_APP_IS_RUNNING_LOCALLY === undefined) { + // Calculate whether we're running the app locally or not, and save the result + try { + const location = window.location; + _APP_IS_RUNNING_LOCALLY = location.hostname === "localhost" || location.hostname === "127.0.0.1" || location.hostname === ""; + } catch (e) { + // ReferenceError --window or location does not exist. + // We must not be running client-side in a browser, in this case (e.g., we are running a Node.js server) + _APP_IS_RUNNING_LOCALLY = false; + } + } + return _APP_IS_RUNNING_LOCALLY; +} + /** * Equivalent to a 'fetch' call, but routes it to the backend Flask server in * case we are running a local server and prefer to not deal with CORS issues making API calls client-side. @@ -93,29 +115,6 @@ export function set_api_keys(api_keys: StringDict): void { // Soft fail for non-present keys } -/** Equivalent to a Python enum's .name property */ -export function getEnumName(enumObject: any, enumValue: any): string | undefined { - for (const key in enumObject) { - if (enumObject[key] === enumValue) { - return key; - } - } - return undefined; -} - -// async def make_sync_call_async(sync_method, *args, **params): -// """ -// Makes a blocking synchronous call asynchronous, so that it can be awaited. -// NOTE: This is necessary for LLM APIs that do not yet support async (e.g. Google PaLM). -// """ -// loop = asyncio.get_running_loop() -// method = sync_method -// if len(params) > 0: -// def partial_sync_meth(*a): -// return sync_method(*a, **params) -// method = partial_sync_meth -// return await loop.run_in_executor(None, method, *args) - /** * Calls OpenAI models via OpenAI's API. @returns raw query and response JSON dicts. @@ -311,7 +310,6 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1, while (responses.length < n) { const resp = await route_fetch(url, 'POST', headers, query); responses.push(resp); - console.log(`${model} response ${responses.length} of ${n}:\n${resp}`); } return [query, responses]; @@ -506,20 +504,21 @@ export async function call_dalai(prompt: string, model: LLM, n: number = 1, temp export async function call_llm(llm: LLM, prompt: string, n: number, temperature: number, params?: Dict): Promise<[Dict, Dict]> { // Get the correct API call for the given LLM: let call_api: LLMAPICall | undefined; - const llm_name = getEnumName(LLM, llm.toString()); - if (llm_name?.startsWith('OpenAI')) - call_api = call_chatgpt; - else if (llm_name?.startsWith('Azure')) - call_api = call_azure_openai; - else if (llm_name?.startsWith('PaLM2')) - call_api = call_google_palm; - else if (llm_name?.startsWith('Dalai')) - call_api = call_dalai; - else if (llm.toString().startsWith('claude')) - call_api = call_anthropic; - - if (!call_api) + let llm_provider: LLMProvider = getProvider(llm); + + if (llm_provider === undefined) throw new Error(`Language model ${llm} is not supported.`); + + if (llm_provider === LLMProvider.OpenAI) + call_api = call_chatgpt; + else if (llm_provider === LLMProvider.Azure_OpenAI) + call_api = call_azure_openai; + else if (llm_provider === LLMProvider.Google) + call_api = call_google_palm; + else if (llm_provider === LLMProvider.Dalai) + call_api = call_dalai; + else if (llm_provider === LLMProvider.Anthropic) + call_api = call_anthropic; return call_api(prompt, llm, n, temperature, params); } @@ -594,24 +593,33 @@ function _extract_anthropic_responses(response: Array): Array { * text response(s) part of the response object. */ export function extract_responses(response: Array | Dict, llm: LLM | string): Array { - const llm_name = getEnumName(LLM, llm.toString()); - if (llm_name?.startsWith('OpenAI')) { - if (llm_name.toLowerCase().includes('davinci')) - return _extract_openai_completion_responses(response); - else - return _extract_chatgpt_responses(response); - } else if (llm_name?.startsWith('Azure')) - return _extract_openai_responses(response); - else if (llm_name?.startsWith('PaLM2')) - return _extract_palm_responses(response); - else if (llm_name?.startsWith('Dalai')) - return [response.toString()]; - else if (llm.toString().startsWith('claude')) - return _extract_anthropic_responses(response as Dict[]); - else - throw new Error(`No method defined to extract responses for LLM ${llm}.`) + let llm_provider: LLMProvider = getProvider(llm as LLM); + + switch (llm_provider) { + case LLMProvider.OpenAI: + if (llm.toString().toLowerCase().includes('davinci')) + return _extract_openai_completion_responses(response); + else + return _extract_chatgpt_responses(response); + case LLMProvider.Azure_OpenAI: + return _extract_openai_responses(response); + case LLMProvider.Google: + return _extract_palm_responses(response); + case LLMProvider.Dalai: + return [response.toString()]; + case LLMProvider.Anthropic: + return _extract_anthropic_responses(response as Dict[]); + default: + throw new Error(`No method defined to extract responses for LLM ${llm}.`) + } } +/** + * Marge the 'responses' and 'raw_response' properties of two LLMResponseObjects, + * keeping all the other params from the second argument (llm, query, etc). + * + * If one object is undefined or null, returns the object that is defined, unaltered. + */ export function merge_response_objs(resp_obj_A: LLMResponseObject | undefined, resp_obj_B: LLMResponseObject | undefined): LLMResponseObject | undefined { if (!resp_obj_A && !resp_obj_B) { console.warn('Warning: Merging two undefined response objects.') @@ -638,48 +646,3 @@ export function merge_response_objs(resp_obj_A: LLMResponseObject | undefined, r metavars: resp_obj_B.metavars, }; } - -export function APP_IS_RUNNING_LOCALLY(): boolean { - try { - const location = window.location; - return location.hostname === "localhost" || location.hostname === "127.0.0.1" || location.hostname === ""; - } catch (e) { - // ReferenceError --window or location does not exist. - // We must not be running client-side in a browser, in this case (e.g., we are running a Node.js server) - return false; - } -} - -// def create_dir_if_not_exists(path: str) -> None: -// if not os.path.exists(path): -// os.makedirs(path) - -// def is_valid_filepath(filepath: str) -> bool: -// try: -// with open(filepath, 'r', encoding='utf-8'): -// pass -// except IOError: -// try: -// # Create the file if it doesn't exist, and write an empty json string to it -// with open(filepath, 'w+', encoding='utf-8') as f: -// f.write("{}") -// pass -// except IOError: -// return False -// return True - -// def is_valid_json(json_dict: dict) -> bool: -// if isinstance(json_dict, dict): -// try: -// json.dumps(json_dict) -// return True -// except Exception: -// pass -// return False - -// def get_files_at_dir(path: str) -> list: -// f = [] -// for (dirpath, dirnames, filenames) in os.walk(path): -// f = filenames -// break -// return f \ No newline at end of file