diff --git a/chainforge/react-server/src/App.js b/chainforge/react-server/src/App.js index 2aaa9a0..255255c 100644 --- a/chainforge/react-server/src/App.js +++ b/chainforge/react-server/src/App.js @@ -279,13 +279,13 @@ const App = () => { const importCache = useCallback((cache_data) => { return fetch_from_backend('importCache', { 'files': cache_data, - }, handleError).then(function(json) { + }).then(function(json) { if (!json || json.result === undefined) throw new Error('Request to import cache data was sent and received by backend server, but there was no response.'); else if (json.error || json.result === false) throw new Error('Error importing cache data:' + json.error); // Done! - }, handleError).catch(handleError); + }).catch(handleError); }, [handleError]); const importFlowFromJSON = useCallback((flowJSON) => { diff --git a/chainforge/react-server/src/GlobalSettingsModal.js b/chainforge/react-server/src/GlobalSettingsModal.js index ebb5879..d498554 100644 --- a/chainforge/react-server/src/GlobalSettingsModal.js +++ b/chainforge/react-server/src/GlobalSettingsModal.js @@ -4,6 +4,8 @@ import { useDisclosure } from '@mantine/hooks'; import { useForm } from '@mantine/form'; import useStore from './store'; +const _LINK_STYLE = {color: '#1E90FF', textDecoration: 'none'}; + const GlobalSettingsModal = forwardRef((props, ref) => { const [opened, { open, close }] = useDisclosure(false); const setAPIKeys = useStore((state) => state.setAPIKeys); @@ -15,6 +17,7 @@ const GlobalSettingsModal = forwardRef((props, ref) => { Google: '', Azure_OpenAI: '', Azure_OpenAI_Endpoint: '', + HuggingFace: '', }, validate: { @@ -38,16 +41,26 @@ const GlobalSettingsModal = forwardRef((props, ref) => { return ( + + Note: We do not store your API keys —not in a cookie, localStorage, or server. + Because of this, you must set your API keys every time you load ChainForge. If you prefer not to worry about it, + we recommend installing ChainForge locally and + setting your API keys as environment variables. +
Note: We do not store your API keys in a cookie or file. - Because of this, you must set your API keys every time you load ChainForge. - If you don't want to worry about it, we recommend setting the API key as an environment variable.} placeholder="Paste your OpenAI API key here" {...form.getInputProps('OpenAI')} /> +
+ +
{ @@ -629,6 +630,140 @@ const AzureOpenAISettings = { postprocessors: ChatGPTSettings.postprocessors, }; +const HuggingFaceTextInferenceSettings = { + fullName: "HuggingFace-hosted text generation models", + schema: { + "type": "object", + "required": [ + "shortname", + ], + "properties": { + "shortname": { + "type": "string", + "title": "Nickname", + "description": "Unique identifier to appear in ChainForge. Keep it short.", + "default": "Falcon.7B", + }, + "model": { + "type": "string", + "title": "Model", + "description": "Select a suggested HuggingFace-hosted model to query using the Inference API. For more details, check out https://huggingface.co/inference-api", + "enum": ["gpt2", "bigscience/bloom-560m", "tiiuae/falcon-7b-instruct", "bigcode/santacoder", "bigcode/starcoder", "Other (HuggingFace)"], + "default": "tiiuae/falcon-7b-instruct", + }, + "custom_model": { + "type": "string", + "title": "Custom HF model endpoint", + "description": "(Only used if you select 'Other' above.) Enter the HuggingFace ID of the text generation model you wish to query via the inference API.", + "default": "", + }, + "temperature": { + "type": "number", + "title": "temperature", + "description": "Controls the 'creativity' or randomness of the response.", + "default": 1.0, + "minimum": 0, + "maximum": 5.0, + "multipleOf": 0.01, + }, + "top_k": { + "type": "integer", + "title": "top_k", + "description": "Sets the maximum number of tokens to sample from on each step. Set to -1 to remain unspecified.", + "minimum": -1, + "default": -1, + }, + "top_p": { + "type": "number", + "title": "top_p", + "description": "Sets the maximum cumulative probability of tokens to sample from (from 0 to 1.0). Set to -1 to remain unspecified.", + "default": -1, + "minimum": -1, + "maximum": 1, + "multipleOf": 0.001, + }, + "repetition_penalty": { + "type": "number", + "title": "repetition_penalty", + "description": "The more a token is used within generation the more it is penalized to not be picked in successive generation passes. Set to -1 to remain unspecified.", + "minimum": -1, + "default": -1, + "maximum": 100, + "multipleOf": 0.01, + }, + "max_new_tokens": { + "type": "integer", + "title": "max_new_tokens", + "description": "The amount of new tokens to be generated, from 0 to 250 tokens. Set to -1 to remain unspecified.", + "default": -1, + "minimum": -1, + "maximum": 250, + }, + "do_sample": { + "type": "boolean", + "title": "do_sample", + "description": "Whether or not to use sampling. Default is True; uses greedy decoding otherwise.", + "enum": [true, false], + "default": true, + }, + "use_cache": { + "type": "boolean", + "title": "use_cache", + "description": "Whether or not to fetch from HF's cache. There is a cache layer on the inference API to speedup requests HF has already seen. Most models can use those results as is as models are deterministic (meaning the results will be the same anyway). However if you use a non-deterministic model, you can set this parameter to prevent the caching mechanism from being used resulting in a real new query.", + "enum": [true, false], + "default": false, + }, + } + }, + + uiSchema: { + 'ui:submitButtonOptions': { + props: { + disabled: false, + className: 'mantine-UnstyledButton-root mantine-Button-root', + }, + norender: false, + submitText: 'Submit', + }, + "shortname": { + "ui:autofocus": true + }, + "model": { + "ui:help": "Defaults to Falcon.7B." + }, + "temperature": { + "ui:help": "Defaults to 1.0.", + "ui:widget": "range" + }, + "max_new_tokens": { + "ui:help": "Defaults to unspecified (-1)" + }, + "top_k": { + "ui:help": "Defaults to unspecified (-1)" + }, + "top_p": { + "ui:help": "Defaults to unspecified (-1)", + "ui:widget": "range" + }, + "repetition_penalty": { + "ui:help": "Defaults to unspecified (-1)", + "ui:widget": "range" + }, + "max_new_tokens": { + "ui:help": "Defaults to unspecified (-1)", + }, + "do_sample": { + "ui:widget": "radio" + }, + "use_cache": { + "ui:widget": "radio", + "ui:help": "Defaults to false in ChainForge. This differs from the HuggingFace docs, as CF's intended use case is evaluation, and for evaluation we want different responses each query." + } + }, + + postprocessors: {} +}; + // A lookup table indexed by base_model. export const ModelSettings = { 'gpt-3.5-turbo': ChatGPTSettings, @@ -637,6 +772,7 @@ export const ModelSettings = { 'palm2-bison': PaLM2Settings, 'dalai': DalaiModelSettings, 'azure-openai': AzureOpenAISettings, + 'hf': HuggingFaceTextInferenceSettings, }; export const getTemperatureSpecForModel = (modelName) => { diff --git a/chainforge/react-server/src/backend/backend.ts b/chainforge/react-server/src/backend/backend.ts index 8ac7669..8bde43b 100644 --- a/chainforge/react-server/src/backend/backend.ts +++ b/chainforge/react-server/src/backend/backend.ts @@ -921,7 +921,9 @@ export async function fetchExampleFlow(evalname: string): Promise { // App is not running locally, but hosted on a site. // If this is the case, attempt to fetch the example flow from a relative site path: - return fetch(`examples/${evalname}.cforge`).then(response => response.json()); + return fetch(`examples/${evalname}.cforge`) + .then(response => response.json()) + .then(res => ({data: res})); } @@ -950,5 +952,7 @@ export async function fetchOpenAIEval(evalname: string): Promise { // App is not running locally, but hosted on a site. // If this is the case, attempt to fetch the example flow from relative path on the site: // > ALT: `https://raw.githubusercontent.com/ianarawjo/ChainForge/main/chainforge/oaievals/${_name}.cforge` - return fetch(`oaievals/${evalname}.cforge`).then(response => response.json()); + return fetch(`oaievals/${evalname}.cforge`) + .then(response => response.json()) + .then(res => ({data: res})); } diff --git a/chainforge/react-server/src/backend/models.ts b/chainforge/react-server/src/backend/models.ts index 42f9e14..40034d5 100644 --- a/chainforge/react-server/src/backend/models.ts +++ b/chainforge/react-server/src/backend/models.ts @@ -40,6 +40,19 @@ export enum LLM { // Google models PaLM2_Text_Bison = "text-bison-001", // it's really models/text-bison-001, but that's confusing PaLM2_Chat_Bison = "chat-bison-001", + + // HuggingFace Inference hosted models, suggested to users + HF_GPT2 = "gpt2", + HF_BLOOM_560M = "bigscience/bloom-560m", + HF_FALCON_7B_INSTRUCT = "tiiuae/falcon-7b-instruct", + HF_SANTACODER = "bigcode/santacoder", + HF_STARCODER = "bigcode/starcoder", + // HF_GPTJ_6B = "EleutherAI/gpt-j-6b", + // HF_LLAMA_7B = "decapoda-research/llama-7b-hf", + + // A special flag for a user-defined HuggingFace model endpoint. + // The actual model name will be passed as a param to the LLM call function. + HF_OTHER = "Other (HuggingFace)", } /** @@ -51,6 +64,7 @@ export enum LLMProvider { Dalai = "dalai", Anthropic = "anthropic", Google = "google", + HuggingFace = "hf", } /** @@ -68,6 +82,8 @@ export function getProvider(llm: LLM): LLMProvider | undefined { return LLMProvider.Google; else if (llm_name?.startsWith('Dalai')) return LLMProvider.Dalai; + else if (llm_name?.startsWith('HF_')) + return LLMProvider.HuggingFace; else if (llm.toString().startsWith('claude')) return LLMProvider.Anthropic; return undefined; diff --git a/chainforge/react-server/src/backend/utils.ts b/chainforge/react-server/src/backend/utils.ts index d1996e3..8739f27 100644 --- a/chainforge/react-server/src/backend/utils.ts +++ b/chainforge/react-server/src/backend/utils.ts @@ -94,6 +94,7 @@ let ANTHROPIC_API_KEY = get_environ("ANTHROPIC_API_KEY"); let GOOGLE_PALM_API_KEY = get_environ("PALM_API_KEY"); let AZURE_OPENAI_KEY = get_environ("AZURE_OPENAI_KEY"); let AZURE_OPENAI_ENDPOINT = get_environ("AZURE_OPENAI_ENDPOINT"); +let HUGGINGFACE_API_KEY = get_environ("HUGGINGFACE_API_KEY"); /** * Sets the local API keys for the revelant LLM API(s). @@ -104,6 +105,8 @@ export function set_api_keys(api_keys: StringDict): void { } if (key_is_present('OpenAI')) OPENAI_API_KEY= api_keys['OpenAI']; + if (key_is_present('HuggingFace')) + HUGGINGFACE_API_KEY = api_keys['HuggingFace']; if (key_is_present('Anthropic')) ANTHROPIC_API_KEY = api_keys['Anthropic']; if (key_is_present('Google')) @@ -242,8 +245,7 @@ export async function call_azure_openai(prompt: string, model: LLM, n: number = response = await openai_call(deployment_name, arg2, query); } catch (error) { if (error?.response) { - throw new Error("Could not authenticate to Azure OpenAI. Double-check that your API key is set in Settings or in your local environment."); - // throw new Error(error.response.status); + throw new Error(error.response.data?.error?.message); } else { console.log(error?.message || error); throw new Error(error?.message || error); @@ -498,6 +500,68 @@ export async function call_dalai(prompt: string, model: LLM, n: number = 1, temp // return query, responses +export async function call_huggingface(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> { + // Whether we should notice a given param in 'params' + const param_exists = (p: any) => (p !== undefined && !((typeof p === 'number' && p < 0) || (typeof p === 'string' && p.trim().length === 0))); + const set_param_if_exists = (name: string, query: Dict) => { + if (!params || params.size === 0) return; + const p = params[name]; + const exists = param_exists(p); + if (exists) { + // Set the param on the query dict + query[name] = p; + } else return; + } + + let query = { + temperature: temperature, + return_full_text: false, + }; + set_param_if_exists('top_k', query); + set_param_if_exists('top_p', query); + set_param_if_exists('repetition_penalty', query); + set_param_if_exists('max_new_tokens', query); + + + let options = { + use_cache: false, // we want it generating fresh each time + }; + set_param_if_exists('use_cache', options); + set_param_if_exists('do_sample', options); + + const using_custom_model_endpoint: boolean = param_exists(params?.custom_model); + + let headers: StringDict = {'Content-Type': 'application/json'}; + // For HuggingFace, technically, the API keys are optional. + if (HUGGINGFACE_API_KEY !== undefined) + headers.Authorization = `Bearer ${HUGGINGFACE_API_KEY}`; + + let responses: Array = []; + while (responses.length < n) { + const response = await fetch( + `https://api-inference.huggingface.co/models/${using_custom_model_endpoint ? params.custom_model.trim() : model}`, + { + headers: headers, + method: "POST", + body: JSON.stringify({inputs: prompt, parameters: query, options: options}), + } + ); + const result = await response.json(); + + // HuggingFace sometimes gives us an error, for instance if a model is loading. + // It returns this as an 'error' key in the response: + if (result?.error !== undefined) + throw new Error(result.error); + else if (!Array.isArray(result) || result.length !== 1) + throw new Error("Result of HuggingFace API call is in unexpected format:" + JSON.stringify(result)); + + // Continue querying + responses.push(result[0]); + } + + return [query, responses]; +} + /** * Switcher that routes the request to the appropriate API call function. If call doesn't exist, throws error. */ @@ -519,6 +583,8 @@ export async function call_llm(llm: LLM, prompt: string, n: number, temperature: call_api = call_dalai; else if (llm_provider === LLMProvider.Anthropic) call_api = call_anthropic; + else if (llm_provider === LLMProvider.HuggingFace) + call_api = call_huggingface; return call_api(prompt, llm, n, temperature, params); } @@ -588,6 +654,13 @@ function _extract_anthropic_responses(response: Array): Array { return response.map((r: Dict) => r.completion.trim()); } +/** + * Extracts the text part of a HuggingFace text completion. + */ +function _extract_huggingface_responses(response: Array): Array{ + return response.map((r: Dict) => r.generated_text.trim()); +} + /** * Given a LLM and a response object from its API, extract the * text response(s) part of the response object. @@ -609,6 +682,8 @@ export function extract_responses(response: Array | Dict, llm: LL return [response.toString()]; case LLMProvider.Anthropic: return _extract_anthropic_responses(response as Dict[]); + case LLMProvider.HuggingFace: + return _extract_huggingface_responses(response as Dict[]); default: throw new Error(`No method defined to extract responses for LLM ${llm}.`) }