mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 08:16:37 +00:00
Add "Continue w prior LLMs" toggle to the base Prompt Node (#168)
* Add support for "continue w prior LLM" toggle on base Prompt Node * Fix anthropic chat bug * Detect immediate prompt chaining, and show cont LLM toggle in that case * Update react build and package
This commit is contained in:
parent
f7be853554
commit
1eae5edf89
@ -1,15 +1,15 @@
|
||||
{
|
||||
"files": {
|
||||
"main.css": "/static/css/main.8665fcca.css",
|
||||
"main.js": "/static/js/main.358435c9.js",
|
||||
"main.js": "/static/js/main.50d66017.js",
|
||||
"static/js/787.4c72bb55.chunk.js": "/static/js/787.4c72bb55.chunk.js",
|
||||
"index.html": "/index.html",
|
||||
"main.8665fcca.css.map": "/static/css/main.8665fcca.css.map",
|
||||
"main.358435c9.js.map": "/static/js/main.358435c9.js.map",
|
||||
"main.50d66017.js.map": "/static/js/main.50d66017.js.map",
|
||||
"787.4c72bb55.chunk.js.map": "/static/js/787.4c72bb55.chunk.js.map"
|
||||
},
|
||||
"entrypoints": [
|
||||
"static/css/main.8665fcca.css",
|
||||
"static/js/main.358435c9.js"
|
||||
"static/js/main.50d66017.js"
|
||||
]
|
||||
}
|
@ -1 +1 @@
|
||||
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.358435c9.js"></script><link href="/static/css/main.8665fcca.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
|
||||
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.50d66017.js"></script><link href="/static/css/main.8665fcca.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
6
chainforge/react-server/src/JoinNode.js
vendored
6
chainforge/react-server/src/JoinNode.js
vendored
@ -8,6 +8,7 @@ import { IconArrowMerge, IconList } from '@tabler/icons-react';
|
||||
import { Divider, NativeSelect, Text, Popover, Tooltip, Center, Modal, Box } from '@mantine/core';
|
||||
import { useDisclosure } from '@mantine/hooks';
|
||||
import { escapeBraces } from './backend/template';
|
||||
import { countNumLLMs } from './backend/utils';
|
||||
|
||||
const formattingOptions = [
|
||||
{value: "\n\n", label:"double newline \\n\\n"},
|
||||
@ -53,11 +54,6 @@ const getVarsAndMetavars = (input_data) => {
|
||||
};
|
||||
};
|
||||
|
||||
const countNumLLMs = (resp_objs_or_dict) => {
|
||||
const resp_objs = Array.isArray(resp_objs_or_dict) ? resp_objs_or_dict : Object.values(resp_objs_or_dict).flat();
|
||||
return (new Set(resp_objs.filter(r => typeof r !== "string" && r.llm !== undefined).map(r => r.llm?.key || r.llm))).size;
|
||||
};
|
||||
|
||||
const tagMetadataWithLLM = (input_data) => {
|
||||
let new_data = {};
|
||||
Object.entries(input_data).forEach(([varname, resp_objs]) => {
|
||||
|
102
chainforge/react-server/src/PromptNode.js
vendored
102
chainforge/react-server/src/PromptNode.js
vendored
@ -13,6 +13,7 @@ import fetch_from_backend from './fetch_from_backend';
|
||||
import { escapeBraces } from './backend/template';
|
||||
import ChatHistoryView from './ChatHistoryView';
|
||||
import InspectFooter from './InspectFooter';
|
||||
import { countNumLLMs, setsAreEqual } from './backend/utils';
|
||||
|
||||
const getUniqueLLMMetavarKey = (responses) => {
|
||||
const metakeys = new Set(responses.map(resp_obj => Object.keys(resp_obj.metavars)).flat());
|
||||
@ -31,6 +32,17 @@ const bucketChatHistoryInfosByLLM = (chat_hist_infos) => {
|
||||
});
|
||||
return chats_by_llm;
|
||||
}
|
||||
const getLLMsInPulledInputData = (pulled_data) => {
|
||||
let found_llms = {};
|
||||
Object.values(pulled_data).filter(_vs => {
|
||||
let vs = Array.isArray(_vs) ? _vs : [_vs];
|
||||
vs.forEach(v => {
|
||||
if (v?.llm !== undefined && !(v.llm.key in found_llms))
|
||||
found_llms[v.llm.key] = v.llm;
|
||||
});
|
||||
});
|
||||
return Object.values(found_llms);
|
||||
};
|
||||
|
||||
class PromptInfo {
|
||||
prompt; // string
|
||||
@ -80,6 +92,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
// Get state from the Zustand store:
|
||||
const edges = useStore((state) => state.edges);
|
||||
const pullInputData = useStore((state) => state.pullInputData);
|
||||
const getImmediateInputNodeTypes = useStore((state) => state.getImmediateInputNodeTypes);
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
const pingOutputNodes = useStore((state) => state.pingOutputNodes);
|
||||
|
||||
@ -106,9 +119,10 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
const [uninspectedResponses, setUninspectedResponses] = useState(false);
|
||||
const [responsesWillChange, setResponsesWillChange] = useState(false);
|
||||
|
||||
// Chat node specific
|
||||
const [contChatWithPriorLLMs, setContChatWithPriorLLMs] = useState(data.contChat !== undefined ? data.contChat : true);
|
||||
const [contChatToggleDisabled, setContChatToggleDisabled] = useState(false);
|
||||
// For continuing with prior LLMs toggle
|
||||
const [contWithPriorLLMs, setContWithPriorLLMs] = useState(data.contChat !== undefined ? data.contChat : (node_type === 'chat' ? true : false));
|
||||
const [showContToggle, setShowContToggle] = useState(node_type === 'chat');
|
||||
const [contToggleDisabled, setContChatToggleDisabled] = useState(false);
|
||||
|
||||
// For an info pop-up that shows all the prompts that will be sent off
|
||||
// NOTE: This is the 'full' version of the PromptListPopover that activates on hover.
|
||||
@ -122,7 +136,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
const triggerAlert = useCallback((msg) => {
|
||||
setProgress(undefined);
|
||||
llmListContainer?.current?.resetLLMItemsProgress();
|
||||
alertModal.current.trigger(msg);
|
||||
alertModal?.current?.trigger(msg);
|
||||
}, [llmListContainer, alertModal]);
|
||||
|
||||
const showResponseInspector = useCallback(() => {
|
||||
@ -159,12 +173,27 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
}
|
||||
}, [setDataPropsForNode, signalDirty]);
|
||||
|
||||
const refreshTemplateHooks = (text) => {
|
||||
const updateShowContToggle = useCallback((pulled_data) => {
|
||||
if (node_type === 'chat') return; // always show when chat node
|
||||
const hasPromptInput = getImmediateInputNodeTypes(templateVars, id).some(t => ['prompt', 'chat'].includes(t));
|
||||
setShowContToggle(hasPromptInput || (pulled_data && countNumLLMs(pulled_data) > 0));
|
||||
}, [setShowContToggle, countNumLLMs, getImmediateInputNodeTypes, templateVars, id]);
|
||||
|
||||
const handleOnConnect = useCallback(() => {
|
||||
if (node_type === 'chat') return; // always show when chat node
|
||||
// Re-pull data and update show cont toggle:
|
||||
updateShowContToggle(pullInputData(templateVars, id));
|
||||
}, [templateVars, id, pullInputData, updateShowContToggle]);
|
||||
|
||||
const refreshTemplateHooks = useCallback((text) => {
|
||||
// Update template var fields + handles
|
||||
const found_template_vars = Array.from(
|
||||
new Set(extractBracketedSubstrings(text))); // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this}
|
||||
setTemplateVars(found_template_vars);
|
||||
};
|
||||
const found_template_vars = new Set(extractBracketedSubstrings(text)); // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this}
|
||||
|
||||
if (!setsAreEqual(found_template_vars, new Set(templateVars))) {
|
||||
if (node_type !== 'chat') updateShowContToggle(pullInputData(found_template_vars, id));
|
||||
setTemplateVars(Array.from(found_template_vars));
|
||||
}
|
||||
}, [setTemplateVars, templateVars, pullInputData, id]);
|
||||
|
||||
const handleInputChange = (event) => {
|
||||
const value = event.target.value;
|
||||
@ -202,6 +231,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
if (data.refresh === true) {
|
||||
setDataPropsForNode(id, { refresh: false });
|
||||
setStatus('warning');
|
||||
handleOnConnect();
|
||||
} else if (data.refreshLLMList === true) {
|
||||
llmListContainer?.current?.refreshLLMProviderList();
|
||||
setDataPropsForNode(id, { refreshLLMList: false });
|
||||
@ -259,6 +289,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
id: id,
|
||||
chat_histories: chat_histories,
|
||||
n: numGenerations,
|
||||
cont_only_w_prior_llms: node_type !== 'chat' ? (showContToggle && contWithPriorLLMs) : undefined,
|
||||
}, rejected).then(function(json) {
|
||||
if (!json || !json.counts) {
|
||||
throw new Error('There was no response from the server.');
|
||||
@ -272,6 +303,8 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
const handlePreviewHover = () => {
|
||||
// Pull input data and prompt
|
||||
const pulled_vars = pullInputData(templateVars, id);
|
||||
updateShowContToggle(pulled_vars);
|
||||
|
||||
fetch_from_backend('generatePrompts', {
|
||||
prompt: promptText,
|
||||
vars: pulled_vars,
|
||||
@ -294,7 +327,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
|
||||
// If this is a chat node, we also need to pull chat histories:
|
||||
let [past_chat_llms, pulled_chats] = node_type === 'chat' ? pullInputChats() : [undefined, undefined];
|
||||
if (node_type === 'chat' && contChatWithPriorLLMs) {
|
||||
if (node_type === 'chat' && contWithPriorLLMs) {
|
||||
if (past_chat_llms === undefined) {
|
||||
setRunTooltip('Attach an input to past conversations first.');
|
||||
return;
|
||||
@ -303,14 +336,21 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
pulled_chats = bucketChatHistoryInfosByLLM(pulled_chats);
|
||||
}
|
||||
|
||||
// Pull the input data
|
||||
const pulled_vars = pullInputData(templateVars, id);
|
||||
updateShowContToggle(pulled_vars);
|
||||
|
||||
// Whether to continue with only the prior LLMs, for each value in vars dict
|
||||
if (node_type !== 'chat' && showContToggle && contWithPriorLLMs) {
|
||||
// We need to draw the LLMs to query from the input responses
|
||||
_llmItemsCurrState = getLLMsInPulledInputData(pulled_vars);
|
||||
}
|
||||
|
||||
// Check if there's at least one model in the list; if not, nothing to run on.
|
||||
if (!_llmItemsCurrState || _llmItemsCurrState.length == 0) {
|
||||
setRunTooltip('No LLMs to query.');
|
||||
return;
|
||||
}
|
||||
|
||||
// Pull the input data
|
||||
const pulled_vars = pullInputData(templateVars, id);
|
||||
|
||||
const llms = _llmItemsCurrState.map(item => item.model);
|
||||
const num_llms = llms.length;
|
||||
@ -391,7 +431,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
// in the input variables (if any). If there's keys present w/o LLMs (for instance a text node),
|
||||
// we need to pop-up an error message.
|
||||
let _llmItemsCurrState = llmItemsCurrState;
|
||||
if (node_type === 'chat' && contChatWithPriorLLMs) {
|
||||
if (node_type === 'chat' && contWithPriorLLMs) {
|
||||
// If there's nothing attached to past conversations, we can't continue the chat:
|
||||
if (past_chat_llms === undefined) {
|
||||
triggerAlert('You need to attach an input to the Past Conversation message first. For instance, you might query \
|
||||
@ -421,6 +461,16 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
pulled_chats = bucketChatHistoryInfosByLLM(pulled_chats);
|
||||
}
|
||||
|
||||
// Pull the data to fill in template input variables, if any
|
||||
const pulled_data = pullInputData(templateVars, id);
|
||||
const prompt_template = promptText;
|
||||
|
||||
// Whether to continue with only the prior LLMs, for each value in vars dict
|
||||
if (node_type !== 'chat' && showContToggle && contWithPriorLLMs) {
|
||||
// We need to draw the LLMs to query from the input responses
|
||||
_llmItemsCurrState = getLLMsInPulledInputData(pulled_data);
|
||||
}
|
||||
|
||||
// Check that there is at least one LLM selected:
|
||||
if (_llmItemsCurrState.length === 0) {
|
||||
alert('Please select at least one LLM to prompt.')
|
||||
@ -433,10 +483,6 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
setJSONResponses([]);
|
||||
setProgressAnimated(true);
|
||||
|
||||
// Pull the data to fill in template input variables, if any
|
||||
const pulled_data = pullInputData(templateVars, id);
|
||||
const prompt_template = promptText;
|
||||
|
||||
const rejected = (err) => {
|
||||
setStatus('error');
|
||||
setContChatToggleDisabled(false);
|
||||
@ -501,6 +547,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
api_keys: (apiKeys ? apiKeys : {}),
|
||||
no_cache: false,
|
||||
progress_listener: onProgressChange,
|
||||
cont_only_w_prior_llms: node_type !== 'chat' ? (showContToggle && contWithPriorLLMs) : undefined,
|
||||
}, rejected).then(function(json) {
|
||||
if (!json) {
|
||||
rejected('Request was sent and received by backend server, but there was no response.');
|
||||
@ -536,7 +583,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
combined_err_msg += item?.name + ': ' + JSON.stringify(json.errors[llm_key][0]) + '\n';
|
||||
});
|
||||
// We trigger the alert directly (don't use triggerAlert) here because we want to keep the progress bar:
|
||||
alertModal.current.trigger('Errors collecting responses. Re-run prompt node to retry.\n\n'+combined_err_msg);
|
||||
alertModal?.current?.trigger('Errors collecting responses. Re-run prompt node to retry.\n\n'+combined_err_msg);
|
||||
|
||||
return;
|
||||
}
|
||||
@ -703,26 +750,27 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
<input id="num-generations" name="num-generations" type="number" min={1} max={50} defaultValue={data.n || 1} onChange={handleNumGenChange} className="nodrag"></input>
|
||||
</div>
|
||||
|
||||
{node_type === 'chat' ? (
|
||||
{ showContToggle ?
|
||||
<div>
|
||||
<Switch
|
||||
label={contChatWithPriorLLMs ? "Continue chat with prior LLM(s)" : "Continue chat with new LLMs:"}
|
||||
label={contWithPriorLLMs ? "Continue with prior LLM(s)" : "Continue with new LLMs:"}
|
||||
defaultChecked={true}
|
||||
checked={contChatWithPriorLLMs}
|
||||
disabled={contChatToggleDisabled}
|
||||
checked={contWithPriorLLMs}
|
||||
disabled={contToggleDisabled}
|
||||
onChange={(event) => {
|
||||
setStatus('warning');
|
||||
setContChatWithPriorLLMs(event.currentTarget.checked);
|
||||
setContWithPriorLLMs(event.currentTarget.checked);
|
||||
setDataPropsForNode(id, { contChat: event.currentTarget.checked });
|
||||
}}
|
||||
color='cyan'
|
||||
size='xs'
|
||||
mb={contChatWithPriorLLMs ? '4px' : '10px'}
|
||||
mb={contWithPriorLLMs ? '4px' : '10px'}
|
||||
/>
|
||||
</div>
|
||||
) : <></>}
|
||||
: <></>
|
||||
}
|
||||
|
||||
{node_type !== 'chat' || !contChatWithPriorLLMs ? (
|
||||
{(!contWithPriorLLMs || !showContToggle) ? (
|
||||
<LLMListContainer
|
||||
ref={llmListContainer}
|
||||
initLLMItems={data.llms}
|
||||
|
18
chainforge/react-server/src/SplitNode.js
vendored
18
chainforge/react-server/src/SplitNode.js
vendored
@ -8,7 +8,7 @@ import { IconArrowMerge, IconArrowsSplit, IconList } from '@tabler/icons-react';
|
||||
import { Divider, NativeSelect, Text, Popover, Tooltip, Center, Modal, Box } from '@mantine/core';
|
||||
import { useDisclosure } from '@mantine/hooks';
|
||||
import { escapeBraces } from './backend/template';
|
||||
import { processCSV } from "./backend/utils";
|
||||
import { processCSV, deepcopy, deepcopy_and_modify, dict_excluding_key } from "./backend/utils";
|
||||
|
||||
import { fromMarkdown } from "mdast-util-from-markdown";
|
||||
|
||||
@ -21,20 +21,6 @@ const formattingOptions = [
|
||||
{value: "paragraph", label:"paragraphs (md)"},
|
||||
];
|
||||
|
||||
const deepcopy = (v) => JSON.parse(JSON.stringify(v));
|
||||
const deepcopy_and_modify = (v, new_val_dict) => {
|
||||
let new_v = deepcopy(v);
|
||||
Object.entries(new_val_dict).forEach(([key, val]) => {
|
||||
new_v[key] = val;
|
||||
});
|
||||
return new_v;
|
||||
};
|
||||
const excluding_key = (d, key) => {
|
||||
if (!(key in d)) return d;
|
||||
const copy_d = {...d};
|
||||
delete copy_d[key];
|
||||
return copy_d;
|
||||
};
|
||||
const truncStr = (s, maxLen) => {
|
||||
if (s.length > maxLen) // Cut the name short if it's long
|
||||
return s.substring(0, maxLen) + '...'
|
||||
@ -220,7 +206,7 @@ const SplitNode = ({ data, id }) => {
|
||||
// Convert the templates into response objects
|
||||
let resp_objs = promptTemplates.map(p => ({
|
||||
text: p.toString(),
|
||||
fill_history: excluding_key(p.fill_history, "__input"),
|
||||
fill_history: dict_excluding_key(p.fill_history, "__input"),
|
||||
llm: "__LLM_key" in p.metavars ? llm_lookup[p.metavars['__LLM_key']] : undefined,
|
||||
metavars: removeLLMTagFromMetadata(p.metavars),
|
||||
}));
|
||||
|
12
chainforge/react-server/src/TextFieldsNode.js
vendored
12
chainforge/react-server/src/TextFieldsNode.js
vendored
@ -6,6 +6,7 @@ import useStore from './store';
|
||||
import NodeLabel from './NodeLabelComponent';
|
||||
import TemplateHooks, { extractBracketedSubstrings } from './TemplateHooksComponent';
|
||||
import BaseNode from './BaseNode';
|
||||
import { setsAreEqual } from './backend/utils';
|
||||
|
||||
// Helper funcs
|
||||
const union = (setA, setB) => {
|
||||
@ -15,17 +16,6 @@ const union = (setA, setB) => {
|
||||
}
|
||||
return _union;
|
||||
}
|
||||
const setsAreEqual = (setA, setB) => {
|
||||
if (setA.size !== setB.size) return false;
|
||||
let equal = true;
|
||||
for (const item of setA) {
|
||||
if (!setB.has(item)) {
|
||||
equal = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return equal;
|
||||
}
|
||||
|
||||
const delButtonId = 'del-';
|
||||
const visibleButtonId = 'eye-';
|
||||
|
@ -2,7 +2,7 @@ import markdownIt from "markdown-it";
|
||||
|
||||
import { Dict, StringDict, LLMResponseError, LLMResponseObject, StandardizedLLMResponse, ChatHistoryInfo, isEqualChatHistory } from "./typing";
|
||||
import { LLM, NativeLLM, getEnumName } from "./models";
|
||||
import { APP_IS_RUNNING_LOCALLY, set_api_keys, FLASK_BASE_URL, call_flask_backend } from "./utils";
|
||||
import { APP_IS_RUNNING_LOCALLY, set_api_keys, FLASK_BASE_URL, call_flask_backend, filterDict, deepcopy } from "./utils";
|
||||
import StorageCache from "./cache";
|
||||
import { PromptPipeline } from "./query";
|
||||
import { PromptPermutationGenerator, PromptTemplate } from "./template";
|
||||
@ -209,6 +209,15 @@ function extract_llm_params(llm_spec: Dict | string): Dict {
|
||||
return {};
|
||||
}
|
||||
|
||||
function filterVarsByLLM(vars: Dict, llm_key: string): Dict {
|
||||
let _vars = {};
|
||||
Object.entries(vars).forEach(([key, val]) => {
|
||||
const vs = Array.isArray(val) ? val : [val];
|
||||
_vars[key] = vs.filter((v) => (typeof v === 'string' || v?.llm === undefined || (v?.llm?.key === llm_key)));
|
||||
});
|
||||
return _vars;
|
||||
}
|
||||
|
||||
/**
|
||||
* Test equality akin to Python's list equality.
|
||||
*/
|
||||
@ -395,14 +404,24 @@ export async function countQueries(prompt: string,
|
||||
llms: Array<Dict | string>,
|
||||
n: number,
|
||||
chat_histories?: ChatHistoryInfo[] | {[key: string]: ChatHistoryInfo[]},
|
||||
id?: string): Promise<Dict> {
|
||||
id?: string,
|
||||
cont_only_w_prior_llms?: boolean): Promise<Dict> {
|
||||
if (chat_histories === undefined) chat_histories = [ undefined ];
|
||||
|
||||
let gen_prompts: PromptPermutationGenerator;
|
||||
let all_prompt_permutations: Array<PromptTemplate>;
|
||||
let all_prompt_permutations: Array<PromptTemplate> | Dict;
|
||||
try {
|
||||
gen_prompts = new PromptPermutationGenerator(prompt);
|
||||
all_prompt_permutations = Array.from(gen_prompts.generate(vars));
|
||||
if (cont_only_w_prior_llms && Array.isArray(llms)) {
|
||||
all_prompt_permutations = {};
|
||||
llms.forEach(llm_spec => {
|
||||
const llm_key = extract_llm_key(llm_spec);
|
||||
all_prompt_permutations[llm_key] = Array.from(gen_prompts.generate(filterVarsByLLM(vars, llm_key)));
|
||||
});
|
||||
} else {
|
||||
all_prompt_permutations = Array.from(gen_prompts.generate(vars));
|
||||
}
|
||||
|
||||
} catch (err) {
|
||||
return {error: err.message};
|
||||
}
|
||||
@ -432,6 +451,9 @@ export async function countQueries(prompt: string,
|
||||
llms.forEach(llm_spec => {
|
||||
const llm_key = extract_llm_key(llm_spec);
|
||||
|
||||
// Get only the relevant prompt permutations
|
||||
let _all_prompt_perms = cont_only_w_prior_llms ? all_prompt_permutations[llm_key] : all_prompt_permutations;
|
||||
|
||||
// Get the relevant chat histories for this LLM:
|
||||
const chat_hists = (!Array.isArray(chat_histories)
|
||||
? chat_histories[extract_llm_nickname(llm_spec)]
|
||||
@ -448,7 +470,7 @@ export async function countQueries(prompt: string,
|
||||
const cache_llm_responses = load_from_cache(cache_filename);
|
||||
|
||||
// Iterate through all prompt permutations and check if how many responses there are in the cache with that prompt
|
||||
all_prompt_permutations.forEach(prompt => {
|
||||
_all_prompt_perms.forEach(prompt => {
|
||||
let prompt_str = prompt.toString();
|
||||
|
||||
add_to_num_responses_req(llm_key, n * chat_hists.length);
|
||||
@ -490,7 +512,7 @@ export async function countQueries(prompt: string,
|
||||
}
|
||||
|
||||
if (!found_cache) {
|
||||
all_prompt_permutations.forEach(perm => {
|
||||
_all_prompt_perms.forEach(perm => {
|
||||
add_to_num_responses_req(llm_key, n * chat_hists.length);
|
||||
add_to_missing_queries(llm_key, perm.toString(), n * chat_hists.length);
|
||||
});
|
||||
@ -537,7 +559,8 @@ export async function queryLLM(id: string,
|
||||
chat_histories?: ChatHistoryInfo[] | {[key: string]: ChatHistoryInfo[]},
|
||||
api_keys?: Dict,
|
||||
no_cache?: boolean,
|
||||
progress_listener?: (progress: {[key: symbol]: any}) => void): Promise<Dict> {
|
||||
progress_listener?: (progress: {[key: symbol]: any}) => void,
|
||||
cont_only_w_prior_llms?: boolean): Promise<Dict> {
|
||||
// Verify the integrity of the params
|
||||
if (typeof id !== 'string' || id.trim().length === 0)
|
||||
return {'error': 'id is improper format (length 0 or not a string)'};
|
||||
@ -552,9 +575,6 @@ export async function queryLLM(id: string,
|
||||
llm = llm as (Array<string> | Array<Dict>);
|
||||
|
||||
await setAPIKeys(api_keys);
|
||||
|
||||
// if 'no_cache' in data and data['no_cache'] is True:
|
||||
// remove_cached_responses(data['id'])
|
||||
|
||||
// Get the storage keys of any cache files for specific models + settings
|
||||
const llms = llm;
|
||||
@ -623,6 +643,12 @@ export async function queryLLM(id: string,
|
||||
let llm_params = extract_llm_params(llm_spec);
|
||||
let llm_key = extract_llm_key(llm_spec);
|
||||
let temperature: number = llm_params?.temperature !== undefined ? llm_params.temperature : 1.0;
|
||||
let _vars = vars;
|
||||
|
||||
if (cont_only_w_prior_llms) {
|
||||
// Filter vars so that only the var values with the matching LLM are used, or otherwise values with no LLM metadata
|
||||
_vars = filterVarsByLLM(vars, llm_key);
|
||||
}
|
||||
|
||||
let chat_hists = ((chat_histories !== undefined && !Array.isArray(chat_histories))
|
||||
? chat_histories[llm_nickname]
|
||||
@ -642,7 +668,7 @@ export async function queryLLM(id: string,
|
||||
console.log(`Querying ${llm_str}...`)
|
||||
|
||||
// Yield responses for 'llm' for each prompt generated from the root template 'prompt' and template variables in 'properties':
|
||||
for await (const response of prompter.gen_responses(vars, llm_str as LLM, num_generations, temperature, llm_params, chat_hists)) {
|
||||
for await (const response of prompter.gen_responses(_vars, llm_str as LLM, num_generations, temperature, llm_params, chat_hists)) {
|
||||
|
||||
// Check for selective failure
|
||||
if (response instanceof LLMResponseError) { // The request failed
|
||||
|
@ -333,6 +333,7 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
|
||||
anthr_chat_context += ' ' + chat_msg.content;
|
||||
}
|
||||
wrapped_prompt = anthr_chat_context + wrapped_prompt; // prepend the chat context
|
||||
delete params.chat_history;
|
||||
}
|
||||
|
||||
// Format query
|
||||
@ -363,6 +364,7 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
|
||||
'User-Agent': "Anthropic/JS 0.5.0",
|
||||
'X-Api-Key': ANTHROPIC_API_KEY,
|
||||
};
|
||||
console.log(query);
|
||||
const resp = await route_fetch(url, 'POST', headers, query);
|
||||
responses.push(resp);
|
||||
|
||||
@ -902,4 +904,34 @@ export const processCSV = (csv: string): string[] => {
|
||||
if (matches[n] == ',') matches[n] = '';
|
||||
}
|
||||
return matches.map(e => e.trim()).filter(e => e.length > 0);
|
||||
}
|
||||
}
|
||||
|
||||
export const countNumLLMs = (resp_objs_or_dict: LLMResponseObject[] | Dict): number => {
|
||||
const resp_objs = Array.isArray(resp_objs_or_dict) ? resp_objs_or_dict : Object.values(resp_objs_or_dict).flat();
|
||||
return (new Set(resp_objs.filter(r => typeof r !== "string" && r.llm !== undefined).map(r => r.llm?.key || r.llm))).size;
|
||||
};
|
||||
|
||||
export const setsAreEqual = (setA: Set<any>, setB: Set<any>): boolean => {
|
||||
if (setA.size !== setB.size) return false;
|
||||
let equal = true;
|
||||
for (const item of setA) {
|
||||
if (!setB.has(item))
|
||||
return false;
|
||||
}
|
||||
return equal;
|
||||
}
|
||||
|
||||
export const deepcopy = (v) => JSON.parse(JSON.stringify(v));
|
||||
export const deepcopy_and_modify = (v, new_val_dict) => {
|
||||
let new_v = deepcopy(v);
|
||||
Object.entries(new_val_dict).forEach(([key, val]) => {
|
||||
new_v[key] = val;
|
||||
});
|
||||
return new_v;
|
||||
};
|
||||
export const dict_excluding_key = (d, key) => {
|
||||
if (!(key in d)) return d;
|
||||
const copy_d = {...d};
|
||||
delete copy_d[key];
|
||||
return copy_d;
|
||||
};
|
@ -11,11 +11,11 @@ async function _route_to_js_backend(route, params) {
|
||||
case 'grabResponses':
|
||||
return grabResponses(params.responses);
|
||||
case 'countQueriesRequired':
|
||||
return countQueries(params.prompt, clone(params.vars), clone(params.llms), params.n, params.chat_histories, params.id);
|
||||
return countQueries(params.prompt, clone(params.vars), clone(params.llms), params.n, params.chat_histories, params.id, params.cont_only_w_prior_llms);
|
||||
case 'generatePrompts':
|
||||
return generatePrompts(params.prompt, clone(params.vars));
|
||||
case 'queryllm':
|
||||
return queryLLM(params.id, clone(params.llm), params.n, params.prompt, clone(params.vars), params.chat_histories, params.api_keys, params.no_cache, params.progress_listener);
|
||||
return queryLLM(params.id, clone(params.llm), params.n, params.prompt, clone(params.vars), params.chat_histories, params.api_keys, params.no_cache, params.progress_listener, params.cont_only_w_prior_llms);
|
||||
case 'executejs':
|
||||
return executejs(params.id, params.code, params.responses, params.scope);
|
||||
case 'executepy':
|
||||
|
15
chainforge/react-server/src/store.js
vendored
15
chainforge/react-server/src/store.js
vendored
@ -205,6 +205,21 @@ const useStore = create((set, get) => ({
|
||||
}
|
||||
},
|
||||
|
||||
// Get the types of nodes attached immediately as input to the given node
|
||||
getImmediateInputNodeTypes: (_targetHandles, node_id) => {
|
||||
const getNode = get().getNode;
|
||||
const edges = get().edges;
|
||||
let inputNodeTypes = [];
|
||||
edges.forEach(e => {
|
||||
if (e.target == node_id && _targetHandles.includes(e.targetHandle)) {
|
||||
const src_node = getNode(e.source);
|
||||
if (src_node && src_node.type !== undefined)
|
||||
inputNodeTypes.push(src_node.type);
|
||||
}
|
||||
});
|
||||
return inputNodeTypes;
|
||||
},
|
||||
|
||||
// Pull all inputs needed to request responses.
|
||||
// Returns [prompt, vars dict]
|
||||
pullInputData: (_targetHandles, node_id) => {
|
||||
|
Loading…
x
Reference in New Issue
Block a user