Gemini model support and raise error when detecting duplicate var names (v0.2.8.1) (#195)

* Raise error after detecting duplicate variable names (#190)

* Raise error for duplicate variable name

* Created base error class

* Simplified error classes. Made just one `DuplicateVariableNameError` that takes in variable name to have a hard-coded error message

---------

Co-authored-by: Kayla Zethelyn <kaylazethelyn@college.harvard.edu>

* Adding support for Google's Gemini-Pro model.  (#194)

* Refined duplicate var error check code

* Tidy up duplicate var name alerts and error handling, and err message

* Rebuild react and update package version

---------

Co-authored-by: Kayla Z <77540029+kamazet@users.noreply.github.com>
Co-authored-by: Kayla Zethelyn <kaylazethelyn@college.harvard.edu>
Co-authored-by: Priyan Vaithilingam <priyanmuthu@gmail.com>
This commit is contained in:
ianarawjo 2023-12-19 16:14:34 -05:00 committed by GitHub
parent ce583a216c
commit d6e850e724
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 268 additions and 43 deletions

View File

@ -1,15 +1,15 @@
{
"files": {
"main.css": "/static/css/main.847ce933.css",
"main.js": "/static/js/main.72d00c7c.js",
"main.js": "/static/js/main.84c4d13c.js",
"static/js/787.4c72bb55.chunk.js": "/static/js/787.4c72bb55.chunk.js",
"index.html": "/index.html",
"main.847ce933.css.map": "/static/css/main.847ce933.css.map",
"main.72d00c7c.js.map": "/static/js/main.72d00c7c.js.map",
"main.84c4d13c.js.map": "/static/js/main.84c4d13c.js.map",
"787.4c72bb55.chunk.js.map": "/static/js/787.4c72bb55.chunk.js.map"
},
"entrypoints": [
"static/css/main.847ce933.css",
"static/js/main.72d00c7c.js"
"static/js/main.84c4d13c.js"
]
}

View File

@ -1 +1 @@
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.72d00c7c.js"></script><link href="/static/css/main.847ce933.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.84c4d13c.js"></script><link href="/static/css/main.847ce933.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>

View File

@ -115,6 +115,23 @@ License: MIT
/*! sheetjs (C) 2013-present SheetJS -- http://sheetjs.com */
/**
* @license
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* @license
* Lodash <https://lodash.com/>

View File

@ -15,6 +15,7 @@
"@emoji-mart/data": "^1.1.2",
"@emoji-mart/react": "^1.1.1",
"@google-ai/generativelanguage": "^0.2.0",
"@google/generative-ai": "^0.1.3",
"@mantine/core": "^6.0.9",
"@mantine/dates": "^6.0.13",
"@mantine/dropzone": "^6.0.19",
@ -66,7 +67,7 @@
"net": "^1.0.2",
"net-browserify": "^0.2.4",
"node-fetch": "^2.6.11",
"openai": "^3.3.0",
"openai": "~3.3.0",
"os-browserify": "^0.3.0",
"papaparse": "^5.4.1",
"path-browserify": "^1.0.1",
@ -3154,6 +3155,14 @@
"node": ">=12.0.0"
}
},
"node_modules/@google/generative-ai": {
"version": "0.1.3",
"resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.1.3.tgz",
"integrity": "sha512-Cm4uJX1sKarpm1mje/MiOIinM7zdUUrQp/5/qGPAgznbdd/B9zup5ehT6c1qGqycFcSopTA1J1HpqHS5kJR8hQ==",
"engines": {
"node": ">=18.0.0"
}
},
"node_modules/@grpc/grpc-js": {
"version": "1.8.21",
"resolved": "https://registry.npmjs.org/@grpc/grpc-js/-/grpc-js-1.8.21.tgz",

View File

@ -13,6 +13,7 @@
"@emoji-mart/data": "^1.1.2",
"@emoji-mart/react": "^1.1.1",
"@google-ai/generativelanguage": "^0.2.0",
"@google/generative-ai": "^0.1.3",
"@mantine/core": "^6.0.9",
"@mantine/dates": "^6.0.13",
"@mantine/dropzone": "^6.0.19",

View File

@ -245,8 +245,8 @@ return (
<br />
<TextInput
label="Google PaLM API Key"
placeholder="Paste your Google PaLM API key here"
label="Google AI API Key (PaLM/GEMINI)"
placeholder="Paste your Google PaLM/GEMINI API key here"
{...form.getInputProps('Google')}
/>
<br />

View File

@ -346,7 +346,7 @@ const ClaudeSettings = {
};
const PaLM2Settings = {
fullName: "PaLM (Google)",
fullName: "Google AI Models (Gemini & PaLM)",
schema: {
"type": "object",
"required": [
@ -357,17 +357,18 @@ const PaLM2Settings = {
"type": "string",
"title": "Nickname",
"description": "Unique identifier to appear in ChainForge. Keep it short.",
"default": "chat-bison"
"default": "Gemini"
},
"model": {
"type": "string",
"title": "Model",
"description": "Select a PaLM model to query. For more details on the differences, see the Google PaLM API documentation.",
"enum": ["text-bison-001", "chat-bison-001"],
"default": "chat-bison-001",
"enum": ["gemini-pro", "text-bison-001", "chat-bison-001"],
"default": "gemini-pro",
"shortname_map": {
"text-bison-001": "PaLM2-text",
"chat-bison-001": "PaLM2-chat",
"gemini-pro": "Gemini",
}
},
"temperature": {
@ -417,7 +418,7 @@ const PaLM2Settings = {
"ui:autofocus": true
},
"model": {
"ui:help": "Defaults to chat-bison."
"ui:help": "Defaults to gemini-pro."
},
"temperature": {
"ui:help": "Defaults to 0.5.",
@ -446,6 +447,7 @@ const PaLM2Settings = {
}
};
const DalaiModelSettings = {
fullName: "Dalai-hosted local model (Alpaca, Llama)",
schema: {

View File

@ -14,8 +14,8 @@ import { escapeBraces } from './backend/template';
import ChatHistoryView from './ChatHistoryView';
import InspectFooter from './InspectFooter';
import { countNumLLMs, setsAreEqual, getLLMsInPulledInputData } from './backend/utils';
import LLMResponseInspector from './LLMResponseInspector';
import LLMResponseInspectorDrawer from './LLMResponseInspectorDrawer';
import { DuplicateVariableNameError } from './backend/errors';
const getUniqueLLMMetavarKey = (responses) => {
const metakeys = new Set(responses.map(resp_obj => Object.keys(resp_obj.metavars)).flat());
@ -134,7 +134,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
const showResponseInspector = useCallback(() => {
if (inspectModal && inspectModal.current && jsonResponses) {
inspectModal.current.trigger();
inspectModal.current?.trigger();
setUninspectedResponses(false);
}
}, [inspectModal, jsonResponses]);
@ -175,7 +175,13 @@ const PromptNode = ({ data, id, type: node_type }) => {
const handleOnConnect = useCallback(() => {
if (node_type === 'chat') return; // always show when chat node
// Re-pull data and update show cont toggle:
updateShowContToggle(pullInputData(templateVars, id));
try {
const pulled_data = pullInputData(templateVars, id);
updateShowContToggle(pulled_data);
} catch (err) {
// alertModal.current?.trigger(err.message);
console.error(err);
}
}, [templateVars, id, pullInputData, updateShowContToggle]);
const refreshTemplateHooks = useCallback((text) => {
@ -183,7 +189,13 @@ const PromptNode = ({ data, id, type: node_type }) => {
const found_template_vars = new Set(extractBracketedSubstrings(text)); // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this}
if (!setsAreEqual(found_template_vars, new Set(templateVars))) {
if (node_type !== 'chat') updateShowContToggle(pullInputData(found_template_vars, id));
if (node_type !== 'chat') {
try {
updateShowContToggle(pullInputData(found_template_vars, id));
} catch (err) {
console.error(err);
}
}
setTemplateVars(Array.from(found_template_vars));
}
}, [setTemplateVars, templateVars, pullInputData, id]);
@ -295,17 +307,23 @@ const PromptNode = ({ data, id, type: node_type }) => {
const [promptPreviews, setPromptPreviews] = useState([]);
const handlePreviewHover = () => {
// Pull input data and prompt
const pulled_vars = pullInputData(templateVars, id);
updateShowContToggle(pulled_vars);
try {
const pulled_vars = pullInputData(templateVars, id);
updateShowContToggle(pulled_vars);
fetch_from_backend('generatePrompts', {
prompt: promptText,
vars: pulled_vars,
}).then(prompts => {
setPromptPreviews(prompts.map(p => (new PromptInfo(p.toString()))));
});
fetch_from_backend('generatePrompts', {
prompt: promptText,
vars: pulled_vars,
}).then(prompts => {
setPromptPreviews(prompts.map(p => (new PromptInfo(p.toString()))));
});
pullInputChats();
pullInputChats();
} catch (err) {
// soft fail
console.error(err);
setPromptPreviews([]);
}
};
// On hover over the 'Run' button, request how many responses are required and update the tooltip. Soft fails.
@ -330,7 +348,15 @@ const PromptNode = ({ data, id, type: node_type }) => {
}
// Pull the input data
const pulled_vars = pullInputData(templateVars, id);
let pulled_vars = {};
try {
pulled_vars = pullInputData(templateVars, id);
} catch (err) {
setRunTooltip('Error: Duplicate variables detected.');
console.error(err);
return; // early exit
}
updateShowContToggle(pulled_vars);
// Whether to continue with only the prior LLMs, for each value in vars dict
@ -455,7 +481,16 @@ const PromptNode = ({ data, id, type: node_type }) => {
}
// Pull the data to fill in template input variables, if any
const pulled_data = pullInputData(templateVars, id);
let pulled_data = {};
try {
// Try to pull inputs
pulled_data = pullInputData(templateVars, id);
} catch (err) {
alertModal.current?.trigger(err.message);
console.error(err);
return; // early exit
}
const prompt_template = promptText;
// Whether to continue with only the prior LLMs, for each value in vars dict

View File

@ -0,0 +1,9 @@
export class DuplicateVariableNameError extends Error {
variable: string;
constructor(variable: string) {
super();
this.name = "DuplicateVariableNameError";
this.message = "You have multiple template variables with the same name, {" + variable + "}. Duplicate names in the same chain is not allowed. To fix, ensure that all template variable names are unique across a chain.";
}
}

View File

@ -50,6 +50,7 @@ export enum NativeLLM {
// Google models
PaLM2_Text_Bison = "text-bison-001", // it's really models/text-bison-001, but that's confusing
PaLM2_Chat_Bison = "chat-bison-001",
GEMINI_PRO = "gemini-pro",
// Aleph Alpha
Aleph_Alpha_Luminous_Extended = "luminous-extended",
@ -101,7 +102,7 @@ export function getProvider(llm: LLM): LLMProvider | undefined {
return LLMProvider.OpenAI;
else if (llm_name?.startsWith('Azure'))
return LLMProvider.Azure_OpenAI;
else if (llm_name?.startsWith('PaLM2'))
else if (llm_name?.startsWith('PaLM2') || llm_name?.startsWith('GEMINI'))
return LLMProvider.Google;
else if (llm_name?.startsWith('Dalai'))
return LLMProvider.Dalai;

View File

@ -46,6 +46,16 @@ export interface PaLMChatContext {
examples?: Dict[],
}
export interface GeminiChatMessage {
role: string,
parts: string,
}
export interface GeminiChatContext {
history: GeminiChatMessage[],
}
/** HuggingFace conversation models format */
export interface HuggingFaceChatHistory {
past_user_inputs: string[],

View File

@ -3,14 +3,15 @@
// from string import Template
// from chainforge.promptengine.models import LLM
import { LLM, LLMProvider, getProvider } from './models';
import { Dict, StringDict, LLMAPICall, LLMResponseObject, ChatHistory, ChatMessage, PaLMChatMessage, PaLMChatContext, HuggingFaceChatHistory } from './typing';
import { LLM, LLMProvider, NativeLLM, getProvider } from './models';
import { Dict, StringDict, LLMAPICall, LLMResponseObject, ChatHistory, ChatMessage, PaLMChatMessage, PaLMChatContext, HuggingFaceChatHistory, GeminiChatContext, GeminiChatMessage } from './typing';
import { env as process_env } from 'process';
import { StringTemplate } from './template';
/* LLM API SDKs */
import { Configuration as OpenAIConfig, OpenAIApi } from "openai";
import { OpenAIClient as AzureOpenAIClient, AzureKeyCredential } from "@azure/openai";
import { GoogleGenerativeAI } from "@google/generative-ai";
const ANTHROPIC_HUMAN_PROMPT = "\n\nHuman:";
const ANTHROPIC_AI_PROMPT = "\n\nAssistant:";
@ -392,14 +393,26 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
return [query, responses];
}
/**
* Calls a Google PaLM/Gemini model, based on the model selection from the user.
* Returns raw query and response JSON dicts.
*/
export async function call_google_ai(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
switch(model) {
case NativeLLM.GEMINI_PRO:
return call_google_gemini(prompt, model, n, temperature, params);
default:
return call_google_palm(prompt, model, n, temperature, params);
}
}
/**
* Calls a Google PaLM model.
Returns raw query and response JSON dicts.
* Returns raw query and response JSON dicts.
*/
export async function call_google_palm(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
if (!GOOGLE_PALM_API_KEY)
throw new Error("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local environment.");
const is_chat_model = model.toString().includes('chat');
// Required non-standard params
@ -506,6 +519,109 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
return [query, completion];
}
export async function call_google_gemini(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
if (!GOOGLE_PALM_API_KEY)
throw new Error("Could not find an API key for Google Gemini models. Double-check that your API key is set in Settings or in your local environment.");
// calling the correct model client
console.log('call_google_gemini: ');
model = NativeLLM.GEMINI_PRO;
const genAI = new GoogleGenerativeAI(GOOGLE_PALM_API_KEY);
const gemini_model = genAI.getGenerativeModel({model: model.toString()});
// removing chat for now. by default chat is supported
// Required non-standard params
const max_output_tokens = params?.max_output_tokens || 1000;
const chat_history = params?.chat_history;
delete params?.chat_history;
let query: Dict = {
model: `models/${model}`,
candidate_count: n,
temperature: temperature,
max_output_tokens: max_output_tokens,
...params,
};
// For some reason Google needs to be special and have its API params be different names --camel or snake-case
// --depending on if it's the Python or Node JS API. ChainForge needs a consistent name, so we must convert snake to camel:
const casemap = {
safety_settings: 'safetySettings',
stop_sequences: 'stopSequences',
candidate_count: 'candidateCount',
max_output_tokens: 'maxOutputTokens',
top_p: 'topP',
top_k: 'topK',
};
let gen_Config = {};
Object.entries(casemap).forEach(([key, val]) => {
if (key in query) {
gen_Config[val] = query[key];
query[val] = query[key];
delete query[key];
}
});
// Gemini only supports candidate_count of 1
gen_Config['candidateCount'] = 1;
// By default for topK is none, and topP is 1.0
if ('topK' in gen_Config && gen_Config['topK'] === -1) {
delete gen_Config['topK'];
}
if ('topP' in gen_Config && gen_Config['topP'] === -1) {
gen_Config['topP'] = 1.0;
}
let gemini_chat_context: GeminiChatContext = { history: [] };
// Chat completions
if (chat_history !== undefined && chat_history.length > 0) {
// Carry over any chat history, converting OpenAI formatted chat history to Google PaLM:
let gemini_messages: GeminiChatMessage[] = [];
for (const chat_msg of chat_history) {
if (chat_msg.role === 'system') {
// Carry the system message over as PaLM's chat 'context':
gemini_messages.push({ role: 'model', parts: chat_msg.content });
} else if (chat_msg.role === 'user') {
gemini_messages.push({ role: 'user', parts: chat_msg.content });
} else
gemini_messages.push({ role: 'model', parts: chat_msg.content });
}
gemini_chat_context.history = gemini_messages;
}
console.log(`Calling Google Gemini model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`);
let responses: Array<Dict> = [];
while(responses.length < n) {
const chat = gemini_model.startChat(
{
history: gemini_chat_context.history,
generationConfig: gen_Config,
},
);
const chatResult = await chat.sendMessage(prompt);
const chatResponse = await chatResult.response;
const response = {
text: chatResponse.text(),
candidates: chatResponse.candidates,
promptFeedback: chatResponse.promptFeedback,
}
responses.push(response);
}
return [query, responses];
}
export async function call_dalai(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
if (APP_IS_RUNNING_LOCALLY()) {
// Try to call Dalai server, through Flask:
@ -716,7 +832,7 @@ export async function call_llm(llm: LLM, prompt: string, n: number, temperature:
else if (llm_provider === LLMProvider.Azure_OpenAI)
call_api = call_azure_openai;
else if (llm_provider === LLMProvider.Google)
call_api = call_google_palm;
call_api = call_google_ai;
else if (llm_provider === LLMProvider.Dalai)
call_api = call_dalai;
else if (llm_provider === LLMProvider.Anthropic)
@ -778,6 +894,16 @@ function _extract_openai_responses(response: Dict): Array<string> {
return _extract_openai_completion_responses(response);
}
function _extract_google_ai_responses(response: Dict, llm: LLM | string): Array<string> {
switch(llm) {
case NativeLLM.GEMINI_PRO:
return _extract_gemini_responses(response as Array<Dict>);
default:
return _extract_palm_responses(response);
}
}
/**
* Extracts the text part of a 'Completion' object from Google PaLM2 `generate_text` or `chat`.
*
@ -788,6 +914,13 @@ function _extract_palm_responses(completion: Dict): Array<string> {
return completion['candidates'].map((c: Dict) => c.output || c.content);
}
/**
* Extracts the text part of a 'EnhancedGenerateContentResponse' object from Google Gemini `sendChat` or `chat`.
*/
function _extract_gemini_responses(completions: Array<Dict>): Array<string> {
return completions.map((c: Dict) => c.text);
}
/**
* Extracts the text part of an Anthropic text completion.
*/
@ -825,7 +958,7 @@ export function extract_responses(response: Array<string | Dict> | Dict, llm: LL
case LLMProvider.Azure_OpenAI:
return _extract_openai_responses(response);
case LLMProvider.Google:
return _extract_palm_responses(response);
return _extract_google_ai_responses(response as Dict, llm);
case LLMProvider.Dalai:
return [response.toString()];
case LLMProvider.Anthropic:

View File

@ -7,6 +7,7 @@ import {
import { escapeBraces } from './backend/template';
import { filterDict } from './backend/utils';
import { APP_IS_RUNNING_LOCALLY } from './backend/utils';
import { DuplicateVariableNameError } from './backend/errors';
// Initial project settings
const initialAPIKeys = {};
@ -33,7 +34,7 @@ export let initLLMProviders = [
{ name: "GPT3.5", emoji: "🤖", model: "gpt-3.5-turbo", base_model: "gpt-3.5-turbo", temp: 1.0 }, // The base_model designates what settings form will be used, and must be unique.
{ name: "GPT4", emoji: "🥵", model: "gpt-4", base_model: "gpt-4", temp: 1.0 },
{ name: "Claude", emoji: "📚", model: "claude-2", base_model: "claude-v1", temp: 0.5 },
{ name: "PaLM2", emoji: "🦬", model: "chat-bison-001", base_model: "palm2-bison", temp: 0.7 },
{ name: "Gemini", emoji: "♊", model: "gemini-pro", base_model: "palm2-bison", temp: 0.7 },
{ name: "HuggingFace", emoji: "🤗", model: "tiiuae/falcon-7b-instruct", base_model: "hf", temp: 1.0 },
{ name: "Aleph Alpha", emoji: "💡", model: "luminous-base", base_model: "luminous-base", temp: 0.0 },
{ name: "Azure OpenAI", emoji: "🔷", model: "azure-openai", base_model: "azure-openai", temp: 1.0 },
@ -253,8 +254,15 @@ const useStore = create((set, get) => ({
// Pull data from each source recursively:
const pulled_data = {};
const get_outputs = (varnames, nodeId) => {
const get_outputs = (varnames, nodeId, var_history) => {
varnames.forEach(varname => {
// Check for duplicate variable names
if (var_history.has(String(varname).toLowerCase()))
throw new DuplicateVariableNameError(varname);
// Add to unique name tally
var_history.add(String(varname).toLowerCase());
// Find the relevant edge(s):
edges.forEach(e => {
if (e.target == nodeId && e.targetHandle == varname) {
@ -274,12 +282,12 @@ const useStore = create((set, get) => ({
// Get any vars that the output depends on, and recursively collect those outputs as well:
const n_vars = getNode(e.source).data.vars;
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
get_outputs(n_vars, e.source);
get_outputs(n_vars, e.source, var_history);
}
});
});
};
get_outputs(_targetHandles, node_id);
get_outputs(_targetHandles, node_id, new Set());
return pulled_data;
},

View File

@ -6,7 +6,7 @@ def readme():
setup(
name='chainforge',
version='0.2.8.0',
version='0.2.8.1',
packages=find_packages(),
author="Ian Arawjo",
description="A Visual Programming Environment for Prompt Engineering",