mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-15 00:36:29 +00:00
wip started utils.js conversion. Tested that OpenAI API call works
This commit is contained in:
parent
dca20c2d0e
commit
a66a5c73f2
31
chainforge/react-server/package-lock.json
generated
31
chainforge/react-server/package-lock.json
generated
@ -43,6 +43,7 @@
|
||||
"emoji-picker-react": "^4.4.9",
|
||||
"mantine-contextmenu": "^1.2.15",
|
||||
"mantine-react-table": "^1.0.0-beta.8",
|
||||
"openai": "^3.3.0",
|
||||
"papaparse": "^5.4.1",
|
||||
"plotly.js": "^2.21.0",
|
||||
"react": "^18.2.0",
|
||||
@ -6535,6 +6536,14 @@
|
||||
"node": ">=4"
|
||||
}
|
||||
},
|
||||
"node_modules/axios": {
|
||||
"version": "0.26.1",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-0.26.1.tgz",
|
||||
"integrity": "sha512-fPwcX4EvnSHuInCMItEhAGnaSEXRBjtzh9fOtsE6E1G6p7vl7edEeZe11QHf18+6+9gR5PbKV/sGKNaD8YaMeA==",
|
||||
"dependencies": {
|
||||
"follow-redirects": "^1.14.8"
|
||||
}
|
||||
},
|
||||
"node_modules/axobject-query": {
|
||||
"version": "3.1.1",
|
||||
"resolved": "https://registry.npmjs.org/axobject-query/-/axobject-query-3.1.1.tgz",
|
||||
@ -15582,6 +15591,28 @@
|
||||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/openai": {
|
||||
"version": "3.3.0",
|
||||
"resolved": "https://registry.npmjs.org/openai/-/openai-3.3.0.tgz",
|
||||
"integrity": "sha512-uqxI/Au+aPRnsaQRe8CojU0eCR7I0mBiKjD3sNMzY6DaC1ZVrc85u98mtJW6voDug8fgGN+DIZmTDxTthxb7dQ==",
|
||||
"dependencies": {
|
||||
"axios": "^0.26.0",
|
||||
"form-data": "^4.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/openai/node_modules/form-data": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.0.tgz",
|
||||
"integrity": "sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==",
|
||||
"dependencies": {
|
||||
"asynckit": "^0.4.0",
|
||||
"combined-stream": "^1.0.8",
|
||||
"mime-types": "^2.1.12"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/optionator": {
|
||||
"version": "0.9.1",
|
||||
"resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.1.tgz",
|
||||
|
@ -38,6 +38,7 @@
|
||||
"emoji-picker-react": "^4.4.9",
|
||||
"mantine-contextmenu": "^1.2.15",
|
||||
"mantine-react-table": "^1.0.0-beta.8",
|
||||
"openai": "^3.3.0",
|
||||
"papaparse": "^5.4.1",
|
||||
"plotly.js": "^2.21.0",
|
||||
"react": "^18.2.0",
|
||||
|
@ -7,18 +7,8 @@
|
||||
|
||||
import { PromptTemplate, PromptPermutationGenerator } from "./template";
|
||||
import { LLM, RATE_LIMITS } from './models';
|
||||
import { Dict, LLMResponseError, LLMResponseObject, LLMAPICall } from "./typing";
|
||||
|
||||
// Custom types
|
||||
interface Dict { [key: string]: any };
|
||||
interface LLMResponseType {
|
||||
prompt: string;
|
||||
query: Dict;
|
||||
responses: string[];
|
||||
raw_response: Dict;
|
||||
llm: LLM;
|
||||
info: Dict;
|
||||
metavars: Dict;
|
||||
}
|
||||
interface _IntermediateLLMResponseType {
|
||||
prompt: PromptTemplate | string,
|
||||
query?: Dict,
|
||||
@ -62,14 +52,6 @@ class StorageCache {
|
||||
}
|
||||
}
|
||||
|
||||
/** Raised when there is an error generating a single response from an LLM */
|
||||
class LLMResponseError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = "LLMResponseError";
|
||||
}
|
||||
}
|
||||
|
||||
// From trincot @ SO: https://stackoverflow.com/a/76477994/1911342
|
||||
// Functions equivalently to Python's asyncio 'as_completed' method,
|
||||
// performing a Promise.race() but where all promises are yielded as they complete
|
||||
@ -137,13 +119,13 @@ class PromptPipeline {
|
||||
llm: LLM,
|
||||
n: number = 1,
|
||||
temperature: number = 1.0,
|
||||
llm_params: Dict): Generator<LLMResponseType | LLMResponseError, boolean, undefined> {
|
||||
llm_params: Dict): Generator<LLMResponseObject | LLMResponseError, boolean, undefined> {
|
||||
// Double-check that properties is the correct type (JSON dict):
|
||||
// Load any cache'd responses
|
||||
let responses = this._load_cached_responses();
|
||||
|
||||
// Query LLM with each prompt, yield + cache the responses
|
||||
let tasks: Array<Promise<_IntermediateLLMResponseType>> = [];
|
||||
let tasks: Array<Promise<_IntermediateLLMResponseObject>> = [];
|
||||
const rate_limit = RATE_LIMITS[llm] || [1, 0];
|
||||
let [max_req, wait_secs] = rate_limit ? rate_limit : [1, 0];
|
||||
let num_queries_sent = -1;
|
||||
@ -206,7 +188,7 @@ class PromptPipeline {
|
||||
}
|
||||
|
||||
// Create a response obj to represent the response
|
||||
let resp_obj: LLMResponseType = {
|
||||
let resp_obj: LLMResponseObject = {
|
||||
"prompt": prompt.toString(),
|
||||
"query": query,
|
||||
"responses": extract_responses(response, llm),
|
||||
@ -245,7 +227,7 @@ class PromptPipeline {
|
||||
let metavars = prompt.metavars;
|
||||
|
||||
// Create a response obj to represent the response
|
||||
let resp_obj: LLMResponseType = {
|
||||
let resp_obj: LLMResponseObject = {
|
||||
"prompt": prompt.toString(),
|
||||
"query": query,
|
||||
"responses": extract_responses(response, llm),
|
||||
@ -318,7 +300,7 @@ class PromptPipeline {
|
||||
}
|
||||
|
||||
// Get the correct API call for the given LLM:
|
||||
let call_llm = null;
|
||||
let call_llm: LLMAPICall | undefined;
|
||||
const llm_name = getEnumName(LLM, llm.toString());
|
||||
if (llm_name?.startsWith('OpenAI'))
|
||||
call_llm = call_chatgpt;
|
||||
|
32
chainforge/react-server/src/backend/typing.ts
Normal file
32
chainforge/react-server/src/backend/typing.ts
Normal file
@ -0,0 +1,32 @@
|
||||
import { LLM } from "./models";
|
||||
|
||||
/** Raised when there is an error generating a single response from an LLM */
|
||||
export class LLMResponseError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = "LLMResponseError";
|
||||
}
|
||||
}
|
||||
|
||||
export interface Dict { [key: string]: any };
|
||||
export interface StringDict { [key: string]: string };
|
||||
|
||||
/** A JSON object describing an LLM response for the same prompt, with n responses (n>=1) */
|
||||
export interface LLMResponseObject {
|
||||
prompt: string;
|
||||
query: Dict;
|
||||
responses: string[];
|
||||
raw_response: Dict;
|
||||
llm: LLM;
|
||||
info: Dict;
|
||||
metavars: Dict;
|
||||
}
|
||||
|
||||
/** A standard function interface for calling an LLM */
|
||||
export interface LLMAPICall {
|
||||
(prompt: string,
|
||||
model: LLM,
|
||||
n: number,
|
||||
temperature: number,
|
||||
params?: Dict): [Dict, Dict]
|
||||
}
|
536
chainforge/react-server/src/backend/utils.ts
Normal file
536
chainforge/react-server/src/backend/utils.ts
Normal file
@ -0,0 +1,536 @@
|
||||
// from typing import Dict, Tuple, List, Union, Optional
|
||||
// import json, os, time, asyncio
|
||||
// from string import Template
|
||||
|
||||
// from chainforge.promptengine.models import LLM
|
||||
import { LLM } from './models';
|
||||
import { Dict, StringDict, LLMAPICall } from './typing';
|
||||
import { env as process_env } from 'process';
|
||||
|
||||
import { Configuration as OpenAIConfig, OpenAIApi } from "openai";
|
||||
|
||||
function get_environ(key: string): string | undefined {
|
||||
if (key in process_env)
|
||||
return process_env[key];
|
||||
return undefined;
|
||||
}
|
||||
|
||||
let DALAI_MODEL: string | undefined;
|
||||
let DALAI_RESPONSE: Dict | undefined;
|
||||
|
||||
let OPENAI_API_KEY = get_environ("OPENAI_API_KEY");
|
||||
let ANTHROPIC_API_KEY = get_environ("ANTHROPIC_API_KEY");
|
||||
let GOOGLE_PALM_API_KEY = get_environ("PALM_API_KEY");
|
||||
let AZURE_OPENAI_KEY = get_environ("AZURE_OPENAI_KEY");
|
||||
let AZURE_OPENAI_ENDPOINT = get_environ("AZURE_OPENAI_ENDPOINT");
|
||||
|
||||
/**
|
||||
* Sets the local API keys for the revelant LLM API(s).
|
||||
*/
|
||||
export function set_api_keys(api_keys: StringDict): void {
|
||||
function key_is_present(name: string): boolean {
|
||||
return name in api_keys && api_keys[name].trim().length > 0;
|
||||
}
|
||||
if (key_is_present('OpenAI'))
|
||||
OPENAI_API_KEY= api_keys['OpenAI'];
|
||||
if (key_is_present('Anthropic'))
|
||||
ANTHROPIC_API_KEY = api_keys['Anthropic'];
|
||||
if (key_is_present('Google'))
|
||||
GOOGLE_PALM_API_KEY = api_keys['Google'];
|
||||
if (key_is_present('Azure_OpenAI'))
|
||||
AZURE_OPENAI_KEY = api_keys['Azure_OpenAI'];
|
||||
if (key_is_present('Azure_OpenAI_Endpoint'))
|
||||
AZURE_OPENAI_ENDPOINT = api_keys['Azure_OpenAI_Endpoint'];
|
||||
// Soft fail for non-present keys
|
||||
}
|
||||
|
||||
// async def make_sync_call_async(sync_method, *args, **params):
|
||||
// """
|
||||
// Makes a blocking synchronous call asynchronous, so that it can be awaited.
|
||||
// NOTE: This is necessary for LLM APIs that do not yet support async (e.g. Google PaLM).
|
||||
// """
|
||||
// loop = asyncio.get_running_loop()
|
||||
// method = sync_method
|
||||
// if len(params) > 0:
|
||||
// def partial_sync_meth(*a):
|
||||
// return sync_method(*a, **params)
|
||||
// method = partial_sync_meth
|
||||
// return await loop.run_in_executor(None, method, *args)
|
||||
|
||||
/**
|
||||
* Calls OpenAI models via OpenAI's API.
|
||||
@returns raw query and response JSON dicts.
|
||||
*/
|
||||
export async function call_chatgpt(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
|
||||
const configuration = new OpenAIConfig({
|
||||
apiKey: OPENAI_API_KEY,
|
||||
});
|
||||
const openai = new OpenAIApi(configuration);
|
||||
|
||||
let modelname: string = model.toString();
|
||||
if (params?.stop && (!Array.isArray(params.stop) || params.stop.length === 0))
|
||||
delete params.stop;
|
||||
if (params?.functions && (!Array.isArray(params.functions) || params.functions.length === 0))
|
||||
delete params?.functions;
|
||||
if (params?.function_call && (!(typeof params.function_call === 'string') || params.function_call.trim().length === 0))
|
||||
delete params.function_call;
|
||||
|
||||
console.log(`Querying OpenAI model '${model}' with prompt '${prompt}'...`);
|
||||
const system_msg: string = params?.system_msg || "You are a helpful assistant.";
|
||||
|
||||
let query: Dict = {
|
||||
model: modelname,
|
||||
n: n,
|
||||
temperature: temperature,
|
||||
...params, // 'the rest' of the settings, passed from the front-end settings
|
||||
};
|
||||
|
||||
// Get the correct function to call
|
||||
let openai_call: any;
|
||||
if (modelname.includes('davinci')) { // text completions model
|
||||
openai_call = openai.createCompletion.bind(openai);
|
||||
query['prompt'] = prompt;
|
||||
} else { // chat model
|
||||
openai_call = openai.createChatCompletion.bind(openai);
|
||||
query['messages'] = [
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
}
|
||||
|
||||
// Try to call OpenAI
|
||||
let response: Dict = {};
|
||||
try {
|
||||
const completion = await openai_call(query);
|
||||
response = completion.data;
|
||||
} catch (error: any) {
|
||||
if (error?.response) {
|
||||
throw new Error("Could not authenticate to OpenAI. Double-check that your API key is set in Settings or in your local Python environment.");
|
||||
// throw new Error(error.response.status);
|
||||
} else {
|
||||
console.log(error?.message || error);
|
||||
throw new Error(error?.message || error);
|
||||
}
|
||||
}
|
||||
|
||||
return [query, response];
|
||||
}
|
||||
|
||||
function _test() {
|
||||
// Call ChatGPT
|
||||
call_chatgpt("Who invented modern playing cards?",
|
||||
LLM.OpenAI_ChatGPT,
|
||||
1, 1.0).then(([query, response]) => {
|
||||
console.log(response.choices[0].message);
|
||||
});
|
||||
}
|
||||
|
||||
_test();
|
||||
|
||||
// async def call_azure_openai(prompt: str, model: LLM, n: int = 1, temperature: float= 1.0,
|
||||
// deployment_name: str = 'gpt-35-turbo',
|
||||
// model_type: str = "chat-completion",
|
||||
// api_version: str = "2023-05-15",
|
||||
// system_msg: Optional[str]=None,
|
||||
// **params) -> Tuple[Dict, Dict]:
|
||||
// """
|
||||
// Calls an OpenAI chat model GPT3.5 or GPT4 via Microsoft Azure services.
|
||||
// Returns raw query and response JSON dicts.
|
||||
|
||||
// NOTE: It is recommended to set an environment variables AZURE_OPENAI_KEY and AZURE_OPENAI_ENDPOINT
|
||||
// """
|
||||
// global AZURE_OPENAI_KEY, AZURE_OPENAI_ENDPOINT
|
||||
// if AZURE_OPENAI_KEY is None:
|
||||
// raise Exception("Could not find an Azure OpenAPI Key to use. Double-check that your key is set in Settings or in your local Python environment.")
|
||||
// if AZURE_OPENAI_ENDPOINT is None:
|
||||
// raise Exception("Could not find an Azure OpenAI Endpoint to use. Double-check that your endpoint is set in Settings or in your local Python environment.")
|
||||
|
||||
// import openai
|
||||
// openai.api_type = "azure"
|
||||
// openai.api_version = api_version
|
||||
// openai.api_key = AZURE_OPENAI_KEY
|
||||
// openai.api_base = AZURE_OPENAI_ENDPOINT
|
||||
|
||||
// if 'stop' in params and not isinstance(params['stop'], list) or len(params['stop']) == 0:
|
||||
// del params['stop']
|
||||
// if 'functions' in params and not isinstance(params['functions'], list) or len(params['functions']) == 0:
|
||||
// del params['functions']
|
||||
// if 'function_call' in params and not isinstance(params['function_call'], str) or len(params['function_call'].strip()) == 0:
|
||||
// del params['function_call']
|
||||
|
||||
// print(f"Querying Azure OpenAI deployed model '{deployment_name}' at endpoint '{AZURE_OPENAI_ENDPOINT}' with prompt '{prompt}'...")
|
||||
// system_msg = "You are a helpful assistant." if system_msg is None else system_msg
|
||||
|
||||
// query = {
|
||||
// "engine": deployment_name, # this differs from a basic OpenAI call
|
||||
// "n": n,
|
||||
// "temperature": temperature,
|
||||
// **params, # 'the rest' of the settings, passed from a front-end app
|
||||
// }
|
||||
|
||||
// if model_type == 'text-completion':
|
||||
// openai_call = openai.Completion.acreate
|
||||
// query['prompt'] = prompt
|
||||
// else:
|
||||
// openai_call = openai.ChatCompletion.acreate
|
||||
// query['messages'] = [
|
||||
// {"role": "system", "content": system_msg},
|
||||
// {"role": "user", "content": prompt},
|
||||
// ]
|
||||
|
||||
// try:
|
||||
// response = await openai_call(**query)
|
||||
// except Exception as e:
|
||||
// if (isinstance(e, openai.error.AuthenticationError)):
|
||||
// raise Exception("Could not authenticate to OpenAI. Double-check that your API key is set in Settings or in your local Python environment.")
|
||||
// raise e
|
||||
|
||||
// return query, response
|
||||
|
||||
// async def call_anthropic(prompt: str, model: LLM, n: int = 1, temperature: float= 1.0,
|
||||
// max_tokens_to_sample=1024,
|
||||
// async_mode=False,
|
||||
// custom_prompt_wrapper: Optional[str]=None,
|
||||
// stop_sequences: Optional[List[str]]=["\n\nHuman:"],
|
||||
// **params) -> Tuple[Dict, Dict]:
|
||||
// """
|
||||
// Calls Anthropic API with the given model, passing in params.
|
||||
// Returns raw query and response JSON dicts.
|
||||
|
||||
// Unique parameters:
|
||||
// - custom_prompt_wrapper: Anthropic models expect prompts in form "\n\nHuman: ${prompt}\n\nAssistant". If you wish to
|
||||
// explore custom prompt wrappers that deviate, write a python Template that maps from 'prompt' to custom wrapper.
|
||||
// If set to None, defaults to Anthropic's suggested prompt wrapper.
|
||||
// - max_tokens_to_sample: A maximum number of tokens to generate before stopping.
|
||||
// - stop_sequences: A list of strings upon which to stop generating. Defaults to ["\n\nHuman:"], the cue for the next turn in the dialog agent.
|
||||
// - async_mode: Evaluation access to Claude limits calls to 1 at a time, meaning we can't take advantage of async.
|
||||
// If you want to send all 'n' requests at once, you can set async_mode to True.
|
||||
|
||||
// NOTE: It is recommended to set an environment variable ANTHROPIC_API_KEY with your Anthropic API key
|
||||
// """
|
||||
// if ANTHROPIC_API_KEY is None:
|
||||
// raise Exception("Could not find an API key for Anthropic models. Double-check that your API key is set in Settings or in your local Python environment.")
|
||||
|
||||
// import anthropic
|
||||
// client = anthropic.Client(ANTHROPIC_API_KEY)
|
||||
|
||||
// # Wrap the prompt in the provided template, or use the default Anthropic one
|
||||
// if custom_prompt_wrapper is None or '${prompt}' not in custom_prompt_wrapper:
|
||||
// custom_prompt_wrapper = anthropic.HUMAN_PROMPT + " ${prompt}" + anthropic.AI_PROMPT
|
||||
// prompt_wrapper_template = Template(custom_prompt_wrapper)
|
||||
// wrapped_prompt = prompt_wrapper_template.substitute(prompt=prompt)
|
||||
|
||||
// # Format query
|
||||
// query = {
|
||||
// 'model': model.value,
|
||||
// 'prompt': wrapped_prompt,
|
||||
// 'max_tokens_to_sample': max_tokens_to_sample,
|
||||
// 'stop_sequences': stop_sequences,
|
||||
// 'temperature': temperature,
|
||||
// **params
|
||||
// }
|
||||
|
||||
// print(f"Calling Anthropic model '{model.value}' with prompt '{prompt}' (n={n}). Please be patient...")
|
||||
|
||||
// # Request responses using the passed async_mode
|
||||
// responses = []
|
||||
// if async_mode:
|
||||
// # Gather n responses by firing off all API requests at once
|
||||
// tasks = [client.acompletion(**query) for _ in range(n)]
|
||||
// responses = await asyncio.gather(*tasks)
|
||||
// else:
|
||||
// # Repeat call n times, waiting for each response to come in:
|
||||
// while len(responses) < n:
|
||||
// resp = await client.acompletion(**query)
|
||||
// responses.append(resp)
|
||||
// print(f'{model.value} response {len(responses)} of {n}:\n', resp)
|
||||
|
||||
// return query, responses
|
||||
|
||||
// async def call_google_palm(prompt: str, model: LLM, n: int = 1, temperature: float= 0.7,
|
||||
// max_output_tokens=800,
|
||||
// async_mode=False,
|
||||
// **params) -> Tuple[Dict, Dict]:
|
||||
// """
|
||||
// Calls a Google PaLM model.
|
||||
// Returns raw query and response JSON dicts.
|
||||
// """
|
||||
// if GOOGLE_PALM_API_KEY is None:
|
||||
// raise Exception("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local Python environment.")
|
||||
|
||||
// import google.generativeai as palm
|
||||
// palm.configure(api_key=GOOGLE_PALM_API_KEY)
|
||||
|
||||
// is_chat_model = 'chat' in model.value
|
||||
|
||||
// query = {
|
||||
// 'model': f"models/{model.value}",
|
||||
// 'prompt': prompt,
|
||||
// 'candidate_count': n,
|
||||
// 'temperature': temperature,
|
||||
// 'max_output_tokens': max_output_tokens,
|
||||
// **params,
|
||||
// }
|
||||
|
||||
// # Remove erroneous parameters for text and chat models
|
||||
// if 'top_k' in query and query['top_k'] <= 0:
|
||||
// del query['top_k']
|
||||
// if 'top_p' in query and query['top_p'] <= 0:
|
||||
// del query['top_p']
|
||||
// if is_chat_model and 'max_output_tokens' in query:
|
||||
// del query['max_output_tokens']
|
||||
// if is_chat_model and 'stop_sequences' in query:
|
||||
// del query['stop_sequences']
|
||||
|
||||
// # Get the correct model's completions call
|
||||
// palm_call = palm.chat if is_chat_model else palm.generate_text
|
||||
|
||||
// # Google PaLM's python API does not currently support async calls.
|
||||
// # To make one, we need to wrap it in an asynchronous executor:
|
||||
// completion = await make_sync_call_async(palm_call, **query)
|
||||
// completion_dict = completion.to_dict()
|
||||
|
||||
// # Google PaLM, unlike other chat models, will output empty
|
||||
// # responses for any response it deems unsafe (blocks). Although the text completions
|
||||
// # API has a (relatively undocumented) 'safety_settings' parameter,
|
||||
// # the current chat completions API provides users no control over the blocking.
|
||||
// # We need to detect this and fill the response with the safety reasoning:
|
||||
// if len(completion.filters) > 0:
|
||||
// # Request was blocked. Output why in the response text,
|
||||
// # repairing the candidate dict to mock up 'n' responses
|
||||
// block_error_msg = f'[[BLOCKED_REQUEST]] Request was blocked because it triggered safety filters: {str(completion.filters)}'
|
||||
// completion_dict['candidates'] = [{'author': 1, 'content':block_error_msg}] * n
|
||||
|
||||
// # Weirdly, google ignores candidate_count if temperature is 0.
|
||||
// # We have to check for this and manually append the n-1 responses:
|
||||
// if n > 1 and temperature == 0 and len(completion_dict['candidates']) == 1:
|
||||
// copied_candidates = [completion_dict['candidates'][0]] * n
|
||||
// completion_dict['candidates'] = copied_candidates
|
||||
|
||||
// return query, completion_dict
|
||||
|
||||
// async def call_dalai(prompt: str, model: LLM, server: str="http://localhost:4000", n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
|
||||
// """
|
||||
// Calls a Dalai server running LLMs Alpaca, Llama, etc locally.
|
||||
// Returns the raw query and response JSON dicts.
|
||||
|
||||
// Parameters:
|
||||
// - model: The LLM model, whose value is the name known byt Dalai; e.g. 'alpaca.7b'
|
||||
// - port: The port of the local server where Dalai is running. By default 4000.
|
||||
// - prompt: The prompt to pass to the LLM.
|
||||
// - n: How many times to query. If n > 1, this will continue to query the LLM 'n' times and collect all responses.
|
||||
// - temperature: The temperature to query at
|
||||
// - params: Any other Dalai-specific params to pass. For more info, see below or https://cocktailpeanut.github.io/dalai/#/?id=syntax-1
|
||||
|
||||
// TODO: Currently, this uses a modified dalaipy library for simplicity; however, in the future we might remove this dependency.
|
||||
// """
|
||||
// # Import and load upon first run
|
||||
// global DALAI_MODEL, DALAI_RESPONSE
|
||||
// if not server or len(server.strip()) == 0: # In case user passed a blank server name, revert to default on port 4000
|
||||
// server = "http://localhost:4000"
|
||||
// if DALAI_MODEL is None:
|
||||
// from chainforge.promptengine.dalaipy import Dalai
|
||||
// DALAI_MODEL = Dalai(server)
|
||||
// elif DALAI_MODEL.server != server: # if the port has changed, we need to create a new model
|
||||
// DALAI_MODEL = Dalai(server)
|
||||
|
||||
// # Make sure server is connected
|
||||
// DALAI_MODEL.connect()
|
||||
|
||||
// # Create settings dict to pass to Dalai as args
|
||||
// def_params = {'n_predict':128, 'repeat_last_n':64, 'repeat_penalty':1.3, 'seed':-1, 'threads':4, 'top_k':40, 'top_p':0.9}
|
||||
// for key in params:
|
||||
// if key in def_params:
|
||||
// def_params[key] = params[key]
|
||||
// else:
|
||||
// print(f"Attempted to pass unsupported param '{key}' to Dalai. Ignoring.")
|
||||
|
||||
// # Create full query to Dalai
|
||||
// query = {
|
||||
// 'prompt': prompt,
|
||||
// 'model': model.value,
|
||||
// 'id': str(round(time.time()*1000)),
|
||||
// 'temp': temperature,
|
||||
// **def_params
|
||||
// }
|
||||
|
||||
// # Create spot to put response and a callback that sets it
|
||||
// DALAI_RESPONSE = None
|
||||
// def on_finish(r):
|
||||
// global DALAI_RESPONSE
|
||||
// DALAI_RESPONSE = r
|
||||
|
||||
// print(f"Calling Dalai model '{query['model']}' with prompt '{query['prompt']}' (n={n}). Please be patient...")
|
||||
|
||||
// # Repeat call n times
|
||||
// responses = []
|
||||
// while len(responses) < n:
|
||||
|
||||
// # Call the Dalai model
|
||||
// req = DALAI_MODEL.generate_request(**query)
|
||||
// sent_req_success = DALAI_MODEL.generate(req, on_finish=on_finish)
|
||||
|
||||
// if not sent_req_success:
|
||||
// print("Something went wrong pinging the Dalai server. Returning None.")
|
||||
// return None, None
|
||||
|
||||
// # Blocking --wait for request to complete:
|
||||
// while DALAI_RESPONSE is None:
|
||||
// await asyncio.sleep(0.01)
|
||||
|
||||
// response = DALAI_RESPONSE['response']
|
||||
// if response[-5:] == '<end>': # strip ending <end> tag, if present
|
||||
// response = response[:-5]
|
||||
// if response.index('\r\n') > -1: # strip off the prompt, which is included in the result up to \r\n:
|
||||
// response = response[(response.index('\r\n')+2):]
|
||||
// DALAI_RESPONSE = None
|
||||
|
||||
// responses.append(response)
|
||||
// print(f'Response {len(responses)} of {n}:\n', response)
|
||||
|
||||
// # Disconnect from the server
|
||||
// DALAI_MODEL.disconnect()
|
||||
|
||||
// return query, responses
|
||||
|
||||
// def _extract_openai_chat_choice_content(choice: dict) -> str:
|
||||
// """
|
||||
// Extracts the relevant portion of a OpenAI chat response.
|
||||
|
||||
// Note that chat choice objects can now include 'function_call' and a blank 'content' response.
|
||||
// This method detects a 'function_call's presence, prepends [[FUNCTION]] and converts the function call into Python format.
|
||||
// """
|
||||
// if choice['finish_reason'] == 'function_call' or choice["message"]["content"] is None or \
|
||||
// ('function_call' in choice['message'] and len(choice['message']['function_call']) > 0):
|
||||
// func = choice['message']['function_call']
|
||||
// return '[[FUNCTION]] ' + func['name'] + str(func['arguments'])
|
||||
// else:
|
||||
// return choice["message"]["content"]
|
||||
|
||||
// def _extract_chatgpt_responses(response: dict) -> List[str]:
|
||||
// """
|
||||
// Extracts the text part of a response JSON from ChatGPT. If there is more
|
||||
// than 1 response (e.g., asking the LLM to generate multiple responses),
|
||||
// this produces a list of all returned responses.
|
||||
// """
|
||||
// choices = response["choices"]
|
||||
// return [
|
||||
// _extract_openai_chat_choice_content(c)
|
||||
// for c in choices
|
||||
// ]
|
||||
|
||||
// def _extract_openai_completion_responses(response: dict) -> List[str]:
|
||||
// """
|
||||
// Extracts the text part of a response JSON from OpenAI completions models like Davinci. If there are more
|
||||
// than 1 response (e.g., asking the LLM to generate multiple responses),
|
||||
// this produces a list of all returned responses.
|
||||
// """
|
||||
// choices = response["choices"]
|
||||
// return [
|
||||
// c["text"].strip()
|
||||
// for c in choices
|
||||
// ]
|
||||
|
||||
// def _extract_openai_responses(response: dict) -> List[str]:
|
||||
// """
|
||||
// Deduces the format of an OpenAI model response (completion or chat)
|
||||
// and extracts the response text using the appropriate method.
|
||||
// """
|
||||
// if len(response["choices"]) == 0: return []
|
||||
// first_choice = response["choices"][0]
|
||||
// if "message" in first_choice:
|
||||
// return _extract_chatgpt_responses(response)
|
||||
// else:
|
||||
// return _extract_openai_completion_responses(response)
|
||||
|
||||
// def _extract_palm_responses(completion) -> List[str]:
|
||||
// """
|
||||
// Extracts the text part of a 'Completion' object from Google PaLM2 `generate_text` or `chat`
|
||||
|
||||
// NOTE: The candidate object for `generate_text` has a key 'output' which contains the response,
|
||||
// while the `chat` API uses a key 'content'. This checks for either.
|
||||
// """
|
||||
// return [
|
||||
// c['output'] if 'output' in c else c['content']
|
||||
// for c in completion['candidates']
|
||||
// ]
|
||||
|
||||
// def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List[str]:
|
||||
// """
|
||||
// Given a LLM and a response object from its API, extract the
|
||||
// text response(s) part of the response object.
|
||||
// """
|
||||
// llm_str = llm.name if isinstance(llm, LLM) else llm
|
||||
// if llm_str[:6] == 'OpenAI':
|
||||
// if 'davinci' in llm_str.lower():
|
||||
// return _extract_openai_completion_responses(response)
|
||||
// else:
|
||||
// return _extract_chatgpt_responses(response)
|
||||
// elif llm_str[:5] == 'Azure':
|
||||
// return _extract_openai_responses(response)
|
||||
// elif llm_str[:5] == 'PaLM2':
|
||||
// return _extract_palm_responses(response)
|
||||
// elif llm_str[:5] == 'Dalai':
|
||||
// return response
|
||||
// elif llm_str[:6] == 'Claude':
|
||||
// return [r["completion"] for r in response]
|
||||
// else:
|
||||
// raise ValueError(f"LLM {llm_str} is unsupported.")
|
||||
|
||||
// def merge_response_objs(resp_obj_A: Union[dict, None], resp_obj_B: Union[dict, None]) -> dict:
|
||||
// if resp_obj_B is None:
|
||||
// return resp_obj_A
|
||||
// elif resp_obj_A is None:
|
||||
// return resp_obj_B
|
||||
// raw_resp_A = resp_obj_A["raw_response"]
|
||||
// raw_resp_B = resp_obj_B["raw_response"]
|
||||
// if not isinstance(raw_resp_A, list):
|
||||
// raw_resp_A = [ raw_resp_A ]
|
||||
// if not isinstance(raw_resp_B, list):
|
||||
// raw_resp_B = [ raw_resp_B ]
|
||||
// C = {
|
||||
// "responses": resp_obj_A["responses"] + resp_obj_B["responses"],
|
||||
// "raw_response": raw_resp_A + raw_resp_B,
|
||||
// }
|
||||
// return {
|
||||
// **C,
|
||||
// "prompt": resp_obj_B['prompt'],
|
||||
// "query": resp_obj_B['query'],
|
||||
// "llm": resp_obj_B['llm'],
|
||||
// "info": resp_obj_B['info'],
|
||||
// "metavars": resp_obj_B['metavars'],
|
||||
// }
|
||||
|
||||
// def create_dir_if_not_exists(path: str) -> None:
|
||||
// if not os.path.exists(path):
|
||||
// os.makedirs(path)
|
||||
|
||||
// def is_valid_filepath(filepath: str) -> bool:
|
||||
// try:
|
||||
// with open(filepath, 'r', encoding='utf-8'):
|
||||
// pass
|
||||
// except IOError:
|
||||
// try:
|
||||
// # Create the file if it doesn't exist, and write an empty json string to it
|
||||
// with open(filepath, 'w+', encoding='utf-8') as f:
|
||||
// f.write("{}")
|
||||
// pass
|
||||
// except IOError:
|
||||
// return False
|
||||
// return True
|
||||
|
||||
// def is_valid_json(json_dict: dict) -> bool:
|
||||
// if isinstance(json_dict, dict):
|
||||
// try:
|
||||
// json.dumps(json_dict)
|
||||
// return True
|
||||
// except Exception:
|
||||
// pass
|
||||
// return False
|
||||
|
||||
// def get_files_at_dir(path: str) -> list:
|
||||
// f = []
|
||||
// for (dirpath, dirnames, filenames) in os.walk(path):
|
||||
// f = filenames
|
||||
// break
|
||||
// return f
|
Loading…
x
Reference in New Issue
Block a user