Add stop button to cancel pending queries (#211)

* Add Stop button

* Replaced QueryTracker stop checks in _prompt_llm in query.ts. Modified _prompt_llm and *gen_responses to take in node id for checking purposes. Added new css class for stopping status.

* Used callback function instead of passing id to the backend, renamed QueryStopper and some of its functions, made custom error

* Added semicolons and one more UserForcedPrematureExit check

* Revise canceler to never clear id, and use unique id Date.now instead

* Make cancel go into call_llm funcs

* Cleanup console logs

* Rebuild app and update package version

---------

Co-authored-by: Kayla Zethelyn <kaylazethelyn@college.harvard.edu>
Co-authored-by: Ian Arawjo <fatso784@gmail.com>
This commit is contained in:
Kayla Z 2024-01-13 18:22:08 -05:00 committed by GitHub
parent 83c49ffe0b
commit 3d15bc9d17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 282 additions and 134 deletions

View File

@ -1,15 +1,15 @@
{
"files": {
"main.css": "/static/css/main.d7b7e6a1.css",
"main.js": "/static/js/main.16bd62df.js",
"main.css": "/static/css/main.01603dff.css",
"main.js": "/static/js/main.dd40466e.js",
"static/js/787.4c72bb55.chunk.js": "/static/js/787.4c72bb55.chunk.js",
"index.html": "/index.html",
"main.d7b7e6a1.css.map": "/static/css/main.d7b7e6a1.css.map",
"main.16bd62df.js.map": "/static/js/main.16bd62df.js.map",
"main.01603dff.css.map": "/static/css/main.01603dff.css.map",
"main.dd40466e.js.map": "/static/js/main.dd40466e.js.map",
"787.4c72bb55.chunk.js.map": "/static/js/787.4c72bb55.chunk.js.map"
},
"entrypoints": [
"static/css/main.d7b7e6a1.css",
"static/js/main.16bd62df.js"
"static/css/main.01603dff.css",
"static/js/main.dd40466e.js"
]
}

View File

@ -1 +1 @@
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.16bd62df.js"></script><link href="/static/css/main.d7b7e6a1.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.dd40466e.js"></script><link href="/static/css/main.01603dff.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
import { useRef } from 'react';
import { useMemo, useRef } from 'react';
import useStore from './store';
import { EditText } from 'react-edit-text';
import 'react-edit-text/dist/index.css';
@ -10,8 +10,7 @@ import { Tooltip, Popover, Badge, Stack } from '@mantine/core';
import { IconSparkles } from '@tabler/icons-react';
export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editable, status, alertModal, customButtons, handleRunClick, handleRunHover, runButtonTooltip }) {
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editable, status, isRunning, alertModal, customButtons, handleRunClick, handleStopClick, handleRunHover, runButtonTooltip }) { const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const [statusIndicator, setStatusIndicator] = useState('none');
const [runButton, setRunButton] = useState('none');
const removeNode = useStore((state) => state.removeNode);
@ -21,6 +20,11 @@ export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editabl
const [deleteConfirmProps, setDeleteConfirmProps] = useState({
title: 'Delete node', message: 'Are you sure?', onConfirm: () => {}
});
const stopButton = useMemo(() =>
<button className="AmitSahoo45-button-3 nodrag" style={{padding: '0px 10px 0px 9px'}} onClick={() => handleStopClick(nodeId)}>
&#9724;
</button>
, [handleStopClick, nodeId]);
const handleNodeLabelChange = (evt) => {
const { value } = evt;
@ -47,9 +51,8 @@ export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editabl
if (runButtonTooltip)
setRunButton(
<Tooltip label={runButtonTooltip} withArrow arrowSize={6} arrowRadius={2} zIndex={1001} withinPortal={true} >
{run_btn}
</Tooltip>
);
{run_btn}
</Tooltip>);
else
setRunButton(run_btn);
}
@ -58,6 +61,8 @@ export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editabl
}
}, [handleRunClick, runButtonTooltip]);
useEffect(() => {}, [stopButton])
const handleCloseButtonClick = useCallback(() => {
setDeleteConfirmProps({
title: 'Delete node',
@ -86,7 +91,7 @@ export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editabl
<AlertModal ref={alertModal} />
<div className="node-header-btns-container">
{customButtons ? customButtons : <></>}
{runButton}
{isRunning ? stopButton : runButton}
<button className="close-button nodrag" onClick={handleCloseButtonClick}>&#x2715;</button>
<br/>
</div>

View File

@ -15,6 +15,8 @@ import ChatHistoryView from './ChatHistoryView';
import InspectFooter from './InspectFooter';
import { countNumLLMs, setsAreEqual, getLLMsInPulledInputData } from './backend/utils';
import LLMResponseInspectorDrawer from './LLMResponseInspectorDrawer';
import CancelTracker from './backend/canceler';
import { UserForcedPrematureExit } from './backend/errors';
const getUniqueLLMMetavarKey = (responses) => {
const metakeys = new Set(responses.map(resp_obj => Object.keys(resp_obj.metavars)).flat());
@ -125,6 +127,10 @@ const PromptNode = ({ data, id, type: node_type }) => {
const [progressAnimated, setProgressAnimated] = useState(true);
const [runTooltip, setRunTooltip] = useState(null);
// Cancelation of pending queries
const [cancelId, setCancelId] = useState(Date.now());
const refreshCancelId = () => setCancelId(Date.now());
// Debounce helpers
const debounceTimeoutRef = useRef(null);
const debounce = (func, delay) => {
@ -223,9 +229,8 @@ const PromptNode = ({ data, id, type: node_type }) => {
data['prompt'] = value;
// Update status icon, if need be:
if (promptTextOnLastRun !== null && status !== 'warning' && value !== promptTextOnLastRun) {
if (promptTextOnLastRun !== null && status !== 'warning' && value !== promptTextOnLastRun)
setStatus('warning');
}
refreshTemplateHooks(value);
};
@ -529,9 +534,14 @@ const PromptNode = ({ data, id, type: node_type }) => {
setProgressAnimated(true);
const rejected = (err) => {
setStatus('error');
setContChatToggleDisabled(false);
triggerAlert(err.message || err);
if (err instanceof UserForcedPrematureExit || CancelTracker.has(cancelId)) {
// Handle a premature cancelation
console.log("Canceled.");
} else {
setStatus('error');
setContChatToggleDisabled(false);
triggerAlert(err.message || err);
}
};
// Fetch info about the number of queries we'll need to make
@ -550,7 +560,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
const max_responses = Object.keys(total_num_responses).reduce((acc, llm) => acc + total_num_responses[llm], 0);
onProgressChange = (progress_by_llm_key) => {
if (!progress_by_llm_key) return;
if (!progress_by_llm_key || CancelTracker.has(cancelId)) return;
// Update individual progress bars
const num_llms = _llmItemsCurrState.length;
@ -584,7 +594,6 @@ const PromptNode = ({ data, id, type: node_type }) => {
};
};
// Run all prompt permutations through the LLM to generate + cache responses:
const query_llms = () => {
return fetch_from_backend('queryllm', {
@ -598,65 +607,23 @@ const PromptNode = ({ data, id, type: node_type }) => {
no_cache: false,
progress_listener: onProgressChange,
cont_only_w_prior_llms: node_type !== 'chat' ? (showContToggle && contWithPriorLLMs) : undefined,
cancel_id: cancelId,
}, rejected).then(function(json) {
if (!json) {
rejected('Request was sent and received by backend server, but there was no response.');
}
else if (json.responses && json.errors) {
// We have to early exit explicitly because we will still enter this function even if 'rejected' is called
if (!json && CancelTracker.has(cancelId))
return;
// Remove progress bars
setProgress(undefined);
setProgressAnimated(false);
debounce(() => {}, 1)(); // erase any pending debounces
// Remove progress bars
setProgress(undefined);
setProgressAnimated(false);
debounce(() => {}, 1)(); // erase any pending debounces
// Store and log responses (if any)
if (json.responses) {
setJSONResponses(json.responses);
// Store and log responses (if any)
if (json?.responses) {
setJSONResponses(json.responses);
// Log responses for debugging:
console.log(json.responses);
}
// If there was at least one error collecting a response...
const llms_w_errors = Object.keys(json.errors);
if (llms_w_errors.length > 0) {
// Remove the total progress bar
setProgress(undefined);
// Ensure there's a sliver of error displayed in the progress bar
// of every LLM item that has an error:
llmListContainer?.current?.ensureLLMItemsErrorProgress(llms_w_errors);
// Set error status
setStatus('error');
setContChatToggleDisabled(false);
// Trigger alert and display one error message per LLM of all collected errors:
let combined_err_msg = "";
llms_w_errors.forEach(llm_key => {
const item = _llmItemsCurrState.find((item) => item.key === llm_key);
combined_err_msg += item?.name + ': ' + JSON.stringify(json.errors[llm_key][0]) + '\n';
});
// We trigger the alert directly (don't use triggerAlert) here because we want to keep the progress bar:
alertModal?.current?.trigger('Errors collecting responses. Re-run prompt node to retry.\n\n'+combined_err_msg);
return;
}
if (responsesWillChange && !showDrawer)
setUninspectedResponses(true);
setResponsesWillChange(false);
// All responses collected! Change status to 'ready':
setStatus('ready');
setContChatToggleDisabled(false);
// Remove individual progress rings
llmListContainer?.current?.resetLLMItemsProgress();
// Save prompt text so we remember what prompt we have responses cache'd for:
setPromptTextOnLastRun(promptText);
setNumGenerationsLastRun(numGenerations);
// Log responses for debugging:
console.log(json.responses);
// Save response texts as 'fields' of data, for any prompt nodes pulling the outputs
// We also need to store a unique metavar for the LLM *set* (set of LLM nicknames) that produced these responses,
@ -668,9 +635,9 @@ const PromptNode = ({ data, id, type: node_type }) => {
r => {
// Carry over the response text, prompt, prompt fill history (vars), and llm nickname:
let o = { text: escapeBraces(r),
prompt: resp_obj['prompt'],
fill_history: resp_obj['vars'],
llm: _llmItemsCurrState.find((item) => item.name === resp_obj.llm) };
prompt: resp_obj['prompt'],
fill_history: resp_obj['vars'],
llm: _llmItemsCurrState.find((item) => item.name === resp_obj.llm) };
// Carry over any metavars
o.metavars = resp_obj['metavars'] || {};
@ -685,12 +652,58 @@ const PromptNode = ({ data, id, type: node_type }) => {
}
)).flat()
});
// Ping any inspect nodes attached to this node to refresh their contents:
pingOutputNodes(id);
} else {
rejected(json.error || 'Unknown error when querying LLM');
}
// If there was at least one error collecting a response...
const llms_w_errors = json?.errors ? Object.keys(json.errors) : [];
if (llms_w_errors.length > 0) {
// Remove the total progress bar
setProgress(undefined);
// Ensure there's a sliver of error displayed in the progress bar
// of every LLM item that has an error:
llmListContainer?.current?.ensureLLMItemsErrorProgress(llms_w_errors);
// Set error status
setStatus('error');
setContChatToggleDisabled(false);
// Trigger alert and display one error message per LLM of all collected errors:
let combined_err_msg = "";
llms_w_errors.forEach(llm_key => {
const item = _llmItemsCurrState.find((item) => item.key === llm_key);
combined_err_msg += item?.name + ': ' + JSON.stringify(json.errors[llm_key][0]) + '\n';
});
// We trigger the alert directly (don't use triggerAlert) here because we want to keep the progress bar:
alertModal?.current?.trigger('Errors collecting responses. Re-run prompt node to retry.\n\n'+combined_err_msg);
return;
}
if (responsesWillChange && !showDrawer)
setUninspectedResponses(true);
setResponsesWillChange(false);
setContChatToggleDisabled(false);
// Remove individual progress rings
llmListContainer?.current?.resetLLMItemsProgress();
if (json?.error || !json) {
rejected(json?.error ?? 'Request was sent and received by backend server, but there was no response.');
return;
}
// Save prompt text so we remember what prompt we have responses cache'd for:
setPromptTextOnLastRun(promptText);
setNumGenerationsLastRun(numGenerations);
// All responses collected! Change status to 'ready':
setStatus('ready');
// Ping any inspect nodes attached to this node to refresh their contents:
pingOutputNodes(id);
}, rejected);
};
@ -701,6 +714,23 @@ const PromptNode = ({ data, id, type: node_type }) => {
.catch(rejected);
};
const handleStopClick = useCallback(() => {
CancelTracker.add(cancelId);
refreshCancelId();
// Update UI to seem like it's been immediately canceled, even
// though we cannot fully cancel the queryLLMs Promise.
// Remove progress bars
setProgress(undefined);
setProgressAnimated(false);
debounce(() => {}, 1)(); // erase any pending debounces
// Set error status
setStatus('none');
setContChatToggleDisabled(false);
llmListContainer?.current?.resetLLMItemsProgress();
}, [cancelId, refreshCancelId]);
const handleNumGenChange = useCallback((event) => {
let n = event.target.value;
if (!isNaN(n) && n.length > 0 && /^\d+$/.test(n)) {
@ -714,7 +744,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
}, [numGenerationsLastRun, status]);
const hideStatusIndicator = () => {
if (status !== 'none') { setStatus('none'); }
if (status !== 'none') setStatus('none');
};
// Dynamically update the textareas and position of the template hooks
@ -748,8 +778,10 @@ const PromptNode = ({ data, id, type: node_type }) => {
onEdit={hideStatusIndicator}
icon={node_icon}
status={status}
isRunning = {status === 'loading'}
alertModal={alertModal}
handleRunClick={handleRunClick}
handleStopClick={handleStopClick}
handleRunHover={handleRunHover}
runButtonTooltip={runTooltip}
customButtons={[

View File

@ -32,7 +32,6 @@ const ScriptNode = ({ data, id }) => {
if (Object.keys(new_data.scriptFiles).length === 0) {
new_data.scriptFiles[get_id()] = '';
}
console.log(new_data);
setDataPropsForNode(id, new_data);
}, [data, id, setDataPropsForNode]);

View File

@ -73,7 +73,7 @@ const TabularDataNode = ({ data, id }) => {
}));
return;
}
console.log('handleSaveCell', rowIdx, columnKey, value);
// console.log('handleSaveCell', rowIdx, columnKey, value);
tableData[rowIdx][columnKey] = value;
setTableData([...tableData]);
}, [tableData, tableColumns, pingOutputNodes]);

View File

@ -128,9 +128,7 @@ function decode(mdText: string): Row[] {
export async function autofill(input: Row[], n: number, apiKeys?: Dict): Promise<Row[]> {
// hash the arguments to get a unique id
let id = JSON.stringify([input, n]);
let encoded = encode(input);
let templateVariables = [...new Set(new StringTemplate(input.join('\n')).get_vars())];
console.log("System message: ", autofillSystemMessage(n, templateVariables));

View File

@ -6,6 +6,8 @@ import { APP_IS_RUNNING_LOCALLY, set_api_keys, FLASK_BASE_URL, call_flask_backen
import StorageCache from "./cache";
import { PromptPipeline } from "./query";
import { PromptPermutationGenerator, PromptTemplate } from "./template";
import { UserForcedPrematureExit } from "./errors";
import CancelTracker from "./canceler";
// """ =================
// SETUP AND GLOBALS
@ -547,6 +549,9 @@ export async function fetchEnvironAPIKeys(): Promise<Dict> {
* @param chat_histories Either an array of `ChatHistory` (to use across all LLMs), or a dict indexed by LLM nicknames of `ChatHistory` arrays to use per LLM.
* @param api_keys (optional) a dict of {api_name: api_key} pairs. Supported key names: OpenAI, Anthropic, Google
* @param no_cache (optional) if true, deletes any cache'd responses for 'id' (always calls the LLMs fresh)
* @param progress_listener (optional) a callback whenever an LLM response is collected, on the current progress
* @param cont_only_w_prior_llms (optional) whether we are continuing using prior LLMs
* @param cancel_id (optional) the id that would appear in CancelTracker if the user cancels the querying (NOT the same as 'id' --must be unique!)
* @returns a dictionary in format `{responses: StandardizedLLMResponse[], errors: string[]}`
*/
export async function queryLLM(id: string,
@ -558,7 +563,8 @@ export async function queryLLM(id: string,
api_keys?: Dict,
no_cache?: boolean,
progress_listener?: (progress: {[key: symbol]: any}) => void,
cont_only_w_prior_llms?: boolean): Promise<Dict> {
cont_only_w_prior_llms?: boolean,
cancel_id?: string): Promise<Dict> {
// Verify the integrity of the params
if (typeof id !== 'string' || id.trim().length === 0)
return {'error': 'id is improper format (length 0 or not a string)'};
@ -635,6 +641,9 @@ export async function queryLLM(id: string,
}
});
// Helper function to check whether this process has been canceled
const should_cancel = () => CancelTracker.has(cancel_id);
// For each LLM, generate and cache responses:
let responses: {[key: string]: Array<LLMResponseObject>} = {};
let all_errors = {};
@ -667,11 +676,12 @@ export async function queryLLM(id: string,
let errors: Array<string> = [];
let num_resps = 0;
let num_errors = 0;
try {
console.log(`Querying ${llm_str}...`)
// Yield responses for 'llm' for each prompt generated from the root template 'prompt' and template variables in 'properties':
for await (const response of prompter.gen_responses(_vars, llm_str as LLM, num_generations, temperature, llm_params, chat_hists)) {
for await (const response of prompter.gen_responses(_vars, llm_str as LLM, num_generations, temperature, llm_params, chat_hists, should_cancel)) {
// Check for selective failure
if (response instanceof LLMResponseError) { // The request failed
@ -694,8 +704,12 @@ export async function queryLLM(id: string,
};
}
} catch (e) {
console.error(`Error generating responses for ${llm_str}: ${e.message}`);
throw e;
if (e instanceof UserForcedPrematureExit) {
throw e;
} else {
console.error(`Error generating responses for ${llm_str}: ${e.message}`);
throw e;
}
}
return {
@ -716,6 +730,7 @@ export async function queryLLM(id: string,
all_errors[result.llm_key] = result.errors;
});
} catch (e) {
if (e instanceof UserForcedPrematureExit) throw e;
console.error(`Error requesting responses: ${e.message}`);
return { error: e.message };
}
@ -758,6 +773,7 @@ export async function queryLLM(id: string,
cache_files: cache_filenames,
responses_last_run: res,
});
// Return all responses for all LLMs
return {
responses: res,
@ -1194,4 +1210,4 @@ export async function loadCachedCustomProviders(): Promise<Dict> {
}).then(function(res) {
return res.json();
});
}
}

View File

@ -0,0 +1,57 @@
import { UserForcedPrematureExit } from "./errors";
/**
* A CancelTracker allows ids to be added, to signal
* any associated processes should be 'canceled'. The tracker
* operates as a global. It does not cancel anything itself,
* but rather can be used to send a message to cancel a process
* associated with 'id' (through .add(id)), which the process itself
* checks (through .has(id)) and then performs the cancellation.
*/
export default class CancelTracker {
private static instance: CancelTracker;
private data: Set<string>;
private constructor() {
this.data = new Set();
}
// Get the canceler
public static getInstance(): CancelTracker {
if (!CancelTracker.instance)
CancelTracker.instance = new CancelTracker();
return CancelTracker.instance;
}
// Add an id to trigger cancelation
private addId(id: string): void {
this.data.add(id);
}
public static add(id: string): void {
CancelTracker.getInstance().addId(id);
}
// Canceler has the given id
private hasId(id: string): boolean {
return this.data.has(id);
}
public static has(id: string): boolean {
return CancelTracker.getInstance().hasId(id);
}
// Clear id from the canceler
private clearId(id: string): void {
if (CancelTracker.has(id))
this.data.delete(id);
}
public static clear(id: string): void {
CancelTracker.getInstance().clearId(id);
}
private clearTracker(): void {
this.data.clear();
}
public static clearAll(): void {
CancelTracker.getInstance().clearTracker();
}
}

View File

@ -6,4 +6,12 @@ export class DuplicateVariableNameError extends Error {
this.name = "DuplicateVariableNameError";
this.message = "You have multiple template variables with the same name, {" + variable + "}. Duplicate names in the same chain is not allowed. To fix, ensure that all template variable names are unique across a chain.";
}
}
export class UserForcedPrematureExit extends Error {
constructor(id?: string) {
super();
this.name = "UserForcedPrematureExit";
this.message = "You have forced the premature exit of the process" + (id !== undefined ? ` with id ${id}` : "");
}
}

View File

@ -3,6 +3,8 @@ import { LLM, NativeLLM, RATE_LIMITS } from './models';
import { Dict, LLMResponseError, LLMResponseObject, isEqualChatHistory, ChatHistoryInfo } from "./typing";
import { extract_responses, merge_response_objs, call_llm, mergeDicts } from "./utils";
import StorageCache from "./cache";
import CancelTracker from "./canceler";
import { UserForcedPrematureExit } from "./errors";
const clone = (obj) => JSON.parse(JSON.stringify(obj));
@ -129,12 +131,13 @@ export class PromptPipeline {
* and 3 different prior chat histories, it will send off 9 queries.
* @yields Yields `LLMResponseObject` if API call succeeds, or `LLMResponseError` if API call fails, for all requests.
*/
async *gen_responses(vars: Dict,
async *gen_responses( vars: Dict,
llm: LLM,
n: number = 1,
temperature: number = 1.0,
llm_params?: Dict,
chat_histories?: ChatHistoryInfo[]): AsyncGenerator<LLMResponseObject | LLMResponseError, boolean, undefined> {
n: number = 1,
temperature: number = 1.0,
llm_params?: Dict,
chat_histories?: ChatHistoryInfo[],
should_cancel?: ()=>boolean): AsyncGenerator<LLMResponseObject | LLMResponseError, boolean, undefined> {
// Load any cache'd responses
let responses = this._load_cached_responses();
@ -215,13 +218,15 @@ export class PromptPipeline {
max_req,
wait_secs,
llm_params,
chat_history));
chat_history,
should_cancel));
} else {
// Block. Await + yield a single LLM call.
let result = await this._prompt_llm(llm, prompt, n, temperature,
cached_resp, cached_resp_idx,
undefined, undefined, undefined,
llm_params, chat_history);
llm_params, chat_history,
should_cancel);
yield this.collect_LLM_response(result, llm, responses);
}
}
@ -263,7 +268,8 @@ export class PromptPipeline {
rate_limit_batch_size?: number,
rate_limit_wait_secs?: number,
llm_params?: Dict,
chat_history?: ChatHistoryInfo): Promise<_IntermediateLLMResponseType> {
chat_history?: ChatHistoryInfo,
should_cancel?: ()=>boolean): Promise<_IntermediateLLMResponseType> {
// Detect how many responses we have already (from cache obj past_resp_obj)
if (past_resp_obj) {
// How many *new* queries we need to send:
@ -283,7 +289,7 @@ export class PromptPipeline {
await sleep(wait_secs);
}
}
// Now try to call the API. If it fails for whatever reason, 'soft fail' by returning
// an LLMResponseException object as the 'response'.
let params = clone(llm_params);
@ -291,15 +297,24 @@ export class PromptPipeline {
let query: Dict | undefined;
let response: Dict | LLMResponseError;
try {
[query, response] = await call_llm(llm, prompt.toString(), n, temperature, params);
// When/if we emerge from sleep, check if this process has been canceled in the meantime:
if (should_cancel && should_cancel()) throw new UserForcedPrematureExit();
// Call the LLM, returning when the Promise returns (if it does!)
[query, response] = await call_llm(llm, prompt.toString(), n, temperature, params, should_cancel);
// When/if we emerge from getting a response, check if this process has been canceled in the meantime:
if (should_cancel && should_cancel()) throw new UserForcedPrematureExit();
} catch(err) {
if (err instanceof UserForcedPrematureExit) throw err; // bubble cancels up
return { prompt: prompt,
query: undefined,
response: new LLMResponseError(err.message),
past_resp_obj: undefined,
past_resp_obj_cache_idx: -1 };
}
return { prompt,
chat_history,
query,

View File

@ -100,7 +100,8 @@ export interface LLMAPICall {
model: LLM,
n: number,
temperature: number,
params?: Dict): Promise<[Dict, Dict]>
params?: Dict,
should_cancel?: (() => boolean)): Promise<[Dict, Dict]>,
}
/** A standard response format expected by the front-end. */

View File

@ -12,6 +12,7 @@ import { StringTemplate } from './template';
import { Configuration as OpenAIConfig, OpenAIApi } from "openai";
import { OpenAIClient as AzureOpenAIClient, AzureKeyCredential } from "@azure/openai";
import { GoogleGenerativeAI } from "@google/generative-ai";
import { UserForcedPrematureExit } from './errors';
const ANTHROPIC_HUMAN_PROMPT = "\n\nHuman:";
const ANTHROPIC_AI_PROMPT = "\n\nAssistant:";
@ -155,7 +156,7 @@ function construct_openai_chat_history(prompt: string, chat_history: ChatHistory
* Calls OpenAI models via OpenAI's API.
@returns raw query and response JSON dicts.
*/
export async function call_chatgpt(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
export async function call_chatgpt(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict, should_cancel?: (() => boolean)): Promise<[Dict, Dict]> {
if (!OPENAI_API_KEY)
throw new Error("Could not find an OpenAI API key. Double-check that your API key is set in Settings or in your local environment.");
@ -233,7 +234,7 @@ export async function call_chatgpt(prompt: string, model: LLM, n: number = 1, te
*
* NOTE: It is recommended to set an environment variables AZURE_OPENAI_KEY and AZURE_OPENAI_ENDPOINT
*/
export async function call_azure_openai(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
export async function call_azure_openai(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict, should_cancel?: (() => boolean)): Promise<[Dict, Dict]> {
if (!AZURE_OPENAI_KEY)
throw new Error("Could not find an Azure OpenAPI Key to use. Double-check that your key is set in Settings or in your local environment.");
if (!AZURE_OPENAI_ENDPOINT)
@ -306,7 +307,7 @@ export async function call_azure_openai(prompt: string, model: LLM, n: number =
NOTE: It is recommended to set an environment variable ANTHROPIC_API_KEY with your Anthropic API key
*/
export async function call_anthropic(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
export async function call_anthropic(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict, should_cancel?: (() => boolean)): Promise<[Dict, Dict]> {
if (!ANTHROPIC_API_KEY)
throw new Error("Could not find an API key for Anthropic models. Double-check that your API key is set in Settings or in your local environment.");
@ -357,6 +358,8 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
// Repeat call n times, waiting for each response to come in:
let responses: Array<Dict> = [];
while (responses.length < n) {
// Abort if canceled
if (should_cancel && should_cancel()) throw new UserForcedPrematureExit();
if (APP_IS_RUNNING_LOCALLY()) {
// If we're running locally, route the request through the Flask backend,
@ -369,7 +372,6 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
'User-Agent': "Anthropic/JS 0.5.0",
'X-Api-Key': ANTHROPIC_API_KEY,
};
console.log(query);
const resp = await route_fetch(url, 'POST', headers, query);
responses.push(resp);
@ -389,7 +391,6 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
throw new Error(`${resp.error.type}: ${resp.error.message}`);
}
// console.log('Received Anthropic response from server proxy:', resp);
responses.push(resp);
}
}
@ -401,12 +402,12 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
* Calls a Google PaLM/Gemini model, based on the model selection from the user.
* Returns raw query and response JSON dicts.
*/
export async function call_google_ai(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
export async function call_google_ai(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict, should_cancel?: (() => boolean)): Promise<[Dict, Dict]> {
switch(model) {
case NativeLLM.GEMINI_PRO:
return call_google_gemini(prompt, model, n, temperature, params);
return call_google_gemini(prompt, model, n, temperature, params, should_cancel);
default:
return call_google_palm(prompt, model, n, temperature, params);
return call_google_palm(prompt, model, n, temperature, params, should_cancel);
}
}
@ -414,7 +415,7 @@ export async function call_google_ai(prompt: string, model: LLM, n: number = 1,
* Calls a Google PaLM model.
* Returns raw query and response JSON dicts.
*/
export async function call_google_palm(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
export async function call_google_palm(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict, should_cancel?: (() => boolean)): Promise<[Dict, Dict]> {
if (!GOOGLE_PALM_API_KEY)
throw new Error("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local environment.");
const is_chat_model = model.toString().includes('chat');
@ -523,7 +524,7 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
return [query, completion];
}
export async function call_google_gemini(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
export async function call_google_gemini(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict, should_cancel?: (() => boolean)): Promise<[Dict, Dict]> {
if (!GOOGLE_PALM_API_KEY)
throw new Error("Could not find an API key for Google Gemini models. Double-check that your API key is set in Settings or in your local environment.");
@ -604,6 +605,8 @@ export async function call_google_gemini(prompt: string, model: LLM, n: number =
let responses: Array<Dict> = [];
while(responses.length < n) {
if (should_cancel && should_cancel()) throw new UserForcedPrematureExit();
const chat = gemini_model.startChat(
{
history: gemini_chat_context.history,
@ -624,7 +627,7 @@ export async function call_google_gemini(prompt: string, model: LLM, n: number =
return [query, responses];
}
export async function call_dalai(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
export async function call_dalai(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict, should_cancel?: (() => boolean)): Promise<[Dict, Dict]> {
if (APP_IS_RUNNING_LOCALLY()) {
// Try to call Dalai server, through Flask:
let {query, response, error} = await call_flask_backend('callDalai', {
@ -641,7 +644,7 @@ export async function call_dalai(prompt: string, model: LLM, n: number = 1, temp
}
export async function call_huggingface(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
export async function call_huggingface(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict, should_cancel?: (() => boolean)): Promise<[Dict, Dict]> {
// Whether we should notice a given param in 'params'
const param_exists = (p: any) => (p !== undefined && !((typeof p === 'number' && p < 0) || (typeof p === 'string' && p.trim().length === 0)));
const set_param_if_exists = (name: string, query: Dict) => {
@ -710,7 +713,11 @@ export async function call_huggingface(prompt: string, model: LLM, n: number = 1
let continued_response: Dict = { generated_text: "" };
let curr_cont = 0;
let curr_text = prompt;
while (curr_cont <= num_continuations) {
// Abort if user canceled the query operation
if (should_cancel && should_cancel()) throw new UserForcedPrematureExit();
const inputs = (model_type === 'chat')
? ({ text: curr_text,
past_user_inputs: hf_chat_hist.past_user_inputs,
@ -749,7 +756,7 @@ export async function call_huggingface(prompt: string, model: LLM, n: number = 1
}
export async function call_alephalpha(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
export async function call_alephalpha(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict, should_cancel?: (() => boolean)): Promise<[Dict, Dict]> {
if (!ALEPH_ALPHA_API_KEY)
throw Error("Could not find an API key for Aleph Alpha models. Double-check that your API key is set in Settings or in your local environment.");
@ -784,7 +791,7 @@ export async function call_alephalpha(prompt: string, model: LLM, n: number = 1,
return [query, responses];
}
export async function call_ollama_provider(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
export async function call_ollama_provider(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict, should_cancel?: (() => boolean)): Promise<[Dict, Dict]> {
let url: string = appendEndSlashIfMissing(params?.ollama_url);
const ollama_model: string = params?.ollamaModel.toString();
const model_type: string = params?.model_type ?? "text";
@ -820,10 +827,15 @@ export async function call_ollama_provider(prompt: string, model: LLM, n: number
// Call Ollama API
let resps : Response[] = [];
for (let i = 0; i < n; i++) {
// Abort if the user canceled
if (should_cancel && should_cancel()) throw new UserForcedPrematureExit();
// Query Ollama and collect the response
const response = await fetch(url, {
method: "POST",
body: JSON.stringify(query),
});
resps.push(response);
}
@ -842,7 +854,7 @@ export async function call_ollama_provider(prompt: string, model: LLM, n: number
return [query, responses];
}
async function call_custom_provider(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
async function call_custom_provider(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict, should_cancel?: (() => boolean)): Promise<[Dict, Dict]> {
if (!APP_IS_RUNNING_LOCALLY())
throw new Error("The ChainForge app does not appear to be running locally. You can only call custom model providers if you are running ChainForge on your local machine, from a Flask app.")
@ -859,6 +871,10 @@ async function call_custom_provider(prompt: string, model: LLM, n: number = 1, t
// Call the custom provider n times
while (responses.length < n) {
// Abort if the user canceled
if (should_cancel && should_cancel()) throw new UserForcedPrematureExit();
// Collect response from the custom provider
let {response, error} = await call_flask_backend('callCustomProvider',
{ 'name': provider_name,
'params': {
@ -878,7 +894,7 @@ async function call_custom_provider(prompt: string, model: LLM, n: number = 1, t
/**
* Switcher that routes the request to the appropriate API call function. If call doesn't exist, throws error.
*/
export async function call_llm(llm: LLM, prompt: string, n: number, temperature: number, params?: Dict): Promise<[Dict, Dict]> {
export async function call_llm(llm: LLM, prompt: string, n: number, temperature: number, params?: Dict, should_cancel?: (() => boolean)): Promise<[Dict, Dict]> {
// Get the correct API call for the given LLM:
let call_api: LLMAPICall | undefined;
let llm_provider: LLMProvider = getProvider(llm);
@ -905,7 +921,7 @@ export async function call_llm(llm: LLM, prompt: string, n: number, temperature:
else if (llm_provider === LLMProvider.Custom)
call_api = call_custom_provider;
return call_api(prompt, llm, n, temperature, params);
return call_api(prompt, llm, n, temperature, params, should_cancel);
}

View File

@ -15,7 +15,7 @@ async function _route_to_js_backend(route, params) {
case 'generatePrompts':
return generatePrompts(params.prompt, clone(params.vars));
case 'queryllm':
return queryLLM(params.id, clone(params.llm), params.n, params.prompt, clone(params.vars), params.chat_histories, params.api_keys, params.no_cache, params.progress_listener, params.cont_only_w_prior_llms);
return queryLLM(params.id, clone(params.llm), params.n, params.prompt, clone(params.vars), params.chat_histories, params.api_keys, params.no_cache, params.progress_listener, params.cont_only_w_prior_llms, params.cancel_id);
case 'executejs':
return executejs(params.id, params.code, params.responses, params.scope, params.process_type);
case 'executepy':

View File

@ -70,8 +70,8 @@
display: inline-block;
position: relative;
margin-left: 4px;
width: 18px;
height: 18px;
width: 16px;
height: 16px;
}
.lds-ring div {
box-sizing: border-box;
@ -748,6 +748,7 @@
.AmitSahoo45-button-3 {
position: relative;
padding: 2px 10px;
height: 20px;
margin-top: -5px;
margin-right: 3px;
border-radius: 5px;

View File

@ -6,7 +6,7 @@ def readme():
setup(
name='chainforge',
version='0.2.8.9',
version='0.2.9.0',
packages=find_packages(),
author="Ian Arawjo",
description="A Visual Programming Environment for Prompt Engineering",