mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 08:16:37 +00:00
Gemini model support and raise error when detecting duplicate var names (v0.2.8.1) (#195)
* Raise error after detecting duplicate variable names (#190) * Raise error for duplicate variable name * Created base error class * Simplified error classes. Made just one `DuplicateVariableNameError` that takes in variable name to have a hard-coded error message --------- Co-authored-by: Kayla Zethelyn <kaylazethelyn@college.harvard.edu> * Adding support for Google's Gemini-Pro model. (#194) * Refined duplicate var error check code * Tidy up duplicate var name alerts and error handling, and err message * Rebuild react and update package version --------- Co-authored-by: Kayla Z <77540029+kamazet@users.noreply.github.com> Co-authored-by: Kayla Zethelyn <kaylazethelyn@college.harvard.edu> Co-authored-by: Priyan Vaithilingam <priyanmuthu@gmail.com>
This commit is contained in:
parent
ce583a216c
commit
d6e850e724
@ -1,15 +1,15 @@
|
||||
{
|
||||
"files": {
|
||||
"main.css": "/static/css/main.847ce933.css",
|
||||
"main.js": "/static/js/main.72d00c7c.js",
|
||||
"main.js": "/static/js/main.84c4d13c.js",
|
||||
"static/js/787.4c72bb55.chunk.js": "/static/js/787.4c72bb55.chunk.js",
|
||||
"index.html": "/index.html",
|
||||
"main.847ce933.css.map": "/static/css/main.847ce933.css.map",
|
||||
"main.72d00c7c.js.map": "/static/js/main.72d00c7c.js.map",
|
||||
"main.84c4d13c.js.map": "/static/js/main.84c4d13c.js.map",
|
||||
"787.4c72bb55.chunk.js.map": "/static/js/787.4c72bb55.chunk.js.map"
|
||||
},
|
||||
"entrypoints": [
|
||||
"static/css/main.847ce933.css",
|
||||
"static/js/main.72d00c7c.js"
|
||||
"static/js/main.84c4d13c.js"
|
||||
]
|
||||
}
|
@ -1 +1 @@
|
||||
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.72d00c7c.js"></script><link href="/static/css/main.847ce933.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
|
||||
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.84c4d13c.js"></script><link href="/static/css/main.847ce933.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
|
File diff suppressed because one or more lines are too long
@ -115,6 +115,23 @@ License: MIT
|
||||
|
||||
/*! sheetjs (C) 2013-present SheetJS -- http://sheetjs.com */
|
||||
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2023 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @license
|
||||
* Lodash <https://lodash.com/>
|
File diff suppressed because one or more lines are too long
11
chainforge/react-server/package-lock.json
generated
11
chainforge/react-server/package-lock.json
generated
@ -15,6 +15,7 @@
|
||||
"@emoji-mart/data": "^1.1.2",
|
||||
"@emoji-mart/react": "^1.1.1",
|
||||
"@google-ai/generativelanguage": "^0.2.0",
|
||||
"@google/generative-ai": "^0.1.3",
|
||||
"@mantine/core": "^6.0.9",
|
||||
"@mantine/dates": "^6.0.13",
|
||||
"@mantine/dropzone": "^6.0.19",
|
||||
@ -66,7 +67,7 @@
|
||||
"net": "^1.0.2",
|
||||
"net-browserify": "^0.2.4",
|
||||
"node-fetch": "^2.6.11",
|
||||
"openai": "^3.3.0",
|
||||
"openai": "~3.3.0",
|
||||
"os-browserify": "^0.3.0",
|
||||
"papaparse": "^5.4.1",
|
||||
"path-browserify": "^1.0.1",
|
||||
@ -3154,6 +3155,14 @@
|
||||
"node": ">=12.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@google/generative-ai": {
|
||||
"version": "0.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.1.3.tgz",
|
||||
"integrity": "sha512-Cm4uJX1sKarpm1mje/MiOIinM7zdUUrQp/5/qGPAgznbdd/B9zup5ehT6c1qGqycFcSopTA1J1HpqHS5kJR8hQ==",
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@grpc/grpc-js": {
|
||||
"version": "1.8.21",
|
||||
"resolved": "https://registry.npmjs.org/@grpc/grpc-js/-/grpc-js-1.8.21.tgz",
|
||||
|
@ -13,6 +13,7 @@
|
||||
"@emoji-mart/data": "^1.1.2",
|
||||
"@emoji-mart/react": "^1.1.1",
|
||||
"@google-ai/generativelanguage": "^0.2.0",
|
||||
"@google/generative-ai": "^0.1.3",
|
||||
"@mantine/core": "^6.0.9",
|
||||
"@mantine/dates": "^6.0.13",
|
||||
"@mantine/dropzone": "^6.0.19",
|
||||
|
@ -245,8 +245,8 @@ return (
|
||||
|
||||
<br />
|
||||
<TextInput
|
||||
label="Google PaLM API Key"
|
||||
placeholder="Paste your Google PaLM API key here"
|
||||
label="Google AI API Key (PaLM/GEMINI)"
|
||||
placeholder="Paste your Google PaLM/GEMINI API key here"
|
||||
{...form.getInputProps('Google')}
|
||||
/>
|
||||
<br />
|
||||
|
@ -346,7 +346,7 @@ const ClaudeSettings = {
|
||||
};
|
||||
|
||||
const PaLM2Settings = {
|
||||
fullName: "PaLM (Google)",
|
||||
fullName: "Google AI Models (Gemini & PaLM)",
|
||||
schema: {
|
||||
"type": "object",
|
||||
"required": [
|
||||
@ -357,17 +357,18 @@ const PaLM2Settings = {
|
||||
"type": "string",
|
||||
"title": "Nickname",
|
||||
"description": "Unique identifier to appear in ChainForge. Keep it short.",
|
||||
"default": "chat-bison"
|
||||
"default": "Gemini"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"title": "Model",
|
||||
"description": "Select a PaLM model to query. For more details on the differences, see the Google PaLM API documentation.",
|
||||
"enum": ["text-bison-001", "chat-bison-001"],
|
||||
"default": "chat-bison-001",
|
||||
"enum": ["gemini-pro", "text-bison-001", "chat-bison-001"],
|
||||
"default": "gemini-pro",
|
||||
"shortname_map": {
|
||||
"text-bison-001": "PaLM2-text",
|
||||
"chat-bison-001": "PaLM2-chat",
|
||||
"gemini-pro": "Gemini",
|
||||
}
|
||||
},
|
||||
"temperature": {
|
||||
@ -417,7 +418,7 @@ const PaLM2Settings = {
|
||||
"ui:autofocus": true
|
||||
},
|
||||
"model": {
|
||||
"ui:help": "Defaults to chat-bison."
|
||||
"ui:help": "Defaults to gemini-pro."
|
||||
},
|
||||
"temperature": {
|
||||
"ui:help": "Defaults to 0.5.",
|
||||
@ -446,6 +447,7 @@ const PaLM2Settings = {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
const DalaiModelSettings = {
|
||||
fullName: "Dalai-hosted local model (Alpaca, Llama)",
|
||||
schema: {
|
||||
|
65
chainforge/react-server/src/PromptNode.js
vendored
65
chainforge/react-server/src/PromptNode.js
vendored
@ -14,8 +14,8 @@ import { escapeBraces } from './backend/template';
|
||||
import ChatHistoryView from './ChatHistoryView';
|
||||
import InspectFooter from './InspectFooter';
|
||||
import { countNumLLMs, setsAreEqual, getLLMsInPulledInputData } from './backend/utils';
|
||||
import LLMResponseInspector from './LLMResponseInspector';
|
||||
import LLMResponseInspectorDrawer from './LLMResponseInspectorDrawer';
|
||||
import { DuplicateVariableNameError } from './backend/errors';
|
||||
|
||||
const getUniqueLLMMetavarKey = (responses) => {
|
||||
const metakeys = new Set(responses.map(resp_obj => Object.keys(resp_obj.metavars)).flat());
|
||||
@ -134,7 +134,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
|
||||
const showResponseInspector = useCallback(() => {
|
||||
if (inspectModal && inspectModal.current && jsonResponses) {
|
||||
inspectModal.current.trigger();
|
||||
inspectModal.current?.trigger();
|
||||
setUninspectedResponses(false);
|
||||
}
|
||||
}, [inspectModal, jsonResponses]);
|
||||
@ -175,7 +175,13 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
const handleOnConnect = useCallback(() => {
|
||||
if (node_type === 'chat') return; // always show when chat node
|
||||
// Re-pull data and update show cont toggle:
|
||||
updateShowContToggle(pullInputData(templateVars, id));
|
||||
try {
|
||||
const pulled_data = pullInputData(templateVars, id);
|
||||
updateShowContToggle(pulled_data);
|
||||
} catch (err) {
|
||||
// alertModal.current?.trigger(err.message);
|
||||
console.error(err);
|
||||
}
|
||||
}, [templateVars, id, pullInputData, updateShowContToggle]);
|
||||
|
||||
const refreshTemplateHooks = useCallback((text) => {
|
||||
@ -183,7 +189,13 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
const found_template_vars = new Set(extractBracketedSubstrings(text)); // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this}
|
||||
|
||||
if (!setsAreEqual(found_template_vars, new Set(templateVars))) {
|
||||
if (node_type !== 'chat') updateShowContToggle(pullInputData(found_template_vars, id));
|
||||
if (node_type !== 'chat') {
|
||||
try {
|
||||
updateShowContToggle(pullInputData(found_template_vars, id));
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
}
|
||||
}
|
||||
setTemplateVars(Array.from(found_template_vars));
|
||||
}
|
||||
}, [setTemplateVars, templateVars, pullInputData, id]);
|
||||
@ -295,17 +307,23 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
const [promptPreviews, setPromptPreviews] = useState([]);
|
||||
const handlePreviewHover = () => {
|
||||
// Pull input data and prompt
|
||||
const pulled_vars = pullInputData(templateVars, id);
|
||||
updateShowContToggle(pulled_vars);
|
||||
try {
|
||||
const pulled_vars = pullInputData(templateVars, id);
|
||||
updateShowContToggle(pulled_vars);
|
||||
|
||||
fetch_from_backend('generatePrompts', {
|
||||
prompt: promptText,
|
||||
vars: pulled_vars,
|
||||
}).then(prompts => {
|
||||
setPromptPreviews(prompts.map(p => (new PromptInfo(p.toString()))));
|
||||
});
|
||||
fetch_from_backend('generatePrompts', {
|
||||
prompt: promptText,
|
||||
vars: pulled_vars,
|
||||
}).then(prompts => {
|
||||
setPromptPreviews(prompts.map(p => (new PromptInfo(p.toString()))));
|
||||
});
|
||||
|
||||
pullInputChats();
|
||||
pullInputChats();
|
||||
} catch (err) {
|
||||
// soft fail
|
||||
console.error(err);
|
||||
setPromptPreviews([]);
|
||||
}
|
||||
};
|
||||
|
||||
// On hover over the 'Run' button, request how many responses are required and update the tooltip. Soft fails.
|
||||
@ -330,7 +348,15 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
}
|
||||
|
||||
// Pull the input data
|
||||
const pulled_vars = pullInputData(templateVars, id);
|
||||
let pulled_vars = {};
|
||||
try {
|
||||
pulled_vars = pullInputData(templateVars, id);
|
||||
} catch (err) {
|
||||
setRunTooltip('Error: Duplicate variables detected.');
|
||||
console.error(err);
|
||||
return; // early exit
|
||||
}
|
||||
|
||||
updateShowContToggle(pulled_vars);
|
||||
|
||||
// Whether to continue with only the prior LLMs, for each value in vars dict
|
||||
@ -455,7 +481,16 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
}
|
||||
|
||||
// Pull the data to fill in template input variables, if any
|
||||
const pulled_data = pullInputData(templateVars, id);
|
||||
let pulled_data = {};
|
||||
try {
|
||||
// Try to pull inputs
|
||||
pulled_data = pullInputData(templateVars, id);
|
||||
} catch (err) {
|
||||
alertModal.current?.trigger(err.message);
|
||||
console.error(err);
|
||||
return; // early exit
|
||||
}
|
||||
|
||||
const prompt_template = promptText;
|
||||
|
||||
// Whether to continue with only the prior LLMs, for each value in vars dict
|
||||
|
9
chainforge/react-server/src/backend/errors.ts
Normal file
9
chainforge/react-server/src/backend/errors.ts
Normal file
@ -0,0 +1,9 @@
|
||||
export class DuplicateVariableNameError extends Error {
|
||||
variable: string;
|
||||
|
||||
constructor(variable: string) {
|
||||
super();
|
||||
this.name = "DuplicateVariableNameError";
|
||||
this.message = "You have multiple template variables with the same name, {" + variable + "}. Duplicate names in the same chain is not allowed. To fix, ensure that all template variable names are unique across a chain.";
|
||||
}
|
||||
}
|
@ -50,6 +50,7 @@ export enum NativeLLM {
|
||||
// Google models
|
||||
PaLM2_Text_Bison = "text-bison-001", // it's really models/text-bison-001, but that's confusing
|
||||
PaLM2_Chat_Bison = "chat-bison-001",
|
||||
GEMINI_PRO = "gemini-pro",
|
||||
|
||||
// Aleph Alpha
|
||||
Aleph_Alpha_Luminous_Extended = "luminous-extended",
|
||||
@ -101,7 +102,7 @@ export function getProvider(llm: LLM): LLMProvider | undefined {
|
||||
return LLMProvider.OpenAI;
|
||||
else if (llm_name?.startsWith('Azure'))
|
||||
return LLMProvider.Azure_OpenAI;
|
||||
else if (llm_name?.startsWith('PaLM2'))
|
||||
else if (llm_name?.startsWith('PaLM2') || llm_name?.startsWith('GEMINI'))
|
||||
return LLMProvider.Google;
|
||||
else if (llm_name?.startsWith('Dalai'))
|
||||
return LLMProvider.Dalai;
|
||||
|
@ -46,6 +46,16 @@ export interface PaLMChatContext {
|
||||
examples?: Dict[],
|
||||
}
|
||||
|
||||
export interface GeminiChatMessage {
|
||||
role: string,
|
||||
parts: string,
|
||||
}
|
||||
|
||||
export interface GeminiChatContext {
|
||||
history: GeminiChatMessage[],
|
||||
}
|
||||
|
||||
|
||||
/** HuggingFace conversation models format */
|
||||
export interface HuggingFaceChatHistory {
|
||||
past_user_inputs: string[],
|
||||
|
@ -3,14 +3,15 @@
|
||||
// from string import Template
|
||||
|
||||
// from chainforge.promptengine.models import LLM
|
||||
import { LLM, LLMProvider, getProvider } from './models';
|
||||
import { Dict, StringDict, LLMAPICall, LLMResponseObject, ChatHistory, ChatMessage, PaLMChatMessage, PaLMChatContext, HuggingFaceChatHistory } from './typing';
|
||||
import { LLM, LLMProvider, NativeLLM, getProvider } from './models';
|
||||
import { Dict, StringDict, LLMAPICall, LLMResponseObject, ChatHistory, ChatMessage, PaLMChatMessage, PaLMChatContext, HuggingFaceChatHistory, GeminiChatContext, GeminiChatMessage } from './typing';
|
||||
import { env as process_env } from 'process';
|
||||
import { StringTemplate } from './template';
|
||||
|
||||
/* LLM API SDKs */
|
||||
import { Configuration as OpenAIConfig, OpenAIApi } from "openai";
|
||||
import { OpenAIClient as AzureOpenAIClient, AzureKeyCredential } from "@azure/openai";
|
||||
import { GoogleGenerativeAI } from "@google/generative-ai";
|
||||
|
||||
const ANTHROPIC_HUMAN_PROMPT = "\n\nHuman:";
|
||||
const ANTHROPIC_AI_PROMPT = "\n\nAssistant:";
|
||||
@ -392,14 +393,26 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
|
||||
return [query, responses];
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls a Google PaLM/Gemini model, based on the model selection from the user.
|
||||
* Returns raw query and response JSON dicts.
|
||||
*/
|
||||
export async function call_google_ai(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
|
||||
switch(model) {
|
||||
case NativeLLM.GEMINI_PRO:
|
||||
return call_google_gemini(prompt, model, n, temperature, params);
|
||||
default:
|
||||
return call_google_palm(prompt, model, n, temperature, params);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls a Google PaLM model.
|
||||
Returns raw query and response JSON dicts.
|
||||
* Returns raw query and response JSON dicts.
|
||||
*/
|
||||
export async function call_google_palm(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
|
||||
if (!GOOGLE_PALM_API_KEY)
|
||||
throw new Error("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local environment.");
|
||||
|
||||
const is_chat_model = model.toString().includes('chat');
|
||||
|
||||
// Required non-standard params
|
||||
@ -506,6 +519,109 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
|
||||
return [query, completion];
|
||||
}
|
||||
|
||||
export async function call_google_gemini(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
|
||||
if (!GOOGLE_PALM_API_KEY)
|
||||
throw new Error("Could not find an API key for Google Gemini models. Double-check that your API key is set in Settings or in your local environment.");
|
||||
|
||||
// calling the correct model client
|
||||
|
||||
console.log('call_google_gemini: ');
|
||||
model = NativeLLM.GEMINI_PRO;
|
||||
|
||||
const genAI = new GoogleGenerativeAI(GOOGLE_PALM_API_KEY);
|
||||
const gemini_model = genAI.getGenerativeModel({model: model.toString()});
|
||||
|
||||
// removing chat for now. by default chat is supported
|
||||
|
||||
// Required non-standard params
|
||||
const max_output_tokens = params?.max_output_tokens || 1000;
|
||||
const chat_history = params?.chat_history;
|
||||
delete params?.chat_history;
|
||||
|
||||
let query: Dict = {
|
||||
model: `models/${model}`,
|
||||
candidate_count: n,
|
||||
temperature: temperature,
|
||||
max_output_tokens: max_output_tokens,
|
||||
...params,
|
||||
};
|
||||
|
||||
// For some reason Google needs to be special and have its API params be different names --camel or snake-case
|
||||
// --depending on if it's the Python or Node JS API. ChainForge needs a consistent name, so we must convert snake to camel:
|
||||
const casemap = {
|
||||
safety_settings: 'safetySettings',
|
||||
stop_sequences: 'stopSequences',
|
||||
candidate_count: 'candidateCount',
|
||||
max_output_tokens: 'maxOutputTokens',
|
||||
top_p: 'topP',
|
||||
top_k: 'topK',
|
||||
};
|
||||
|
||||
let gen_Config = {};
|
||||
|
||||
Object.entries(casemap).forEach(([key, val]) => {
|
||||
if (key in query) {
|
||||
gen_Config[val] = query[key];
|
||||
query[val] = query[key];
|
||||
delete query[key];
|
||||
}
|
||||
});
|
||||
|
||||
// Gemini only supports candidate_count of 1
|
||||
gen_Config['candidateCount'] = 1;
|
||||
|
||||
// By default for topK is none, and topP is 1.0
|
||||
if ('topK' in gen_Config && gen_Config['topK'] === -1) {
|
||||
delete gen_Config['topK'];
|
||||
}
|
||||
if ('topP' in gen_Config && gen_Config['topP'] === -1) {
|
||||
gen_Config['topP'] = 1.0;
|
||||
}
|
||||
|
||||
let gemini_chat_context: GeminiChatContext = { history: [] };
|
||||
|
||||
// Chat completions
|
||||
if (chat_history !== undefined && chat_history.length > 0) {
|
||||
// Carry over any chat history, converting OpenAI formatted chat history to Google PaLM:
|
||||
|
||||
let gemini_messages: GeminiChatMessage[] = [];
|
||||
for (const chat_msg of chat_history) {
|
||||
if (chat_msg.role === 'system') {
|
||||
// Carry the system message over as PaLM's chat 'context':
|
||||
gemini_messages.push({ role: 'model', parts: chat_msg.content });
|
||||
} else if (chat_msg.role === 'user') {
|
||||
gemini_messages.push({ role: 'user', parts: chat_msg.content });
|
||||
} else
|
||||
gemini_messages.push({ role: 'model', parts: chat_msg.content });
|
||||
}
|
||||
gemini_chat_context.history = gemini_messages;
|
||||
}
|
||||
|
||||
console.log(`Calling Google Gemini model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`);
|
||||
|
||||
let responses: Array<Dict> = [];
|
||||
|
||||
while(responses.length < n) {
|
||||
const chat = gemini_model.startChat(
|
||||
{
|
||||
history: gemini_chat_context.history,
|
||||
generationConfig: gen_Config,
|
||||
},
|
||||
);
|
||||
|
||||
const chatResult = await chat.sendMessage(prompt);
|
||||
const chatResponse = await chatResult.response;
|
||||
const response = {
|
||||
text: chatResponse.text(),
|
||||
candidates: chatResponse.candidates,
|
||||
promptFeedback: chatResponse.promptFeedback,
|
||||
}
|
||||
responses.push(response);
|
||||
}
|
||||
|
||||
return [query, responses];
|
||||
}
|
||||
|
||||
export async function call_dalai(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
|
||||
if (APP_IS_RUNNING_LOCALLY()) {
|
||||
// Try to call Dalai server, through Flask:
|
||||
@ -716,7 +832,7 @@ export async function call_llm(llm: LLM, prompt: string, n: number, temperature:
|
||||
else if (llm_provider === LLMProvider.Azure_OpenAI)
|
||||
call_api = call_azure_openai;
|
||||
else if (llm_provider === LLMProvider.Google)
|
||||
call_api = call_google_palm;
|
||||
call_api = call_google_ai;
|
||||
else if (llm_provider === LLMProvider.Dalai)
|
||||
call_api = call_dalai;
|
||||
else if (llm_provider === LLMProvider.Anthropic)
|
||||
@ -778,6 +894,16 @@ function _extract_openai_responses(response: Dict): Array<string> {
|
||||
return _extract_openai_completion_responses(response);
|
||||
}
|
||||
|
||||
|
||||
function _extract_google_ai_responses(response: Dict, llm: LLM | string): Array<string> {
|
||||
switch(llm) {
|
||||
case NativeLLM.GEMINI_PRO:
|
||||
return _extract_gemini_responses(response as Array<Dict>);
|
||||
default:
|
||||
return _extract_palm_responses(response);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the text part of a 'Completion' object from Google PaLM2 `generate_text` or `chat`.
|
||||
*
|
||||
@ -788,6 +914,13 @@ function _extract_palm_responses(completion: Dict): Array<string> {
|
||||
return completion['candidates'].map((c: Dict) => c.output || c.content);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the text part of a 'EnhancedGenerateContentResponse' object from Google Gemini `sendChat` or `chat`.
|
||||
*/
|
||||
function _extract_gemini_responses(completions: Array<Dict>): Array<string> {
|
||||
return completions.map((c: Dict) => c.text);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the text part of an Anthropic text completion.
|
||||
*/
|
||||
@ -825,7 +958,7 @@ export function extract_responses(response: Array<string | Dict> | Dict, llm: LL
|
||||
case LLMProvider.Azure_OpenAI:
|
||||
return _extract_openai_responses(response);
|
||||
case LLMProvider.Google:
|
||||
return _extract_palm_responses(response);
|
||||
return _extract_google_ai_responses(response as Dict, llm);
|
||||
case LLMProvider.Dalai:
|
||||
return [response.toString()];
|
||||
case LLMProvider.Anthropic:
|
||||
|
16
chainforge/react-server/src/store.js
vendored
16
chainforge/react-server/src/store.js
vendored
@ -7,6 +7,7 @@ import {
|
||||
import { escapeBraces } from './backend/template';
|
||||
import { filterDict } from './backend/utils';
|
||||
import { APP_IS_RUNNING_LOCALLY } from './backend/utils';
|
||||
import { DuplicateVariableNameError } from './backend/errors';
|
||||
|
||||
// Initial project settings
|
||||
const initialAPIKeys = {};
|
||||
@ -33,7 +34,7 @@ export let initLLMProviders = [
|
||||
{ name: "GPT3.5", emoji: "🤖", model: "gpt-3.5-turbo", base_model: "gpt-3.5-turbo", temp: 1.0 }, // The base_model designates what settings form will be used, and must be unique.
|
||||
{ name: "GPT4", emoji: "🥵", model: "gpt-4", base_model: "gpt-4", temp: 1.0 },
|
||||
{ name: "Claude", emoji: "📚", model: "claude-2", base_model: "claude-v1", temp: 0.5 },
|
||||
{ name: "PaLM2", emoji: "🦬", model: "chat-bison-001", base_model: "palm2-bison", temp: 0.7 },
|
||||
{ name: "Gemini", emoji: "♊", model: "gemini-pro", base_model: "palm2-bison", temp: 0.7 },
|
||||
{ name: "HuggingFace", emoji: "🤗", model: "tiiuae/falcon-7b-instruct", base_model: "hf", temp: 1.0 },
|
||||
{ name: "Aleph Alpha", emoji: "💡", model: "luminous-base", base_model: "luminous-base", temp: 0.0 },
|
||||
{ name: "Azure OpenAI", emoji: "🔷", model: "azure-openai", base_model: "azure-openai", temp: 1.0 },
|
||||
@ -253,8 +254,15 @@ const useStore = create((set, get) => ({
|
||||
|
||||
// Pull data from each source recursively:
|
||||
const pulled_data = {};
|
||||
const get_outputs = (varnames, nodeId) => {
|
||||
const get_outputs = (varnames, nodeId, var_history) => {
|
||||
varnames.forEach(varname => {
|
||||
// Check for duplicate variable names
|
||||
if (var_history.has(String(varname).toLowerCase()))
|
||||
throw new DuplicateVariableNameError(varname);
|
||||
|
||||
// Add to unique name tally
|
||||
var_history.add(String(varname).toLowerCase());
|
||||
|
||||
// Find the relevant edge(s):
|
||||
edges.forEach(e => {
|
||||
if (e.target == nodeId && e.targetHandle == varname) {
|
||||
@ -274,12 +282,12 @@ const useStore = create((set, get) => ({
|
||||
// Get any vars that the output depends on, and recursively collect those outputs as well:
|
||||
const n_vars = getNode(e.source).data.vars;
|
||||
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
|
||||
get_outputs(n_vars, e.source);
|
||||
get_outputs(n_vars, e.source, var_history);
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
get_outputs(_targetHandles, node_id);
|
||||
get_outputs(_targetHandles, node_id, new Set());
|
||||
|
||||
return pulled_data;
|
||||
},
|
||||
|
Loading…
x
Reference in New Issue
Block a user