mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 08:16:37 +00:00
Replace Dalai with Ollama (#209)
* Add basic Ollama support (#208) * Remove trapFocus warning when no OpenAI key set * Ensure Ollama is only visible in providers list if running locally. * Remove Dalai. * Fix ollama support to include chat models and pass chat history correctly * Fix bug with debounce on progress bar updates in Prompt/Chat nodes * Rebuilt app and update package version --------- Co-authored-by: Laurent Huberdeau <16990250+laurenthuberdeau@users.noreply.github.com>
This commit is contained in:
parent
5acdfc0677
commit
b92c03afb2
@ -1,15 +1,15 @@
|
||||
{
|
||||
"files": {
|
||||
"main.css": "/static/css/main.d7b7e6a1.css",
|
||||
"main.js": "/static/js/main.546c3be8.js",
|
||||
"main.js": "/static/js/main.16bd62df.js",
|
||||
"static/js/787.4c72bb55.chunk.js": "/static/js/787.4c72bb55.chunk.js",
|
||||
"index.html": "/index.html",
|
||||
"main.d7b7e6a1.css.map": "/static/css/main.d7b7e6a1.css.map",
|
||||
"main.546c3be8.js.map": "/static/js/main.546c3be8.js.map",
|
||||
"main.16bd62df.js.map": "/static/js/main.16bd62df.js.map",
|
||||
"787.4c72bb55.chunk.js.map": "/static/js/787.4c72bb55.chunk.js.map"
|
||||
},
|
||||
"entrypoints": [
|
||||
"static/css/main.d7b7e6a1.css",
|
||||
"static/js/main.546c3be8.js"
|
||||
"static/js/main.16bd62df.js"
|
||||
]
|
||||
}
|
@ -1 +1 @@
|
||||
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.546c3be8.js"></script><link href="/static/css/main.d7b7e6a1.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
|
||||
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.16bd62df.js"></script><link href="/static/css/main.d7b7e6a1.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
5
chainforge/react-server/src/AiPopover.js
vendored
5
chainforge/react-server/src/AiPopover.js
vendored
@ -1,10 +1,9 @@
|
||||
import React, { useMemo, useRef } from 'react';
|
||||
import { Stack, NumberInput, Button, Text, TextInput, Switch, Tabs, Popover, Badge, Textarea, Alert } from "@mantine/core"
|
||||
import { Stack, NumberInput, Button, Text, Switch, Tabs, Popover, Badge, Textarea, Alert } from "@mantine/core"
|
||||
import { useState } from 'react';
|
||||
import { autofill, generateAndReplace, AIError } from './backend/ai';
|
||||
import { IconSparkles, IconAlertCircle } from '@tabler/icons-react';
|
||||
import AlertModal from './AlertModal';
|
||||
import { useStore } from './store';
|
||||
|
||||
const zeroGap = {gap: "0rem"};
|
||||
const popoverShadow ="rgb(38, 57, 77) 0px 10px 30px -14px";
|
||||
@ -142,7 +141,7 @@ function AIPopover({
|
||||
), [didGenerateAndReplaceError, generateAndReplacePrompt, setGenerateAndReplacePrompt, generateAndReplaceNumber, setGenerateAndReplaceNumber, generateAndReplaceIsUnconventional, setGenerateAndReplaceIsUnconventional, handleGenerateAndReplace, areValuesLoading]);
|
||||
|
||||
return (
|
||||
<Popover position="right-start" withArrow shadow={popoverShadow} withinPortal keepMounted trapFocus>
|
||||
<Popover position="right-start" withArrow shadow={popoverShadow} withinPortal keepMounted trapFocus={noOpenAIKeyMessage === undefined}>
|
||||
<Popover.Target>
|
||||
<button className="ai-button nodrag"><IconSparkles size={10} stroke={3}/></button>
|
||||
</Popover.Target>
|
||||
|
148
chainforge/react-server/src/ModelSettingSchemas.js
vendored
148
chainforge/react-server/src/ModelSettingSchemas.js
vendored
@ -254,7 +254,7 @@ const ClaudeSettings = {
|
||||
"type": "string",
|
||||
"title": "Model Version",
|
||||
"description": "Select a version of Claude to query. For more details on the differences, see the Anthropic API documentation.",
|
||||
"enum": ["claude-2", "claude-2.0", "claude-2.1", "claude-instant-1", "claude-instant-1.1", "claude-instant-1.2", "claude-v1", "claude-v1-100k", "claude-instant-v1", "claude-instant-v1-100k", "claude-v1.3",
|
||||
"enum": ["claude-2", "claude-2.0", "claude-2.1", "claude-instant-1", "claude-instant-1.1", "claude-instant-1.2", "claude-v1", "claude-v1-100k", "claude-instant-v1", "claude-instant-v1-100k", "claude-v1.3",
|
||||
"claude-v1.3-100k", "claude-v1.2", "claude-v1.0", "claude-instant-v1.1", "claude-instant-v1.1-100k", "claude-instant-v1.0"],
|
||||
"default": "claude-2"
|
||||
},
|
||||
@ -646,9 +646,9 @@ const HuggingFaceTextInferenceSettings = {
|
||||
"enum": ["mistralai/Mistral-7B-Instruct-v0.1", "HuggingFaceH4/zephyr-7b-beta", "tiiuae/falcon-7b-instruct", "microsoft/DialoGPT-large", "bigscience/bloom-560m", "gpt2", "bigcode/santacoder", "bigcode/starcoder", "Other (HuggingFace)"],
|
||||
"default": "tiiuae/falcon-7b-instruct",
|
||||
"shortname_map": {
|
||||
"mistralai/Mistral-7B-Instruct-v0.1": "Mistral-7B",
|
||||
"HuggingFaceH4/zephyr-7b-beta": "Zephyr-7B",
|
||||
"tiiuae/falcon-7b-instruct": "Falcon-7B",
|
||||
"mistralai/Mistral-7B-Instruct-v0.1": "Mistral-7B",
|
||||
"HuggingFaceH4/zephyr-7b-beta": "Zephyr-7B",
|
||||
"tiiuae/falcon-7b-instruct": "Falcon-7B",
|
||||
"microsoft/DialoGPT-large": "DialoGPT",
|
||||
"bigscience/bloom-560m": "Bloom560M",
|
||||
"gpt2": "GPT-2",
|
||||
@ -742,7 +742,7 @@ const HuggingFaceTextInferenceSettings = {
|
||||
"ui:autofocus": true
|
||||
},
|
||||
"model": {
|
||||
"ui:help": "Defaults to Falcon.7B."
|
||||
"ui:help": "Defaults to Falcon.7B."
|
||||
},
|
||||
"temperature": {
|
||||
"ui:help": "Defaults to 1.0.",
|
||||
@ -797,7 +797,7 @@ const AlephAlphaLuminousSettings = {
|
||||
type: "string",
|
||||
title: "Model",
|
||||
description:
|
||||
"Select a suggested Aleph Alpha model to query using the Aleph Alpha API. For more details, check outhttps://docs.aleph-alpha.com/api/available-models/",
|
||||
"Select a suggested Aleph Alpha model to query using the Aleph Alpha API. For more details, check out https://docs.aleph-alpha.com/api/available-models/",
|
||||
enum: [
|
||||
"luminous-extended",
|
||||
"luminous-extended-control",
|
||||
@ -880,14 +880,7 @@ const AlephAlphaLuminousSettings = {
|
||||
},
|
||||
},
|
||||
uiSchema: {
|
||||
"ui:submitButtonOptions": {
|
||||
props: {
|
||||
disabled: false,
|
||||
className: "mantine-UnstyledButton-root mantine-Button-root",
|
||||
},
|
||||
norender: false,
|
||||
submitText: "Submit",
|
||||
},
|
||||
"ui:submitButtonOptions": UI_SUBMIT_BUTTON_SPEC,
|
||||
shortname: {
|
||||
"ui:autofocus": true,
|
||||
},
|
||||
@ -1000,6 +993,112 @@ const AlephAlphaLuminousSettings = {
|
||||
},
|
||||
};
|
||||
|
||||
const OllamaSettings = {
|
||||
fullName: "Ollama",
|
||||
schema: {
|
||||
type: "object",
|
||||
required: ["shortname"],
|
||||
properties: {
|
||||
shortname: {
|
||||
type: "string",
|
||||
title: "Nickname",
|
||||
description:
|
||||
"Unique identifier to appear in ChainForge. Keep it short.",
|
||||
default: "Ollama",
|
||||
},
|
||||
ollamaModel: {
|
||||
type: "string",
|
||||
title: "Model",
|
||||
description:
|
||||
"Enter the model to query using Ollama's API. Make sure you've pulled the model before. For more details, check out https://ollama.ai/library",
|
||||
default: "mistral",
|
||||
},
|
||||
ollama_url: {
|
||||
type: "string",
|
||||
title: "URL",
|
||||
description: "URL of the Ollama server generate endpoint. Only enter the path up to /api, nothing else.",
|
||||
default: "http://localhost:11434/api",
|
||||
},
|
||||
model_type: {
|
||||
type: "string",
|
||||
title: "Model Type (Text or Chat)",
|
||||
description: "Select 'chat' to pass conversation history and use system messages on chat-enabled models, such as llama-2. Detected automatically when '-chat' or ':chat' is present in model name. You must select 'chat' if you want to use this model in Chat Turn nodes.",
|
||||
enum: ["text", "chat"],
|
||||
default: "text",
|
||||
},
|
||||
system_msg: {
|
||||
type: "string",
|
||||
title: "System Message (chat models only)",
|
||||
description: "Enter your system message here. Note that the type of model must be set to 'chat' for this to be passed.",
|
||||
default: "",
|
||||
allow_empty_str: true,
|
||||
},
|
||||
temperature: {
|
||||
type: "number",
|
||||
title: "temperature",
|
||||
description: "Amount of randomness injected into the response. Ranges from 0 to 1. Use temp closer to 0 for analytical / multiple choice, and temp closer to 1 for creative and generative tasks.",
|
||||
default: 1.0,
|
||||
minimum: 0,
|
||||
maximum: 1.0,
|
||||
multipleOf: 0.01,
|
||||
},
|
||||
raw: {
|
||||
type: "boolean",
|
||||
title: "raw",
|
||||
description: "Whether to disable the templating done by Ollama. If checked, you'll need to insert annotations ([INST]) in your prompt.",
|
||||
default: false,
|
||||
},
|
||||
top_k: {
|
||||
type: "integer",
|
||||
title: "top_k",
|
||||
description: "Only sample from the top K options for each subsequent token. Used to remove \"long tail\" low probability responses. Defaults to -1, which disables it.",
|
||||
minimum: -1,
|
||||
default: -1
|
||||
},
|
||||
top_p: {
|
||||
type: "number",
|
||||
title: "top_p",
|
||||
description: "Does nucleus sampling, in which we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. Defaults to -1, which disables it. Note that you should either alter temperature or top_p, but not both.",
|
||||
default: -1,
|
||||
minimum: -1,
|
||||
maximum: 1,
|
||||
multipleOf: 0.001,
|
||||
},
|
||||
seed: {
|
||||
type: "integer",
|
||||
title: "seed",
|
||||
description: "If specified, the OpenAI API will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed.",
|
||||
allow_empty_str: true,
|
||||
},
|
||||
stop_sequences: {
|
||||
type: "string",
|
||||
title: "stop_sequences",
|
||||
description: "Sequences where the API will stop generating further tokens. Enclose stop sequences in double-quotes \"\" and use whitespace to separate them.",
|
||||
default: ""
|
||||
},
|
||||
},
|
||||
},
|
||||
uiSchema: {
|
||||
"ui:submitButtonOptions": UI_SUBMIT_BUTTON_SPEC,
|
||||
shortname: {
|
||||
"ui:autofocus": true,
|
||||
},
|
||||
temperature: {
|
||||
"ui:help": "Defaults to 1.0.",
|
||||
"ui:widget": "range",
|
||||
},
|
||||
raw: {
|
||||
"ui:help": "Defaults to false.",
|
||||
},
|
||||
},
|
||||
postprocessors: {
|
||||
stop_sequences: (str) => {
|
||||
if (str.trim().length === 0) return [];
|
||||
return str.match(/"((?:[^"\\]|\\.)*)"/g).map(s => s.substring(1, s.length-1)); // split on double-quotes but exclude escaped double-quotes inside the group
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// A lookup table indexed by base_model.
|
||||
export let ModelSettings = {
|
||||
'gpt-3.5-turbo': ChatGPTSettings,
|
||||
@ -1009,16 +1108,17 @@ export let ModelSettings = {
|
||||
'dalai': DalaiModelSettings,
|
||||
'azure-openai': AzureOpenAISettings,
|
||||
'hf': HuggingFaceTextInferenceSettings,
|
||||
"luminous-base": AlephAlphaLuminousSettings
|
||||
"luminous-base": AlephAlphaLuminousSettings,
|
||||
"ollama": OllamaSettings,
|
||||
};
|
||||
|
||||
/**
|
||||
* Add new model provider to the AvailableLLMs list. Also adds the respective ModelSettings schema and rate limit.
|
||||
* @param {*} name The name of the provider, to use in the dropdown menu and default name. Must be unique.
|
||||
* @param {*} emoji The emoji to use for the provider. Optional.
|
||||
* @param {*} emoji The emoji to use for the provider. Optional.
|
||||
* @param {*} models A list of models the user can select from this provider. Optional.
|
||||
* @param {*} rate_limit
|
||||
* @param {*} settings_schema
|
||||
* @param {*} rate_limit
|
||||
* @param {*} settings_schema
|
||||
*/
|
||||
export const setCustomProvider = (name, emoji, models, rate_limit, settings_schema, llmProviderList) => {
|
||||
if (typeof emoji === 'string' && (emoji.length === 0 || emoji.length > 2))
|
||||
@ -1028,13 +1128,13 @@ export const setCustomProvider = (name, emoji, models, rate_limit, settings_sche
|
||||
new_provider.emoji = emoji || '✨';
|
||||
|
||||
// Each LLM *model* must have a unique name. To avoid name collisions, for custom providers,
|
||||
// the full LLM model name is a path, __custom/<provider_name>/<submodel name>
|
||||
// the full LLM model name is a path, __custom/<provider_name>/<submodel name>
|
||||
// If there's no submodel, it's just __custom/<provider_name>.
|
||||
const base_model = `__custom/${name}/`;
|
||||
new_provider.base_model = base_model;
|
||||
new_provider.model = base_model + ((Array.isArray(models) && models.length > 0) ? `${models[0]}` : '');
|
||||
|
||||
// Build the settings form schema for this new custom provider
|
||||
// Build the settings form schema for this new custom provider
|
||||
let compiled_schema = {
|
||||
fullName: `${name} (custom provider)`,
|
||||
schema: {
|
||||
@ -1069,7 +1169,7 @@ export const setCustomProvider = (name, emoji, models, rate_limit, settings_sche
|
||||
"default": models[0],
|
||||
};
|
||||
compiled_schema.uiSchema["model"] = {
|
||||
"ui:help": `Defaults to ${models[0]}`
|
||||
"ui:help": `Defaults to ${models[0]}`
|
||||
};
|
||||
}
|
||||
|
||||
@ -1083,13 +1183,13 @@ export const setCustomProvider = (name, emoji, models, rate_limit, settings_sche
|
||||
const default_temp = compiled_schema?.schema?.properties?.temperature?.default;
|
||||
if (default_temp !== undefined)
|
||||
new_provider.temp = default_temp;
|
||||
|
||||
|
||||
// Add the built provider and its settings to the global lookups:
|
||||
let AvailableLLMs = useStore.getState().AvailableLLMs;
|
||||
const prev_provider_idx = AvailableLLMs.findIndex((d) => d.name === name);
|
||||
if (prev_provider_idx > -1)
|
||||
AvailableLLMs[prev_provider_idx] = new_provider;
|
||||
else
|
||||
else
|
||||
AvailableLLMs.push(new_provider);
|
||||
ModelSettings[base_model] = compiled_schema;
|
||||
|
||||
@ -1134,7 +1234,7 @@ export const postProcessFormData = (settingsSpec, formData) => {
|
||||
else
|
||||
new_data[key] = formData[key];
|
||||
});
|
||||
|
||||
|
||||
return new_data;
|
||||
};
|
||||
|
||||
|
9
chainforge/react-server/src/PromptNode.js
vendored
9
chainforge/react-server/src/PromptNode.js
vendored
@ -604,6 +604,11 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
}
|
||||
else if (json.responses && json.errors) {
|
||||
|
||||
// Remove progress bars
|
||||
setProgress(undefined);
|
||||
setProgressAnimated(false);
|
||||
debounce(() => {}, 1)(); // erase any pending debounces
|
||||
|
||||
// Store and log responses (if any)
|
||||
if (json.responses) {
|
||||
setJSONResponses(json.responses);
|
||||
@ -646,9 +651,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
setStatus('ready');
|
||||
setContChatToggleDisabled(false);
|
||||
|
||||
// Remove progress bars
|
||||
setProgress(undefined);
|
||||
setProgressAnimated(true);
|
||||
// Remove individual progress rings
|
||||
llmListContainer?.current?.resetLLMItemsProgress();
|
||||
|
||||
// Save prompt text so we remember what prompt we have responses cache'd for:
|
||||
|
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* A list of all model APIs natively supported by ChainForge.
|
||||
/**
|
||||
* A list of all model APIs natively supported by ChainForge.
|
||||
*/
|
||||
export type LLM = string | NativeLLM;
|
||||
export enum NativeLLM {
|
||||
@ -75,6 +75,7 @@ export enum NativeLLM {
|
||||
// A special flag for a user-defined HuggingFace model endpoint.
|
||||
// The actual model name will be passed as a param to the LLM call function.
|
||||
HF_OTHER = "Other (HuggingFace)",
|
||||
Ollama = "ollama",
|
||||
}
|
||||
|
||||
/**
|
||||
@ -88,6 +89,7 @@ export enum LLMProvider {
|
||||
Google = "google",
|
||||
HuggingFace = "hf",
|
||||
Aleph_Alpha = "alephalpha",
|
||||
Ollama = "ollama",
|
||||
Custom = "__custom",
|
||||
}
|
||||
|
||||
@ -112,9 +114,11 @@ export function getProvider(llm: LLM): LLMProvider | undefined {
|
||||
return LLMProvider.Anthropic;
|
||||
else if (llm_name?.startsWith('Aleph_Alpha'))
|
||||
return LLMProvider.Aleph_Alpha;
|
||||
else if (llm_name?.startsWith('Ollama'))
|
||||
return LLMProvider.Ollama;
|
||||
else if (llm.toString().startsWith('__custom/'))
|
||||
return LLMProvider.Custom;
|
||||
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
@ -145,8 +149,8 @@ export let RATE_LIMITS: { [key in LLM]?: [number, number] } = {
|
||||
|
||||
/** Equivalent to a Python enum's .name property */
|
||||
export function getEnumName(enumObject: any, enumValue: any): string | undefined {
|
||||
for (const key in enumObject)
|
||||
if (enumObject[key] === enumValue)
|
||||
for (const key in enumObject)
|
||||
if (enumObject[key] === enumValue)
|
||||
return key;
|
||||
return undefined;
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ export function APP_IS_RUNNING_LOCALLY(): boolean {
|
||||
// @ts-ignore
|
||||
_APP_IS_RUNNING_LOCALLY = location.hostname === "localhost" || location.hostname === "127.0.0.1" || location.hostname === "0.0.0.0" || location.hostname === "" || window.__CF_HOSTNAME !== undefined;
|
||||
} catch (e) {
|
||||
// ReferenceError --window or location does not exist.
|
||||
// ReferenceError --window or location does not exist.
|
||||
// We must not be running client-side in a browser, in this case (e.g., we are running a Node.js server)
|
||||
_APP_IS_RUNNING_LOCALLY = false;
|
||||
}
|
||||
@ -56,7 +56,7 @@ export function APP_IS_RUNNING_LOCALLY(): boolean {
|
||||
}
|
||||
|
||||
/**
|
||||
* Equivalent to a 'fetch' call, but routes it to the backend Flask server in
|
||||
* Equivalent to a 'fetch' call, but routes it to the backend Flask server in
|
||||
* case we are running a local server and prefer to not deal with CORS issues making API calls client-side.
|
||||
*/
|
||||
async function route_fetch(url: string, method: string, headers: Dict, body: Dict) {
|
||||
@ -89,6 +89,10 @@ function get_environ(key: string): string | undefined {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function appendEndSlashIfMissing(path: string) {
|
||||
return path + ((path[path.length-1] === '/') ? "" : "/");
|
||||
}
|
||||
|
||||
let DALAI_MODEL: string | undefined;
|
||||
let DALAI_RESPONSE: Dict | undefined;
|
||||
|
||||
@ -127,18 +131,18 @@ export function set_api_keys(api_keys: StringDict): void {
|
||||
/**
|
||||
* Construct an OpenAI format chat history for sending off to an OpenAI API call.
|
||||
* @param prompt The next prompt (user message) to append.
|
||||
* @param chat_history The prior turns of the chat, ending with the AI assistants' turn.
|
||||
* @param chat_history The prior turns of the chat, ending with the AI assistants' turn.
|
||||
* @param system_msg Optional; the system message to use if none is present in chat_history. (Ignored if chat_history already has a sys message.)
|
||||
*/
|
||||
function construct_openai_chat_history(prompt: string, chat_history: ChatHistory | undefined, system_msg: string): ChatHistory {
|
||||
const prompt_msg: ChatMessage = { role: 'user', content: prompt };
|
||||
if (chat_history !== undefined && chat_history.length > 0) {
|
||||
if (chat_history[0].role === 'system') {
|
||||
// In this case, the system_msg is ignored because the prior history already contains one.
|
||||
// In this case, the system_msg is ignored because the prior history already contains one.
|
||||
return chat_history.concat([prompt_msg]);
|
||||
} else {
|
||||
// In this case, there's no system message that starts the prior history, so inject one:
|
||||
// NOTE: We might reach this scenario if we chain output of a non-OpenAI chat model into an OpenAI model.
|
||||
// NOTE: We might reach this scenario if we chain output of a non-OpenAI chat model into an OpenAI model.
|
||||
return [{"role": "system", "content": system_msg}].concat(chat_history).concat([prompt_msg]);
|
||||
}
|
||||
} else return [
|
||||
@ -148,8 +152,8 @@ function construct_openai_chat_history(prompt: string, chat_history: ChatHistory
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls OpenAI models via OpenAI's API.
|
||||
@returns raw query and response JSON dicts.
|
||||
* Calls OpenAI models via OpenAI's API.
|
||||
@returns raw query and response JSON dicts.
|
||||
*/
|
||||
export async function call_chatgpt(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
|
||||
if (!OPENAI_API_KEY)
|
||||
@ -197,7 +201,7 @@ export async function call_chatgpt(prompt: string, model: LLM, n: number = 1, te
|
||||
// Create call to text completions model
|
||||
openai_call = openai.createCompletion.bind(openai);
|
||||
query['prompt'] = prompt;
|
||||
} else {
|
||||
} else {
|
||||
// Create call to chat model
|
||||
openai_call = openai.createChatCompletion.bind(openai);
|
||||
|
||||
@ -218,14 +222,14 @@ export async function call_chatgpt(prompt: string, model: LLM, n: number = 1, te
|
||||
console.log(error?.message || error);
|
||||
throw new Error(error?.message || error);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return [query, response];
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls OpenAI models hosted on Microsoft Azure services.
|
||||
* Returns raw query and response JSON dicts.
|
||||
* Returns raw query and response JSON dicts.
|
||||
*
|
||||
* NOTE: It is recommended to set an environment variables AZURE_OPENAI_KEY and AZURE_OPENAI_ENDPOINT
|
||||
*/
|
||||
@ -234,7 +238,7 @@ export async function call_azure_openai(prompt: string, model: LLM, n: number =
|
||||
throw new Error("Could not find an Azure OpenAPI Key to use. Double-check that your key is set in Settings or in your local environment.");
|
||||
if (!AZURE_OPENAI_ENDPOINT)
|
||||
throw new Error("Could not find an Azure OpenAI Endpoint to use. Double-check that your endpoint is set in Settings or in your local environment.");
|
||||
|
||||
|
||||
const deployment_name: string = params?.deployment_name;
|
||||
const model_type: string = params?.model_type;
|
||||
if (!deployment_name)
|
||||
@ -258,7 +262,7 @@ export async function call_azure_openai(prompt: string, model: LLM, n: number =
|
||||
delete params?.system_msg;
|
||||
delete params?.model_type;
|
||||
delete params?.deployment_name;
|
||||
|
||||
|
||||
// Setup the args for the query
|
||||
let query: Dict = {
|
||||
n: n,
|
||||
@ -285,7 +289,7 @@ export async function call_azure_openai(prompt: string, model: LLM, n: number =
|
||||
throw new Error(error?.message || error);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return [query, response];
|
||||
}
|
||||
|
||||
@ -294,7 +298,7 @@ export async function call_azure_openai(prompt: string, model: LLM, n: number =
|
||||
Returns raw query and response JSON dicts.
|
||||
|
||||
Unique parameters:
|
||||
- custom_prompt_wrapper: Anthropic models expect prompts in form "\n\nHuman: ${prompt}\n\nAssistant". If you wish to
|
||||
- custom_prompt_wrapper: Anthropic models expect prompts in form "\n\nHuman: ${prompt}\n\nAssistant". If you wish to
|
||||
explore custom prompt wrappers that deviate, write a python Template that maps from 'prompt' to custom wrapper.
|
||||
If set to None, defaults to Anthropic's suggested prompt wrapper.
|
||||
- max_tokens_to_sample: A maximum number of tokens to generate before stopping.
|
||||
@ -316,7 +320,7 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
|
||||
if (params?.custom_prompt_wrapper !== undefined)
|
||||
delete params.custom_prompt_wrapper;
|
||||
|
||||
// Required non-standard params
|
||||
// Required non-standard params
|
||||
const max_tokens_to_sample = params?.max_tokens_to_sample || 1024;
|
||||
const stop_sequences = params?.stop_sequences || [ANTHROPIC_HUMAN_PROMPT];
|
||||
|
||||
@ -394,7 +398,7 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls a Google PaLM/Gemini model, based on the model selection from the user.
|
||||
* Calls a Google PaLM/Gemini model, based on the model selection from the user.
|
||||
* Returns raw query and response JSON dicts.
|
||||
*/
|
||||
export async function call_google_ai(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
|
||||
@ -407,7 +411,7 @@ export async function call_google_ai(prompt: string, model: LLM, n: number = 1,
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls a Google PaLM model.
|
||||
* Calls a Google PaLM model.
|
||||
* Returns raw query and response JSON dicts.
|
||||
*/
|
||||
export async function call_google_palm(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
|
||||
@ -415,7 +419,7 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
|
||||
throw new Error("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local environment.");
|
||||
const is_chat_model = model.toString().includes('chat');
|
||||
|
||||
// Required non-standard params
|
||||
// Required non-standard params
|
||||
const max_output_tokens = params?.max_output_tokens || 800;
|
||||
const chat_history = params?.chat_history;
|
||||
delete params?.chat_history;
|
||||
@ -438,7 +442,7 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
|
||||
if (is_chat_model && query.stop_sequences !== undefined)
|
||||
delete query.stop_sequences;
|
||||
|
||||
// For some reason Google needs to be special and have its API params be different names --camel or snake-case
|
||||
// For some reason Google needs to be special and have its API params be different names --camel or snake-case
|
||||
// --depending on if it's the Python or Node JS API. ChainForge needs a consistent name, so we must convert snake to camel:
|
||||
const casemap = {
|
||||
safety_settings: 'safetySettings',
|
||||
@ -475,14 +479,14 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
|
||||
query.prompt = palm_chat_context;
|
||||
} else {
|
||||
query.prompt = { messages: [{content: prompt}] };
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Text completions
|
||||
query.prompt = { text: prompt };
|
||||
}
|
||||
|
||||
console.log(`Calling Google PaLM model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`);
|
||||
|
||||
|
||||
// Call the correct model client
|
||||
const method = is_chat_model ? 'generateMessage' : 'generateText';
|
||||
const url = `https://generativelanguage.googleapis.com/v1beta2/models/${model}:${method}?key=${GOOGLE_PALM_API_KEY}`;
|
||||
@ -510,7 +514,7 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
|
||||
completion.candidates = new Array(n).fill({'author': '1', 'content':block_error_msg});
|
||||
}
|
||||
|
||||
// Weirdly, google ignores candidate_count if temperature is 0.
|
||||
// Weirdly, google ignores candidate_count if temperature is 0.
|
||||
// We have to check for this and manually append the n-1 responses:
|
||||
if (n > 1 && completion.candidates?.length === 1) {
|
||||
completion.candidates = new Array(n).fill(completion.candidates[0]);
|
||||
@ -524,16 +528,14 @@ export async function call_google_gemini(prompt: string, model: LLM, n: number =
|
||||
throw new Error("Could not find an API key for Google Gemini models. Double-check that your API key is set in Settings or in your local environment.");
|
||||
|
||||
// calling the correct model client
|
||||
|
||||
console.log('call_google_gemini: ');
|
||||
model = NativeLLM.GEMINI_PRO;
|
||||
|
||||
const genAI = new GoogleGenerativeAI(GOOGLE_PALM_API_KEY);
|
||||
const gemini_model = genAI.getGenerativeModel({model: model.toString()});
|
||||
|
||||
|
||||
// removing chat for now. by default chat is supported
|
||||
|
||||
// Required non-standard params
|
||||
// Required non-standard params
|
||||
const max_output_tokens = params?.max_output_tokens || 1000;
|
||||
const chat_history = params?.chat_history;
|
||||
delete params?.chat_history;
|
||||
@ -546,7 +548,7 @@ export async function call_google_gemini(prompt: string, model: LLM, n: number =
|
||||
...params,
|
||||
};
|
||||
|
||||
// For some reason Google needs to be special and have its API params be different names --camel or snake-case
|
||||
// For some reason Google needs to be special and have its API params be different names --camel or snake-case
|
||||
// --depending on if it's the Python or Node JS API. ChainForge needs a consistent name, so we must convert snake to camel:
|
||||
const casemap = {
|
||||
safety_settings: 'safetySettings',
|
||||
@ -583,7 +585,7 @@ export async function call_google_gemini(prompt: string, model: LLM, n: number =
|
||||
// Chat completions
|
||||
if (chat_history !== undefined && chat_history.length > 0) {
|
||||
// Carry over any chat history, converting OpenAI formatted chat history to Google PaLM:
|
||||
|
||||
|
||||
let gemini_messages: GeminiChatMessage[] = [];
|
||||
for (const chat_msg of chat_history) {
|
||||
if (chat_msg.role === 'system') {
|
||||
@ -608,7 +610,7 @@ export async function call_google_gemini(prompt: string, model: LLM, n: number =
|
||||
generationConfig: gen_Config,
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
const chatResult = await chat.sendMessage(prompt);
|
||||
const chatResponse = await chatResult.response;
|
||||
const response = {
|
||||
@ -692,14 +694,14 @@ export async function call_huggingface(prompt: string, model: LLM, n: number = 1
|
||||
const using_custom_model_endpoint: boolean = param_exists(params?.custom_model);
|
||||
|
||||
let headers: StringDict = {'Content-Type': 'application/json'};
|
||||
// For HuggingFace, technically, the API keys are optional.
|
||||
// For HuggingFace, technically, the API keys are optional.
|
||||
if (HUGGINGFACE_API_KEY !== undefined)
|
||||
headers.Authorization = `Bearer ${HUGGINGFACE_API_KEY}`;
|
||||
|
||||
|
||||
// Inference Endpoints for text completion models has the same call,
|
||||
// except the endpoint is an entire URL. Detect this:
|
||||
const url = (using_custom_model_endpoint && params.custom_model.startsWith('https:')) ?
|
||||
params.custom_model :
|
||||
// except the endpoint is an entire URL. Detect this:
|
||||
const url = (using_custom_model_endpoint && params.custom_model.startsWith('https:')) ?
|
||||
params.custom_model :
|
||||
`https://api-inference.huggingface.co/models/${using_custom_model_endpoint ? params.custom_model.trim() : model}`;
|
||||
|
||||
let responses: Array<Dict> = [];
|
||||
@ -710,8 +712,8 @@ export async function call_huggingface(prompt: string, model: LLM, n: number = 1
|
||||
let curr_text = prompt;
|
||||
while (curr_cont <= num_continuations) {
|
||||
const inputs = (model_type === 'chat')
|
||||
? ({ text: curr_text,
|
||||
past_user_inputs: hf_chat_hist.past_user_inputs,
|
||||
? ({ text: curr_text,
|
||||
past_user_inputs: hf_chat_hist.past_user_inputs,
|
||||
generated_responses: hf_chat_hist.generated_responses })
|
||||
: curr_text;
|
||||
|
||||
@ -722,12 +724,12 @@ export async function call_huggingface(prompt: string, model: LLM, n: number = 1
|
||||
body: JSON.stringify({inputs: inputs, parameters: query, options: options}),
|
||||
});
|
||||
const result = await response.json();
|
||||
|
||||
|
||||
// HuggingFace sometimes gives us an error, for instance if a model is loading.
|
||||
// It returns this as an 'error' key in the response:
|
||||
if (result?.error !== undefined)
|
||||
throw new Error(result.error);
|
||||
else if ((model_type !== 'chat' && (!Array.isArray(result) || result.length !== 1)) ||
|
||||
else if ((model_type !== 'chat' && (!Array.isArray(result) || result.length !== 1)) ||
|
||||
(model_type === 'chat' && (Array.isArray(result) || !result || result?.generated_text === undefined)))
|
||||
throw new Error("Result of HuggingFace API call is in unexpected format:" + JSON.stringify(result));
|
||||
|
||||
@ -750,12 +752,12 @@ export async function call_huggingface(prompt: string, model: LLM, n: number = 1
|
||||
export async function call_alephalpha(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
|
||||
if (!ALEPH_ALPHA_API_KEY)
|
||||
throw Error("Could not find an API key for Aleph Alpha models. Double-check that your API key is set in Settings or in your local environment.");
|
||||
|
||||
|
||||
const url: string = 'https://api.aleph-alpha.com/complete';
|
||||
let headers: StringDict = {'Content-Type': 'application/json', 'Accept': 'application/json'};
|
||||
if (ALEPH_ALPHA_API_KEY !== undefined)
|
||||
headers.Authorization = `Bearer ${ALEPH_ALPHA_API_KEY}`;
|
||||
|
||||
|
||||
let data = JSON.stringify({
|
||||
"model": model.toString(),
|
||||
"prompt": prompt,
|
||||
@ -770,7 +772,7 @@ export async function call_alephalpha(prompt: string, model: LLM, n: number = 1,
|
||||
temperature: temperature,
|
||||
...params, // 'the rest' of the settings, passed from the front-end settings
|
||||
};
|
||||
|
||||
|
||||
const response = await fetch(url, {
|
||||
headers: headers,
|
||||
method: "POST",
|
||||
@ -782,14 +784,71 @@ export async function call_alephalpha(prompt: string, model: LLM, n: number = 1,
|
||||
return [query, responses];
|
||||
}
|
||||
|
||||
export async function call_ollama_provider(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
|
||||
let url: string = appendEndSlashIfMissing(params?.ollama_url);
|
||||
const ollama_model: string = params?.ollamaModel.toString();
|
||||
const model_type: string = params?.model_type ?? "text";
|
||||
const system_msg: string = params?.system_msg ?? "";
|
||||
const chat_history: ChatHistory | undefined = params?.chat_history;
|
||||
|
||||
// Cleanup
|
||||
for (const name of ["ollamaModel", "ollama_url", "model_type", "system_msg", "chat_history"])
|
||||
if (name in params) delete params[name];
|
||||
|
||||
// FIXME: Ollama doesn't support batch inference, but llama.cpp does so it will eventually
|
||||
// For now, we send n requests and then wait for all of them to finish
|
||||
let query: Dict = {
|
||||
model: ollama_model,
|
||||
stream: false,
|
||||
temperature: temperature,
|
||||
...params, // 'the rest' of the settings, passed from the front-end settings
|
||||
};
|
||||
|
||||
// If the model type is explicitly or implicitly set to "chat", pass chat history instead:
|
||||
if (model_type === 'chat' || /[-:](chat)/.test(ollama_model)) {
|
||||
// Construct chat history and pass to query payload
|
||||
query.messages = construct_openai_chat_history(prompt, chat_history, system_msg);
|
||||
url += "chat";
|
||||
} else {
|
||||
// Text-only models
|
||||
query.prompt = prompt;
|
||||
url += "generate";
|
||||
}
|
||||
|
||||
console.log(`Calling Ollama API at ${url} for model '${ollama_model}' with prompt '${prompt}' n=${n} times. Please be patient...`);
|
||||
|
||||
// Call Ollama API
|
||||
let resps : Response[] = [];
|
||||
for (let i = 0; i < n; i++) {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
body: JSON.stringify(query),
|
||||
});
|
||||
resps.push(response);
|
||||
}
|
||||
|
||||
const parse_response = (body: string) => {
|
||||
const json = JSON.parse(body);
|
||||
if (json.message) // chat models
|
||||
return {generated_text: json.message.content}
|
||||
else // text-only models
|
||||
return {generated_text: json.response};
|
||||
};
|
||||
|
||||
const responses = await Promise.all(resps.map((resp) => resp.text())).then((responses) => {
|
||||
return responses.map((response) => parse_response(response));
|
||||
});
|
||||
|
||||
return [query, responses];
|
||||
}
|
||||
|
||||
async function call_custom_provider(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
|
||||
if (!APP_IS_RUNNING_LOCALLY())
|
||||
throw new Error("The ChainForge app does not appear to be running locally. You can only call custom model providers if you are running ChainForge on your local machine, from a Flask app.")
|
||||
|
||||
// The model to call is in format:
|
||||
// __custom/<provider_name>/<submodel name>
|
||||
// It may also exclude the final tag.
|
||||
// __custom/<provider_name>/<submodel name>
|
||||
// It may also exclude the final tag.
|
||||
// We extract the provider name (this is the name used in the Python backend's `ProviderRegistry`) and optionally, the submodel name
|
||||
const provider_path = model.substring(9);
|
||||
const provider_name = provider_path.substring(0, provider_path.indexOf('/'));
|
||||
@ -798,9 +857,9 @@ async function call_custom_provider(prompt: string, model: LLM, n: number = 1, t
|
||||
let responses = [];
|
||||
const query = { prompt, model, temperature, ...params };
|
||||
|
||||
// Call the custom provider n times
|
||||
// Call the custom provider n times
|
||||
while (responses.length < n) {
|
||||
let {response, error} = await call_flask_backend('callCustomProvider',
|
||||
let {response, error} = await call_flask_backend('callCustomProvider',
|
||||
{ 'name': provider_name,
|
||||
'params': {
|
||||
prompt, model: submodel_name, temperature, ...params
|
||||
@ -823,7 +882,7 @@ export async function call_llm(llm: LLM, prompt: string, n: number, temperature:
|
||||
// Get the correct API call for the given LLM:
|
||||
let call_api: LLMAPICall | undefined;
|
||||
let llm_provider: LLMProvider = getProvider(llm);
|
||||
|
||||
|
||||
if (llm_provider === undefined)
|
||||
throw new Error(`Language model ${llm} is not supported.`);
|
||||
|
||||
@ -841,20 +900,22 @@ export async function call_llm(llm: LLM, prompt: string, n: number, temperature:
|
||||
call_api = call_huggingface;
|
||||
else if (llm_provider === LLMProvider.Aleph_Alpha)
|
||||
call_api = call_alephalpha;
|
||||
else if (llm_provider === LLMProvider.Ollama)
|
||||
call_api = call_ollama_provider;
|
||||
else if (llm_provider === LLMProvider.Custom)
|
||||
call_api = call_custom_provider;
|
||||
|
||||
|
||||
return call_api(prompt, llm, n, temperature, params);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Extracts the relevant portion of a OpenAI chat response.
|
||||
* Extracts the relevant portion of a OpenAI chat response.
|
||||
* Note that chat choice objects can now include 'function_call' and a blank 'content' response.
|
||||
* This method detects a 'function_call's presence, prepends [[FUNCTION]] and converts the function call into JS format.
|
||||
* This method detects a 'function_call's presence, prepends [[FUNCTION]] and converts the function call into JS format.
|
||||
*/
|
||||
function _extract_openai_chat_choice_content(choice: Dict): string {
|
||||
if (choice['finish_reason'] === 'function_call' ||
|
||||
if (choice['finish_reason'] === 'function_call' ||
|
||||
('function_call' in choice['message'] && choice['message']['function_call'].length > 0)) {
|
||||
const func = choice['message']['function_call'];
|
||||
return '[[FUNCTION]] ' + func['name'] + func['arguments'].toString();
|
||||
@ -865,7 +926,7 @@ function _extract_openai_chat_choice_content(choice: Dict): string {
|
||||
|
||||
/**
|
||||
* Extracts the text part of a response JSON from ChatGPT. If there is more
|
||||
* than 1 response (e.g., asking the LLM to generate multiple responses),
|
||||
* than 1 response (e.g., asking the LLM to generate multiple responses),
|
||||
* this produces a list of all returned responses.
|
||||
*/
|
||||
function _extract_chatgpt_responses(response: Dict): Array<string> {
|
||||
@ -874,7 +935,7 @@ function _extract_chatgpt_responses(response: Dict): Array<string> {
|
||||
|
||||
/**
|
||||
* Extracts the text part of a response JSON from OpenAI completions models like Davinci. If there are more
|
||||
* than 1 response (e.g., asking the LLM to generate multiple responses),
|
||||
* than 1 response (e.g., asking the LLM to generate multiple responses),
|
||||
* this produces a list of all returned responses.
|
||||
*/
|
||||
function _extract_openai_completion_responses(response: Dict): Array<string> {
|
||||
@ -939,7 +1000,14 @@ function _extract_huggingface_responses(response: Array<Dict>): Array<string>{
|
||||
* Extracts the text part of a Aleph Alpha text completion.
|
||||
*/
|
||||
function _extract_alephalpha_responses(response: Dict): Array<string> {
|
||||
return response.map((r: string) => r.trim());
|
||||
return response.map((r: string) => r.trim());
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the text part of a Ollama text completion.
|
||||
*/
|
||||
function _extract_ollama_responses(response: Array<Dict>): Array<string> {
|
||||
return response.map((r: Object) => r["generated_text"].trim());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -965,20 +1033,22 @@ export function extract_responses(response: Array<string | Dict> | Dict, llm: LL
|
||||
return _extract_anthropic_responses(response as Dict[]);
|
||||
case LLMProvider.HuggingFace:
|
||||
return _extract_huggingface_responses(response as Dict[]);
|
||||
case LLMProvider.Aleph_Alpha:
|
||||
return _extract_alephalpha_responses(response);
|
||||
case LLMProvider.Aleph_Alpha:
|
||||
return _extract_alephalpha_responses(response);
|
||||
case LLMProvider.Ollama:
|
||||
return _extract_ollama_responses(response as Dict[]);
|
||||
default:
|
||||
if (Array.isArray(response) && response.length > 0 && typeof response[0] === 'string')
|
||||
return response as string[];
|
||||
else
|
||||
return response as string[];
|
||||
else
|
||||
throw new Error(`No method defined to extract responses for LLM ${llm}.`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Marge the 'responses' and 'raw_response' properties of two LLMResponseObjects,
|
||||
* keeping all the other params from the second argument (llm, query, etc).
|
||||
*
|
||||
* keeping all the other params from the second argument (llm, query, etc).
|
||||
*
|
||||
* If one object is undefined or null, returns the object that is defined, unaltered.
|
||||
*/
|
||||
export function merge_response_objs(resp_obj_A: LLMResponseObject | undefined, resp_obj_B: LLMResponseObject | undefined): LLMResponseObject | undefined {
|
||||
@ -1102,7 +1172,7 @@ export const toStandardResponseFormat = (r) => {
|
||||
// Check if the current browser window/tab is 'active' or not
|
||||
export const browserTabIsActive = () => {
|
||||
try {
|
||||
const visible = document.visibilityState === 'visible';
|
||||
const visible = document.visibilityState === 'visible';
|
||||
return visible;
|
||||
} catch(e) {
|
||||
console.error(e);
|
||||
@ -1136,7 +1206,7 @@ export const extractLLMLookup = (input_data) => {
|
||||
|
||||
export const removeLLMTagFromMetadata = (metavars) => {
|
||||
if (!('__LLM_key' in metavars))
|
||||
return metavars;
|
||||
return metavars;
|
||||
let mcopy = JSON.parse(JSON.stringify(metavars));
|
||||
delete mcopy['__LLM_key'];
|
||||
return mcopy;
|
||||
|
51
chainforge/react-server/src/store.js
vendored
51
chainforge/react-server/src/store.js
vendored
@ -17,7 +17,7 @@ const initialLLMColors = {};
|
||||
/** The color palette used for displaying info about different LLMs. */
|
||||
const llmColorPalette = ['#44d044', '#f1b933', '#e46161', '#8888f9', '#33bef0', '#bb55f9', '#f7ee45', '#f955cd', '#26e080', '#2654e0', '#7d8191', '#bea5d1'];
|
||||
|
||||
/** The color palette used for displaying variations of prompts and prompt variables (non-LLM differences).
|
||||
/** The color palette used for displaying variations of prompts and prompt variables (non-LLM differences).
|
||||
* Distinct from the LLM color palette in order to avoid confusion around what the data means.
|
||||
* Palette adapted from https://lospec.com/palette-list/sness by Space Sandwich */
|
||||
const varColorPalette = ['#0bdb52', '#e71861', '#7161de', '#f6d714', '#80bedb', '#ffa995', '#a9b399', '#dc6f0f', '#8d022e', '#138e7d', '#c6924f', '#885818', '#616b6d'];
|
||||
@ -40,7 +40,10 @@ export let initLLMProviders = [
|
||||
{ name: "Azure OpenAI", emoji: "🔷", model: "azure-openai", base_model: "azure-openai", temp: 1.0 },
|
||||
];
|
||||
if (APP_IS_RUNNING_LOCALLY()) {
|
||||
initLLMProviders.push({ name: "Dalai (Alpaca.7B)", emoji: "🦙", model: "alpaca.7B", base_model: "dalai", temp: 0.5 });
|
||||
initLLMProviders.push({ name: "Ollama", emoji: "🦙", model: "ollama", base_model: "ollama", temp: 1.0 });
|
||||
// -- Deprecated provider --
|
||||
// initLLMProviders.push({ name: "Dalai (Alpaca.7B)", emoji: "🦙", model: "alpaca.7B", base_model: "dalai", temp: 0.5 });
|
||||
// -------------------------
|
||||
}
|
||||
|
||||
// A global store of variables, used for maintaining state
|
||||
@ -79,7 +82,7 @@ const useStore = create((set, get) => ({
|
||||
llmColors: initialLLMColors,
|
||||
|
||||
// Gets the color for the model named 'llm_name' in llmColors; returns undefined if not found.
|
||||
getColorForLLM: (llm_name) => {
|
||||
getColorForLLM: (llm_name) => {
|
||||
const colors = get().llmColors;
|
||||
if (llm_name in colors)
|
||||
return colors[llm_name];
|
||||
@ -87,7 +90,7 @@ const useStore = create((set, get) => ({
|
||||
},
|
||||
|
||||
// Gets the color for the specified LLM. If not found, generates a new (ideally unique) color
|
||||
// and saves it to the llmColors dict.
|
||||
// and saves it to the llmColors dict.
|
||||
getColorForLLMAndSetIfNotFound: (llm_name) => {
|
||||
let color = get().getColorForLLM(llm_name);
|
||||
if (color) return color;
|
||||
@ -97,7 +100,7 @@ const useStore = create((set, get) => ({
|
||||
},
|
||||
|
||||
// Generates a unique color not yet used in llmColors (unless # colors is large)
|
||||
genUniqueLLMColor: () => {
|
||||
genUniqueLLMColor: () => {
|
||||
const used_colors = new Set(Object.values(get().llmColors));
|
||||
const get_unused_color = (all_colors) => {
|
||||
for (let i = 0; i < all_colors.length; i++) {
|
||||
@ -107,7 +110,7 @@ const useStore = create((set, get) => ({
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
|
||||
|
||||
let unique_color = get_unused_color(llmColorPalette);
|
||||
if (unique_color) return unique_color;
|
||||
|
||||
@ -116,7 +119,7 @@ const useStore = create((set, get) => ({
|
||||
unique_color = get_unused_color(varColorPalette);
|
||||
if (unique_color) return unique_color;
|
||||
|
||||
// If we've reached here, we've run out of all predefined colors.
|
||||
// If we've reached here, we've run out of all predefined colors.
|
||||
// Choose one to repeat, at random:
|
||||
const all_colors = llmColorPalette.concat(varColorPalette);
|
||||
return all_colors[Math.floor(Math.random() * all_colors.length)];
|
||||
@ -133,7 +136,7 @@ const useStore = create((set, get) => ({
|
||||
llmColors: []
|
||||
});
|
||||
},
|
||||
|
||||
|
||||
inputEdgesForNode: (sourceNodeId) => {
|
||||
return get().edges.filter(e => e.target == sourceNodeId);
|
||||
},
|
||||
@ -164,7 +167,7 @@ const useStore = create((set, get) => ({
|
||||
const src_col = columns.find(c => c.header === sourceHandleKey);
|
||||
if (src_col !== undefined) {
|
||||
|
||||
// Construct a lookup table from column key to header name,
|
||||
// Construct a lookup table from column key to header name,
|
||||
// as the 'metavars' dict should be keyed by column *header*, not internal key:
|
||||
let col_header_lookup = {};
|
||||
columns.forEach(c => {
|
||||
@ -203,10 +206,10 @@ const useStore = create((set, get) => ({
|
||||
if (Array.isArray(src_node.data["fields"]))
|
||||
return src_node.data["fields"];
|
||||
else {
|
||||
// We have to filter over a special 'fields_visibility' prop, which
|
||||
// We have to filter over a special 'fields_visibility' prop, which
|
||||
// can select what fields get output:
|
||||
if ("fields_visibility" in src_node.data)
|
||||
return Object.values(filterDict(src_node.data["fields"],
|
||||
return Object.values(filterDict(src_node.data["fields"],
|
||||
fid => src_node.data["fields_visibility"][fid] !== false));
|
||||
else // return all field values
|
||||
return Object.values(src_node.data["fields"]);
|
||||
@ -224,7 +227,7 @@ const useStore = create((set, get) => ({
|
||||
// Get the types of nodes attached immediately as input to the given node
|
||||
getImmediateInputNodeTypes: (_targetHandles, node_id) => {
|
||||
const getNode = get().getNode;
|
||||
const edges = get().edges;
|
||||
const edges = get().edges;
|
||||
let inputNodeTypes = [];
|
||||
edges.forEach(e => {
|
||||
if (e.target == node_id && _targetHandles.includes(e.targetHandle)) {
|
||||
@ -242,9 +245,9 @@ const useStore = create((set, get) => ({
|
||||
// Functions/data from the store:
|
||||
const getNode = get().getNode;
|
||||
const output = get().output;
|
||||
const edges = get().edges;
|
||||
const edges = get().edges;
|
||||
|
||||
// Helper function to store collected data in dict:
|
||||
// Helper function to store collected data in dict:
|
||||
const store_data = (_texts, _varname, _data) => {
|
||||
if (_varname in _data)
|
||||
_data[_varname] = _data[_varname].concat(_texts);
|
||||
@ -257,12 +260,12 @@ const useStore = create((set, get) => ({
|
||||
const get_outputs = (varnames, nodeId, var_history) => {
|
||||
varnames.forEach(varname => {
|
||||
// Check for duplicate variable names
|
||||
if (var_history.has(String(varname).toLowerCase()))
|
||||
if (var_history.has(String(varname).toLowerCase()))
|
||||
throw new DuplicateVariableNameError(varname);
|
||||
|
||||
|
||||
// Add to unique name tally
|
||||
var_history.add(String(varname).toLowerCase());
|
||||
|
||||
|
||||
// Find the relevant edge(s):
|
||||
edges.forEach(e => {
|
||||
if (e.target == nodeId && e.targetHandle == varname) {
|
||||
@ -278,7 +281,7 @@ const useStore = create((set, get) => ({
|
||||
// Save the list of strings from the pulled output under the var 'varname'
|
||||
store_data(out, varname, pulled_data);
|
||||
}
|
||||
|
||||
|
||||
// Get any vars that the output depends on, and recursively collect those outputs as well:
|
||||
const n_vars = getNode(e.source).data.vars;
|
||||
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
|
||||
@ -295,15 +298,15 @@ const useStore = create((set, get) => ({
|
||||
/**
|
||||
* Sets select 'data' properties for node 'id'. This updates global state, and forces re-renders. Use sparingly.
|
||||
* @param {*} id The id of the node to set 'data' properties for.
|
||||
* @param {*} data_props The properties to set on the node's 'data'.
|
||||
* @param {*} data_props The properties to set on the node's 'data'.
|
||||
*/
|
||||
setDataPropsForNode: (id, data_props) => {
|
||||
set({
|
||||
nodes: (nds =>
|
||||
nodes: (nds =>
|
||||
nds.map(n => {
|
||||
if (n.id === id) {
|
||||
for (const key of Object.keys(data_props))
|
||||
n.data[key] = data_props[key];
|
||||
for (const key of Object.keys(data_props))
|
||||
n.data[key] = data_props[key];
|
||||
n.data = deepcopy(n.data);
|
||||
}
|
||||
return n;
|
||||
@ -393,10 +396,10 @@ const useStore = create((set, get) => ({
|
||||
});
|
||||
},
|
||||
onConnect: (connection) => {
|
||||
|
||||
|
||||
// Get the target node information
|
||||
const target = get().getNode(connection.target);
|
||||
|
||||
|
||||
if (target.type === 'vis' || target.type === 'inspect' || target.type === 'simpleval') {
|
||||
get().setDataPropsForNode(target.id, { input: connection.source });
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user