Replace Dalai with Ollama (#209)

* Add basic Ollama support (#208)

* Remove trapFocus warning when no OpenAI key set

* Ensure Ollama is only visible in providers list if running locally. 

* Remove Dalai.

* Fix ollama support to include chat models and pass chat history correctly

* Fix bug with debounce on progress bar updates in Prompt/Chat nodes

* Rebuilt app and update package version

---------

Co-authored-by: Laurent Huberdeau <16990250+laurenthuberdeau@users.noreply.github.com>
This commit is contained in:
ianarawjo 2024-01-08 18:33:13 -05:00 committed by GitHub
parent 5acdfc0677
commit b92c03afb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 309 additions and 130 deletions

View File

@ -1,15 +1,15 @@
{
"files": {
"main.css": "/static/css/main.d7b7e6a1.css",
"main.js": "/static/js/main.546c3be8.js",
"main.js": "/static/js/main.16bd62df.js",
"static/js/787.4c72bb55.chunk.js": "/static/js/787.4c72bb55.chunk.js",
"index.html": "/index.html",
"main.d7b7e6a1.css.map": "/static/css/main.d7b7e6a1.css.map",
"main.546c3be8.js.map": "/static/js/main.546c3be8.js.map",
"main.16bd62df.js.map": "/static/js/main.16bd62df.js.map",
"787.4c72bb55.chunk.js.map": "/static/js/787.4c72bb55.chunk.js.map"
},
"entrypoints": [
"static/css/main.d7b7e6a1.css",
"static/js/main.546c3be8.js"
"static/js/main.16bd62df.js"
]
}

View File

@ -1 +1 @@
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.546c3be8.js"></script><link href="/static/css/main.d7b7e6a1.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.16bd62df.js"></script><link href="/static/css/main.d7b7e6a1.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>

View File

@ -1,10 +1,9 @@
import React, { useMemo, useRef } from 'react';
import { Stack, NumberInput, Button, Text, TextInput, Switch, Tabs, Popover, Badge, Textarea, Alert } from "@mantine/core"
import { Stack, NumberInput, Button, Text, Switch, Tabs, Popover, Badge, Textarea, Alert } from "@mantine/core"
import { useState } from 'react';
import { autofill, generateAndReplace, AIError } from './backend/ai';
import { IconSparkles, IconAlertCircle } from '@tabler/icons-react';
import AlertModal from './AlertModal';
import { useStore } from './store';
const zeroGap = {gap: "0rem"};
const popoverShadow ="rgb(38, 57, 77) 0px 10px 30px -14px";
@ -142,7 +141,7 @@ function AIPopover({
), [didGenerateAndReplaceError, generateAndReplacePrompt, setGenerateAndReplacePrompt, generateAndReplaceNumber, setGenerateAndReplaceNumber, generateAndReplaceIsUnconventional, setGenerateAndReplaceIsUnconventional, handleGenerateAndReplace, areValuesLoading]);
return (
<Popover position="right-start" withArrow shadow={popoverShadow} withinPortal keepMounted trapFocus>
<Popover position="right-start" withArrow shadow={popoverShadow} withinPortal keepMounted trapFocus={noOpenAIKeyMessage === undefined}>
<Popover.Target>
<button className="ai-button nodrag"><IconSparkles size={10} stroke={3}/></button>
</Popover.Target>

View File

@ -254,7 +254,7 @@ const ClaudeSettings = {
"type": "string",
"title": "Model Version",
"description": "Select a version of Claude to query. For more details on the differences, see the Anthropic API documentation.",
"enum": ["claude-2", "claude-2.0", "claude-2.1", "claude-instant-1", "claude-instant-1.1", "claude-instant-1.2", "claude-v1", "claude-v1-100k", "claude-instant-v1", "claude-instant-v1-100k", "claude-v1.3",
"enum": ["claude-2", "claude-2.0", "claude-2.1", "claude-instant-1", "claude-instant-1.1", "claude-instant-1.2", "claude-v1", "claude-v1-100k", "claude-instant-v1", "claude-instant-v1-100k", "claude-v1.3",
"claude-v1.3-100k", "claude-v1.2", "claude-v1.0", "claude-instant-v1.1", "claude-instant-v1.1-100k", "claude-instant-v1.0"],
"default": "claude-2"
},
@ -646,9 +646,9 @@ const HuggingFaceTextInferenceSettings = {
"enum": ["mistralai/Mistral-7B-Instruct-v0.1", "HuggingFaceH4/zephyr-7b-beta", "tiiuae/falcon-7b-instruct", "microsoft/DialoGPT-large", "bigscience/bloom-560m", "gpt2", "bigcode/santacoder", "bigcode/starcoder", "Other (HuggingFace)"],
"default": "tiiuae/falcon-7b-instruct",
"shortname_map": {
"mistralai/Mistral-7B-Instruct-v0.1": "Mistral-7B",
"HuggingFaceH4/zephyr-7b-beta": "Zephyr-7B",
"tiiuae/falcon-7b-instruct": "Falcon-7B",
"mistralai/Mistral-7B-Instruct-v0.1": "Mistral-7B",
"HuggingFaceH4/zephyr-7b-beta": "Zephyr-7B",
"tiiuae/falcon-7b-instruct": "Falcon-7B",
"microsoft/DialoGPT-large": "DialoGPT",
"bigscience/bloom-560m": "Bloom560M",
"gpt2": "GPT-2",
@ -742,7 +742,7 @@ const HuggingFaceTextInferenceSettings = {
"ui:autofocus": true
},
"model": {
"ui:help": "Defaults to Falcon.7B."
"ui:help": "Defaults to Falcon.7B."
},
"temperature": {
"ui:help": "Defaults to 1.0.",
@ -797,7 +797,7 @@ const AlephAlphaLuminousSettings = {
type: "string",
title: "Model",
description:
"Select a suggested Aleph Alpha model to query using the Aleph Alpha API. For more details, check outhttps://docs.aleph-alpha.com/api/available-models/",
"Select a suggested Aleph Alpha model to query using the Aleph Alpha API. For more details, check out https://docs.aleph-alpha.com/api/available-models/",
enum: [
"luminous-extended",
"luminous-extended-control",
@ -880,14 +880,7 @@ const AlephAlphaLuminousSettings = {
},
},
uiSchema: {
"ui:submitButtonOptions": {
props: {
disabled: false,
className: "mantine-UnstyledButton-root mantine-Button-root",
},
norender: false,
submitText: "Submit",
},
"ui:submitButtonOptions": UI_SUBMIT_BUTTON_SPEC,
shortname: {
"ui:autofocus": true,
},
@ -1000,6 +993,112 @@ const AlephAlphaLuminousSettings = {
},
};
const OllamaSettings = {
fullName: "Ollama",
schema: {
type: "object",
required: ["shortname"],
properties: {
shortname: {
type: "string",
title: "Nickname",
description:
"Unique identifier to appear in ChainForge. Keep it short.",
default: "Ollama",
},
ollamaModel: {
type: "string",
title: "Model",
description:
"Enter the model to query using Ollama's API. Make sure you've pulled the model before. For more details, check out https://ollama.ai/library",
default: "mistral",
},
ollama_url: {
type: "string",
title: "URL",
description: "URL of the Ollama server generate endpoint. Only enter the path up to /api, nothing else.",
default: "http://localhost:11434/api",
},
model_type: {
type: "string",
title: "Model Type (Text or Chat)",
description: "Select 'chat' to pass conversation history and use system messages on chat-enabled models, such as llama-2. Detected automatically when '-chat' or ':chat' is present in model name. You must select 'chat' if you want to use this model in Chat Turn nodes.",
enum: ["text", "chat"],
default: "text",
},
system_msg: {
type: "string",
title: "System Message (chat models only)",
description: "Enter your system message here. Note that the type of model must be set to 'chat' for this to be passed.",
default: "",
allow_empty_str: true,
},
temperature: {
type: "number",
title: "temperature",
description: "Amount of randomness injected into the response. Ranges from 0 to 1. Use temp closer to 0 for analytical / multiple choice, and temp closer to 1 for creative and generative tasks.",
default: 1.0,
minimum: 0,
maximum: 1.0,
multipleOf: 0.01,
},
raw: {
type: "boolean",
title: "raw",
description: "Whether to disable the templating done by Ollama. If checked, you'll need to insert annotations ([INST]) in your prompt.",
default: false,
},
top_k: {
type: "integer",
title: "top_k",
description: "Only sample from the top K options for each subsequent token. Used to remove \"long tail\" low probability responses. Defaults to -1, which disables it.",
minimum: -1,
default: -1
},
top_p: {
type: "number",
title: "top_p",
description: "Does nucleus sampling, in which we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. Defaults to -1, which disables it. Note that you should either alter temperature or top_p, but not both.",
default: -1,
minimum: -1,
maximum: 1,
multipleOf: 0.001,
},
seed: {
type: "integer",
title: "seed",
description: "If specified, the OpenAI API will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed.",
allow_empty_str: true,
},
stop_sequences: {
type: "string",
title: "stop_sequences",
description: "Sequences where the API will stop generating further tokens. Enclose stop sequences in double-quotes \"\" and use whitespace to separate them.",
default: ""
},
},
},
uiSchema: {
"ui:submitButtonOptions": UI_SUBMIT_BUTTON_SPEC,
shortname: {
"ui:autofocus": true,
},
temperature: {
"ui:help": "Defaults to 1.0.",
"ui:widget": "range",
},
raw: {
"ui:help": "Defaults to false.",
},
},
postprocessors: {
stop_sequences: (str) => {
if (str.trim().length === 0) return [];
return str.match(/"((?:[^"\\]|\\.)*)"/g).map(s => s.substring(1, s.length-1)); // split on double-quotes but exclude escaped double-quotes inside the group
},
},
};
// A lookup table indexed by base_model.
export let ModelSettings = {
'gpt-3.5-turbo': ChatGPTSettings,
@ -1009,16 +1108,17 @@ export let ModelSettings = {
'dalai': DalaiModelSettings,
'azure-openai': AzureOpenAISettings,
'hf': HuggingFaceTextInferenceSettings,
"luminous-base": AlephAlphaLuminousSettings
"luminous-base": AlephAlphaLuminousSettings,
"ollama": OllamaSettings,
};
/**
* Add new model provider to the AvailableLLMs list. Also adds the respective ModelSettings schema and rate limit.
* @param {*} name The name of the provider, to use in the dropdown menu and default name. Must be unique.
* @param {*} emoji The emoji to use for the provider. Optional.
* @param {*} emoji The emoji to use for the provider. Optional.
* @param {*} models A list of models the user can select from this provider. Optional.
* @param {*} rate_limit
* @param {*} settings_schema
* @param {*} rate_limit
* @param {*} settings_schema
*/
export const setCustomProvider = (name, emoji, models, rate_limit, settings_schema, llmProviderList) => {
if (typeof emoji === 'string' && (emoji.length === 0 || emoji.length > 2))
@ -1028,13 +1128,13 @@ export const setCustomProvider = (name, emoji, models, rate_limit, settings_sche
new_provider.emoji = emoji || '✨';
// Each LLM *model* must have a unique name. To avoid name collisions, for custom providers,
// the full LLM model name is a path, __custom/<provider_name>/<submodel name>
// the full LLM model name is a path, __custom/<provider_name>/<submodel name>
// If there's no submodel, it's just __custom/<provider_name>.
const base_model = `__custom/${name}/`;
new_provider.base_model = base_model;
new_provider.model = base_model + ((Array.isArray(models) && models.length > 0) ? `${models[0]}` : '');
// Build the settings form schema for this new custom provider
// Build the settings form schema for this new custom provider
let compiled_schema = {
fullName: `${name} (custom provider)`,
schema: {
@ -1069,7 +1169,7 @@ export const setCustomProvider = (name, emoji, models, rate_limit, settings_sche
"default": models[0],
};
compiled_schema.uiSchema["model"] = {
"ui:help": `Defaults to ${models[0]}`
"ui:help": `Defaults to ${models[0]}`
};
}
@ -1083,13 +1183,13 @@ export const setCustomProvider = (name, emoji, models, rate_limit, settings_sche
const default_temp = compiled_schema?.schema?.properties?.temperature?.default;
if (default_temp !== undefined)
new_provider.temp = default_temp;
// Add the built provider and its settings to the global lookups:
let AvailableLLMs = useStore.getState().AvailableLLMs;
const prev_provider_idx = AvailableLLMs.findIndex((d) => d.name === name);
if (prev_provider_idx > -1)
AvailableLLMs[prev_provider_idx] = new_provider;
else
else
AvailableLLMs.push(new_provider);
ModelSettings[base_model] = compiled_schema;
@ -1134,7 +1234,7 @@ export const postProcessFormData = (settingsSpec, formData) => {
else
new_data[key] = formData[key];
});
return new_data;
};

View File

@ -604,6 +604,11 @@ const PromptNode = ({ data, id, type: node_type }) => {
}
else if (json.responses && json.errors) {
// Remove progress bars
setProgress(undefined);
setProgressAnimated(false);
debounce(() => {}, 1)(); // erase any pending debounces
// Store and log responses (if any)
if (json.responses) {
setJSONResponses(json.responses);
@ -646,9 +651,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
setStatus('ready');
setContChatToggleDisabled(false);
// Remove progress bars
setProgress(undefined);
setProgressAnimated(true);
// Remove individual progress rings
llmListContainer?.current?.resetLLMItemsProgress();
// Save prompt text so we remember what prompt we have responses cache'd for:

View File

@ -1,5 +1,5 @@
/**
* A list of all model APIs natively supported by ChainForge.
/**
* A list of all model APIs natively supported by ChainForge.
*/
export type LLM = string | NativeLLM;
export enum NativeLLM {
@ -75,6 +75,7 @@ export enum NativeLLM {
// A special flag for a user-defined HuggingFace model endpoint.
// The actual model name will be passed as a param to the LLM call function.
HF_OTHER = "Other (HuggingFace)",
Ollama = "ollama",
}
/**
@ -88,6 +89,7 @@ export enum LLMProvider {
Google = "google",
HuggingFace = "hf",
Aleph_Alpha = "alephalpha",
Ollama = "ollama",
Custom = "__custom",
}
@ -112,9 +114,11 @@ export function getProvider(llm: LLM): LLMProvider | undefined {
return LLMProvider.Anthropic;
else if (llm_name?.startsWith('Aleph_Alpha'))
return LLMProvider.Aleph_Alpha;
else if (llm_name?.startsWith('Ollama'))
return LLMProvider.Ollama;
else if (llm.toString().startsWith('__custom/'))
return LLMProvider.Custom;
return undefined;
}
@ -145,8 +149,8 @@ export let RATE_LIMITS: { [key in LLM]?: [number, number] } = {
/** Equivalent to a Python enum's .name property */
export function getEnumName(enumObject: any, enumValue: any): string | undefined {
for (const key in enumObject)
if (enumObject[key] === enumValue)
for (const key in enumObject)
if (enumObject[key] === enumValue)
return key;
return undefined;
}

View File

@ -47,7 +47,7 @@ export function APP_IS_RUNNING_LOCALLY(): boolean {
// @ts-ignore
_APP_IS_RUNNING_LOCALLY = location.hostname === "localhost" || location.hostname === "127.0.0.1" || location.hostname === "0.0.0.0" || location.hostname === "" || window.__CF_HOSTNAME !== undefined;
} catch (e) {
// ReferenceError --window or location does not exist.
// ReferenceError --window or location does not exist.
// We must not be running client-side in a browser, in this case (e.g., we are running a Node.js server)
_APP_IS_RUNNING_LOCALLY = false;
}
@ -56,7 +56,7 @@ export function APP_IS_RUNNING_LOCALLY(): boolean {
}
/**
* Equivalent to a 'fetch' call, but routes it to the backend Flask server in
* Equivalent to a 'fetch' call, but routes it to the backend Flask server in
* case we are running a local server and prefer to not deal with CORS issues making API calls client-side.
*/
async function route_fetch(url: string, method: string, headers: Dict, body: Dict) {
@ -89,6 +89,10 @@ function get_environ(key: string): string | undefined {
return undefined;
}
function appendEndSlashIfMissing(path: string) {
return path + ((path[path.length-1] === '/') ? "" : "/");
}
let DALAI_MODEL: string | undefined;
let DALAI_RESPONSE: Dict | undefined;
@ -127,18 +131,18 @@ export function set_api_keys(api_keys: StringDict): void {
/**
* Construct an OpenAI format chat history for sending off to an OpenAI API call.
* @param prompt The next prompt (user message) to append.
* @param chat_history The prior turns of the chat, ending with the AI assistants' turn.
* @param chat_history The prior turns of the chat, ending with the AI assistants' turn.
* @param system_msg Optional; the system message to use if none is present in chat_history. (Ignored if chat_history already has a sys message.)
*/
function construct_openai_chat_history(prompt: string, chat_history: ChatHistory | undefined, system_msg: string): ChatHistory {
const prompt_msg: ChatMessage = { role: 'user', content: prompt };
if (chat_history !== undefined && chat_history.length > 0) {
if (chat_history[0].role === 'system') {
// In this case, the system_msg is ignored because the prior history already contains one.
// In this case, the system_msg is ignored because the prior history already contains one.
return chat_history.concat([prompt_msg]);
} else {
// In this case, there's no system message that starts the prior history, so inject one:
// NOTE: We might reach this scenario if we chain output of a non-OpenAI chat model into an OpenAI model.
// NOTE: We might reach this scenario if we chain output of a non-OpenAI chat model into an OpenAI model.
return [{"role": "system", "content": system_msg}].concat(chat_history).concat([prompt_msg]);
}
} else return [
@ -148,8 +152,8 @@ function construct_openai_chat_history(prompt: string, chat_history: ChatHistory
}
/**
* Calls OpenAI models via OpenAI's API.
@returns raw query and response JSON dicts.
* Calls OpenAI models via OpenAI's API.
@returns raw query and response JSON dicts.
*/
export async function call_chatgpt(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
if (!OPENAI_API_KEY)
@ -197,7 +201,7 @@ export async function call_chatgpt(prompt: string, model: LLM, n: number = 1, te
// Create call to text completions model
openai_call = openai.createCompletion.bind(openai);
query['prompt'] = prompt;
} else {
} else {
// Create call to chat model
openai_call = openai.createChatCompletion.bind(openai);
@ -218,14 +222,14 @@ export async function call_chatgpt(prompt: string, model: LLM, n: number = 1, te
console.log(error?.message || error);
throw new Error(error?.message || error);
}
}
}
return [query, response];
}
/**
* Calls OpenAI models hosted on Microsoft Azure services.
* Returns raw query and response JSON dicts.
* Returns raw query and response JSON dicts.
*
* NOTE: It is recommended to set an environment variables AZURE_OPENAI_KEY and AZURE_OPENAI_ENDPOINT
*/
@ -234,7 +238,7 @@ export async function call_azure_openai(prompt: string, model: LLM, n: number =
throw new Error("Could not find an Azure OpenAPI Key to use. Double-check that your key is set in Settings or in your local environment.");
if (!AZURE_OPENAI_ENDPOINT)
throw new Error("Could not find an Azure OpenAI Endpoint to use. Double-check that your endpoint is set in Settings or in your local environment.");
const deployment_name: string = params?.deployment_name;
const model_type: string = params?.model_type;
if (!deployment_name)
@ -258,7 +262,7 @@ export async function call_azure_openai(prompt: string, model: LLM, n: number =
delete params?.system_msg;
delete params?.model_type;
delete params?.deployment_name;
// Setup the args for the query
let query: Dict = {
n: n,
@ -285,7 +289,7 @@ export async function call_azure_openai(prompt: string, model: LLM, n: number =
throw new Error(error?.message || error);
}
}
return [query, response];
}
@ -294,7 +298,7 @@ export async function call_azure_openai(prompt: string, model: LLM, n: number =
Returns raw query and response JSON dicts.
Unique parameters:
- custom_prompt_wrapper: Anthropic models expect prompts in form "\n\nHuman: ${prompt}\n\nAssistant". If you wish to
- custom_prompt_wrapper: Anthropic models expect prompts in form "\n\nHuman: ${prompt}\n\nAssistant". If you wish to
explore custom prompt wrappers that deviate, write a python Template that maps from 'prompt' to custom wrapper.
If set to None, defaults to Anthropic's suggested prompt wrapper.
- max_tokens_to_sample: A maximum number of tokens to generate before stopping.
@ -316,7 +320,7 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
if (params?.custom_prompt_wrapper !== undefined)
delete params.custom_prompt_wrapper;
// Required non-standard params
// Required non-standard params
const max_tokens_to_sample = params?.max_tokens_to_sample || 1024;
const stop_sequences = params?.stop_sequences || [ANTHROPIC_HUMAN_PROMPT];
@ -394,7 +398,7 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
}
/**
* Calls a Google PaLM/Gemini model, based on the model selection from the user.
* Calls a Google PaLM/Gemini model, based on the model selection from the user.
* Returns raw query and response JSON dicts.
*/
export async function call_google_ai(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
@ -407,7 +411,7 @@ export async function call_google_ai(prompt: string, model: LLM, n: number = 1,
}
/**
* Calls a Google PaLM model.
* Calls a Google PaLM model.
* Returns raw query and response JSON dicts.
*/
export async function call_google_palm(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
@ -415,7 +419,7 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
throw new Error("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local environment.");
const is_chat_model = model.toString().includes('chat');
// Required non-standard params
// Required non-standard params
const max_output_tokens = params?.max_output_tokens || 800;
const chat_history = params?.chat_history;
delete params?.chat_history;
@ -438,7 +442,7 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
if (is_chat_model && query.stop_sequences !== undefined)
delete query.stop_sequences;
// For some reason Google needs to be special and have its API params be different names --camel or snake-case
// For some reason Google needs to be special and have its API params be different names --camel or snake-case
// --depending on if it's the Python or Node JS API. ChainForge needs a consistent name, so we must convert snake to camel:
const casemap = {
safety_settings: 'safetySettings',
@ -475,14 +479,14 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
query.prompt = palm_chat_context;
} else {
query.prompt = { messages: [{content: prompt}] };
}
}
} else {
// Text completions
query.prompt = { text: prompt };
}
console.log(`Calling Google PaLM model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`);
// Call the correct model client
const method = is_chat_model ? 'generateMessage' : 'generateText';
const url = `https://generativelanguage.googleapis.com/v1beta2/models/${model}:${method}?key=${GOOGLE_PALM_API_KEY}`;
@ -510,7 +514,7 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
completion.candidates = new Array(n).fill({'author': '1', 'content':block_error_msg});
}
// Weirdly, google ignores candidate_count if temperature is 0.
// Weirdly, google ignores candidate_count if temperature is 0.
// We have to check for this and manually append the n-1 responses:
if (n > 1 && completion.candidates?.length === 1) {
completion.candidates = new Array(n).fill(completion.candidates[0]);
@ -524,16 +528,14 @@ export async function call_google_gemini(prompt: string, model: LLM, n: number =
throw new Error("Could not find an API key for Google Gemini models. Double-check that your API key is set in Settings or in your local environment.");
// calling the correct model client
console.log('call_google_gemini: ');
model = NativeLLM.GEMINI_PRO;
const genAI = new GoogleGenerativeAI(GOOGLE_PALM_API_KEY);
const gemini_model = genAI.getGenerativeModel({model: model.toString()});
// removing chat for now. by default chat is supported
// Required non-standard params
// Required non-standard params
const max_output_tokens = params?.max_output_tokens || 1000;
const chat_history = params?.chat_history;
delete params?.chat_history;
@ -546,7 +548,7 @@ export async function call_google_gemini(prompt: string, model: LLM, n: number =
...params,
};
// For some reason Google needs to be special and have its API params be different names --camel or snake-case
// For some reason Google needs to be special and have its API params be different names --camel or snake-case
// --depending on if it's the Python or Node JS API. ChainForge needs a consistent name, so we must convert snake to camel:
const casemap = {
safety_settings: 'safetySettings',
@ -583,7 +585,7 @@ export async function call_google_gemini(prompt: string, model: LLM, n: number =
// Chat completions
if (chat_history !== undefined && chat_history.length > 0) {
// Carry over any chat history, converting OpenAI formatted chat history to Google PaLM:
let gemini_messages: GeminiChatMessage[] = [];
for (const chat_msg of chat_history) {
if (chat_msg.role === 'system') {
@ -608,7 +610,7 @@ export async function call_google_gemini(prompt: string, model: LLM, n: number =
generationConfig: gen_Config,
},
);
const chatResult = await chat.sendMessage(prompt);
const chatResponse = await chatResult.response;
const response = {
@ -692,14 +694,14 @@ export async function call_huggingface(prompt: string, model: LLM, n: number = 1
const using_custom_model_endpoint: boolean = param_exists(params?.custom_model);
let headers: StringDict = {'Content-Type': 'application/json'};
// For HuggingFace, technically, the API keys are optional.
// For HuggingFace, technically, the API keys are optional.
if (HUGGINGFACE_API_KEY !== undefined)
headers.Authorization = `Bearer ${HUGGINGFACE_API_KEY}`;
// Inference Endpoints for text completion models has the same call,
// except the endpoint is an entire URL. Detect this:
const url = (using_custom_model_endpoint && params.custom_model.startsWith('https:')) ?
params.custom_model :
// except the endpoint is an entire URL. Detect this:
const url = (using_custom_model_endpoint && params.custom_model.startsWith('https:')) ?
params.custom_model :
`https://api-inference.huggingface.co/models/${using_custom_model_endpoint ? params.custom_model.trim() : model}`;
let responses: Array<Dict> = [];
@ -710,8 +712,8 @@ export async function call_huggingface(prompt: string, model: LLM, n: number = 1
let curr_text = prompt;
while (curr_cont <= num_continuations) {
const inputs = (model_type === 'chat')
? ({ text: curr_text,
past_user_inputs: hf_chat_hist.past_user_inputs,
? ({ text: curr_text,
past_user_inputs: hf_chat_hist.past_user_inputs,
generated_responses: hf_chat_hist.generated_responses })
: curr_text;
@ -722,12 +724,12 @@ export async function call_huggingface(prompt: string, model: LLM, n: number = 1
body: JSON.stringify({inputs: inputs, parameters: query, options: options}),
});
const result = await response.json();
// HuggingFace sometimes gives us an error, for instance if a model is loading.
// It returns this as an 'error' key in the response:
if (result?.error !== undefined)
throw new Error(result.error);
else if ((model_type !== 'chat' && (!Array.isArray(result) || result.length !== 1)) ||
else if ((model_type !== 'chat' && (!Array.isArray(result) || result.length !== 1)) ||
(model_type === 'chat' && (Array.isArray(result) || !result || result?.generated_text === undefined)))
throw new Error("Result of HuggingFace API call is in unexpected format:" + JSON.stringify(result));
@ -750,12 +752,12 @@ export async function call_huggingface(prompt: string, model: LLM, n: number = 1
export async function call_alephalpha(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
if (!ALEPH_ALPHA_API_KEY)
throw Error("Could not find an API key for Aleph Alpha models. Double-check that your API key is set in Settings or in your local environment.");
const url: string = 'https://api.aleph-alpha.com/complete';
let headers: StringDict = {'Content-Type': 'application/json', 'Accept': 'application/json'};
if (ALEPH_ALPHA_API_KEY !== undefined)
headers.Authorization = `Bearer ${ALEPH_ALPHA_API_KEY}`;
let data = JSON.stringify({
"model": model.toString(),
"prompt": prompt,
@ -770,7 +772,7 @@ export async function call_alephalpha(prompt: string, model: LLM, n: number = 1,
temperature: temperature,
...params, // 'the rest' of the settings, passed from the front-end settings
};
const response = await fetch(url, {
headers: headers,
method: "POST",
@ -782,14 +784,71 @@ export async function call_alephalpha(prompt: string, model: LLM, n: number = 1,
return [query, responses];
}
export async function call_ollama_provider(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
let url: string = appendEndSlashIfMissing(params?.ollama_url);
const ollama_model: string = params?.ollamaModel.toString();
const model_type: string = params?.model_type ?? "text";
const system_msg: string = params?.system_msg ?? "";
const chat_history: ChatHistory | undefined = params?.chat_history;
// Cleanup
for (const name of ["ollamaModel", "ollama_url", "model_type", "system_msg", "chat_history"])
if (name in params) delete params[name];
// FIXME: Ollama doesn't support batch inference, but llama.cpp does so it will eventually
// For now, we send n requests and then wait for all of them to finish
let query: Dict = {
model: ollama_model,
stream: false,
temperature: temperature,
...params, // 'the rest' of the settings, passed from the front-end settings
};
// If the model type is explicitly or implicitly set to "chat", pass chat history instead:
if (model_type === 'chat' || /[-:](chat)/.test(ollama_model)) {
// Construct chat history and pass to query payload
query.messages = construct_openai_chat_history(prompt, chat_history, system_msg);
url += "chat";
} else {
// Text-only models
query.prompt = prompt;
url += "generate";
}
console.log(`Calling Ollama API at ${url} for model '${ollama_model}' with prompt '${prompt}' n=${n} times. Please be patient...`);
// Call Ollama API
let resps : Response[] = [];
for (let i = 0; i < n; i++) {
const response = await fetch(url, {
method: "POST",
body: JSON.stringify(query),
});
resps.push(response);
}
const parse_response = (body: string) => {
const json = JSON.parse(body);
if (json.message) // chat models
return {generated_text: json.message.content}
else // text-only models
return {generated_text: json.response};
};
const responses = await Promise.all(resps.map((resp) => resp.text())).then((responses) => {
return responses.map((response) => parse_response(response));
});
return [query, responses];
}
async function call_custom_provider(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
if (!APP_IS_RUNNING_LOCALLY())
throw new Error("The ChainForge app does not appear to be running locally. You can only call custom model providers if you are running ChainForge on your local machine, from a Flask app.")
// The model to call is in format:
// __custom/<provider_name>/<submodel name>
// It may also exclude the final tag.
// __custom/<provider_name>/<submodel name>
// It may also exclude the final tag.
// We extract the provider name (this is the name used in the Python backend's `ProviderRegistry`) and optionally, the submodel name
const provider_path = model.substring(9);
const provider_name = provider_path.substring(0, provider_path.indexOf('/'));
@ -798,9 +857,9 @@ async function call_custom_provider(prompt: string, model: LLM, n: number = 1, t
let responses = [];
const query = { prompt, model, temperature, ...params };
// Call the custom provider n times
// Call the custom provider n times
while (responses.length < n) {
let {response, error} = await call_flask_backend('callCustomProvider',
let {response, error} = await call_flask_backend('callCustomProvider',
{ 'name': provider_name,
'params': {
prompt, model: submodel_name, temperature, ...params
@ -823,7 +882,7 @@ export async function call_llm(llm: LLM, prompt: string, n: number, temperature:
// Get the correct API call for the given LLM:
let call_api: LLMAPICall | undefined;
let llm_provider: LLMProvider = getProvider(llm);
if (llm_provider === undefined)
throw new Error(`Language model ${llm} is not supported.`);
@ -841,20 +900,22 @@ export async function call_llm(llm: LLM, prompt: string, n: number, temperature:
call_api = call_huggingface;
else if (llm_provider === LLMProvider.Aleph_Alpha)
call_api = call_alephalpha;
else if (llm_provider === LLMProvider.Ollama)
call_api = call_ollama_provider;
else if (llm_provider === LLMProvider.Custom)
call_api = call_custom_provider;
return call_api(prompt, llm, n, temperature, params);
}
/**
* Extracts the relevant portion of a OpenAI chat response.
* Extracts the relevant portion of a OpenAI chat response.
* Note that chat choice objects can now include 'function_call' and a blank 'content' response.
* This method detects a 'function_call's presence, prepends [[FUNCTION]] and converts the function call into JS format.
* This method detects a 'function_call's presence, prepends [[FUNCTION]] and converts the function call into JS format.
*/
function _extract_openai_chat_choice_content(choice: Dict): string {
if (choice['finish_reason'] === 'function_call' ||
if (choice['finish_reason'] === 'function_call' ||
('function_call' in choice['message'] && choice['message']['function_call'].length > 0)) {
const func = choice['message']['function_call'];
return '[[FUNCTION]] ' + func['name'] + func['arguments'].toString();
@ -865,7 +926,7 @@ function _extract_openai_chat_choice_content(choice: Dict): string {
/**
* Extracts the text part of a response JSON from ChatGPT. If there is more
* than 1 response (e.g., asking the LLM to generate multiple responses),
* than 1 response (e.g., asking the LLM to generate multiple responses),
* this produces a list of all returned responses.
*/
function _extract_chatgpt_responses(response: Dict): Array<string> {
@ -874,7 +935,7 @@ function _extract_chatgpt_responses(response: Dict): Array<string> {
/**
* Extracts the text part of a response JSON from OpenAI completions models like Davinci. If there are more
* than 1 response (e.g., asking the LLM to generate multiple responses),
* than 1 response (e.g., asking the LLM to generate multiple responses),
* this produces a list of all returned responses.
*/
function _extract_openai_completion_responses(response: Dict): Array<string> {
@ -939,7 +1000,14 @@ function _extract_huggingface_responses(response: Array<Dict>): Array<string>{
* Extracts the text part of a Aleph Alpha text completion.
*/
function _extract_alephalpha_responses(response: Dict): Array<string> {
return response.map((r: string) => r.trim());
return response.map((r: string) => r.trim());
}
/**
* Extracts the text part of a Ollama text completion.
*/
function _extract_ollama_responses(response: Array<Dict>): Array<string> {
return response.map((r: Object) => r["generated_text"].trim());
}
/**
@ -965,20 +1033,22 @@ export function extract_responses(response: Array<string | Dict> | Dict, llm: LL
return _extract_anthropic_responses(response as Dict[]);
case LLMProvider.HuggingFace:
return _extract_huggingface_responses(response as Dict[]);
case LLMProvider.Aleph_Alpha:
return _extract_alephalpha_responses(response);
case LLMProvider.Aleph_Alpha:
return _extract_alephalpha_responses(response);
case LLMProvider.Ollama:
return _extract_ollama_responses(response as Dict[]);
default:
if (Array.isArray(response) && response.length > 0 && typeof response[0] === 'string')
return response as string[];
else
return response as string[];
else
throw new Error(`No method defined to extract responses for LLM ${llm}.`)
}
}
/**
* Marge the 'responses' and 'raw_response' properties of two LLMResponseObjects,
* keeping all the other params from the second argument (llm, query, etc).
*
* keeping all the other params from the second argument (llm, query, etc).
*
* If one object is undefined or null, returns the object that is defined, unaltered.
*/
export function merge_response_objs(resp_obj_A: LLMResponseObject | undefined, resp_obj_B: LLMResponseObject | undefined): LLMResponseObject | undefined {
@ -1102,7 +1172,7 @@ export const toStandardResponseFormat = (r) => {
// Check if the current browser window/tab is 'active' or not
export const browserTabIsActive = () => {
try {
const visible = document.visibilityState === 'visible';
const visible = document.visibilityState === 'visible';
return visible;
} catch(e) {
console.error(e);
@ -1136,7 +1206,7 @@ export const extractLLMLookup = (input_data) => {
export const removeLLMTagFromMetadata = (metavars) => {
if (!('__LLM_key' in metavars))
return metavars;
return metavars;
let mcopy = JSON.parse(JSON.stringify(metavars));
delete mcopy['__LLM_key'];
return mcopy;

View File

@ -17,7 +17,7 @@ const initialLLMColors = {};
/** The color palette used for displaying info about different LLMs. */
const llmColorPalette = ['#44d044', '#f1b933', '#e46161', '#8888f9', '#33bef0', '#bb55f9', '#f7ee45', '#f955cd', '#26e080', '#2654e0', '#7d8191', '#bea5d1'];
/** The color palette used for displaying variations of prompts and prompt variables (non-LLM differences).
/** The color palette used for displaying variations of prompts and prompt variables (non-LLM differences).
* Distinct from the LLM color palette in order to avoid confusion around what the data means.
* Palette adapted from https://lospec.com/palette-list/sness by Space Sandwich */
const varColorPalette = ['#0bdb52', '#e71861', '#7161de', '#f6d714', '#80bedb', '#ffa995', '#a9b399', '#dc6f0f', '#8d022e', '#138e7d', '#c6924f', '#885818', '#616b6d'];
@ -40,7 +40,10 @@ export let initLLMProviders = [
{ name: "Azure OpenAI", emoji: "🔷", model: "azure-openai", base_model: "azure-openai", temp: 1.0 },
];
if (APP_IS_RUNNING_LOCALLY()) {
initLLMProviders.push({ name: "Dalai (Alpaca.7B)", emoji: "🦙", model: "alpaca.7B", base_model: "dalai", temp: 0.5 });
initLLMProviders.push({ name: "Ollama", emoji: "🦙", model: "ollama", base_model: "ollama", temp: 1.0 });
// -- Deprecated provider --
// initLLMProviders.push({ name: "Dalai (Alpaca.7B)", emoji: "🦙", model: "alpaca.7B", base_model: "dalai", temp: 0.5 });
// -------------------------
}
// A global store of variables, used for maintaining state
@ -79,7 +82,7 @@ const useStore = create((set, get) => ({
llmColors: initialLLMColors,
// Gets the color for the model named 'llm_name' in llmColors; returns undefined if not found.
getColorForLLM: (llm_name) => {
getColorForLLM: (llm_name) => {
const colors = get().llmColors;
if (llm_name in colors)
return colors[llm_name];
@ -87,7 +90,7 @@ const useStore = create((set, get) => ({
},
// Gets the color for the specified LLM. If not found, generates a new (ideally unique) color
// and saves it to the llmColors dict.
// and saves it to the llmColors dict.
getColorForLLMAndSetIfNotFound: (llm_name) => {
let color = get().getColorForLLM(llm_name);
if (color) return color;
@ -97,7 +100,7 @@ const useStore = create((set, get) => ({
},
// Generates a unique color not yet used in llmColors (unless # colors is large)
genUniqueLLMColor: () => {
genUniqueLLMColor: () => {
const used_colors = new Set(Object.values(get().llmColors));
const get_unused_color = (all_colors) => {
for (let i = 0; i < all_colors.length; i++) {
@ -107,7 +110,7 @@ const useStore = create((set, get) => ({
}
return undefined;
};
let unique_color = get_unused_color(llmColorPalette);
if (unique_color) return unique_color;
@ -116,7 +119,7 @@ const useStore = create((set, get) => ({
unique_color = get_unused_color(varColorPalette);
if (unique_color) return unique_color;
// If we've reached here, we've run out of all predefined colors.
// If we've reached here, we've run out of all predefined colors.
// Choose one to repeat, at random:
const all_colors = llmColorPalette.concat(varColorPalette);
return all_colors[Math.floor(Math.random() * all_colors.length)];
@ -133,7 +136,7 @@ const useStore = create((set, get) => ({
llmColors: []
});
},
inputEdgesForNode: (sourceNodeId) => {
return get().edges.filter(e => e.target == sourceNodeId);
},
@ -164,7 +167,7 @@ const useStore = create((set, get) => ({
const src_col = columns.find(c => c.header === sourceHandleKey);
if (src_col !== undefined) {
// Construct a lookup table from column key to header name,
// Construct a lookup table from column key to header name,
// as the 'metavars' dict should be keyed by column *header*, not internal key:
let col_header_lookup = {};
columns.forEach(c => {
@ -203,10 +206,10 @@ const useStore = create((set, get) => ({
if (Array.isArray(src_node.data["fields"]))
return src_node.data["fields"];
else {
// We have to filter over a special 'fields_visibility' prop, which
// We have to filter over a special 'fields_visibility' prop, which
// can select what fields get output:
if ("fields_visibility" in src_node.data)
return Object.values(filterDict(src_node.data["fields"],
return Object.values(filterDict(src_node.data["fields"],
fid => src_node.data["fields_visibility"][fid] !== false));
else // return all field values
return Object.values(src_node.data["fields"]);
@ -224,7 +227,7 @@ const useStore = create((set, get) => ({
// Get the types of nodes attached immediately as input to the given node
getImmediateInputNodeTypes: (_targetHandles, node_id) => {
const getNode = get().getNode;
const edges = get().edges;
const edges = get().edges;
let inputNodeTypes = [];
edges.forEach(e => {
if (e.target == node_id && _targetHandles.includes(e.targetHandle)) {
@ -242,9 +245,9 @@ const useStore = create((set, get) => ({
// Functions/data from the store:
const getNode = get().getNode;
const output = get().output;
const edges = get().edges;
const edges = get().edges;
// Helper function to store collected data in dict:
// Helper function to store collected data in dict:
const store_data = (_texts, _varname, _data) => {
if (_varname in _data)
_data[_varname] = _data[_varname].concat(_texts);
@ -257,12 +260,12 @@ const useStore = create((set, get) => ({
const get_outputs = (varnames, nodeId, var_history) => {
varnames.forEach(varname => {
// Check for duplicate variable names
if (var_history.has(String(varname).toLowerCase()))
if (var_history.has(String(varname).toLowerCase()))
throw new DuplicateVariableNameError(varname);
// Add to unique name tally
var_history.add(String(varname).toLowerCase());
// Find the relevant edge(s):
edges.forEach(e => {
if (e.target == nodeId && e.targetHandle == varname) {
@ -278,7 +281,7 @@ const useStore = create((set, get) => ({
// Save the list of strings from the pulled output under the var 'varname'
store_data(out, varname, pulled_data);
}
// Get any vars that the output depends on, and recursively collect those outputs as well:
const n_vars = getNode(e.source).data.vars;
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
@ -295,15 +298,15 @@ const useStore = create((set, get) => ({
/**
* Sets select 'data' properties for node 'id'. This updates global state, and forces re-renders. Use sparingly.
* @param {*} id The id of the node to set 'data' properties for.
* @param {*} data_props The properties to set on the node's 'data'.
* @param {*} data_props The properties to set on the node's 'data'.
*/
setDataPropsForNode: (id, data_props) => {
set({
nodes: (nds =>
nodes: (nds =>
nds.map(n => {
if (n.id === id) {
for (const key of Object.keys(data_props))
n.data[key] = data_props[key];
for (const key of Object.keys(data_props))
n.data[key] = data_props[key];
n.data = deepcopy(n.data);
}
return n;
@ -393,10 +396,10 @@ const useStore = create((set, get) => ({
});
},
onConnect: (connection) => {
// Get the target node information
const target = get().getNode(connection.target);
if (target.type === 'vis' || target.type === 'inspect' || target.type === 'simpleval') {
get().setDataPropsForNode(target.id, { input: connection.source });
}

View File

@ -6,7 +6,7 @@ def readme():
setup(
name='chainforge',
version='0.2.8.8',
version='0.2.8.9',
packages=find_packages(),
author="Ian Arawjo",
description="A Visual Programming Environment for Prompt Engineering",