diff --git a/chainforge/react-server/src/PromptNode.js b/chainforge/react-server/src/PromptNode.tsx similarity index 92% rename from chainforge/react-server/src/PromptNode.js rename to chainforge/react-server/src/PromptNode.tsx index fc74324..b72587f 100644 --- a/chainforge/react-server/src/PromptNode.js +++ b/chainforge/react-server/src/PromptNode.tsx @@ -5,7 +5,7 @@ import React, { useCallback, useMemo, } from "react"; -import { Handle } from "reactflow"; +import { Handle, Position } from "reactflow"; import { v4 as uuid } from "uuid"; import { Switch, @@ -38,12 +38,16 @@ import { getLLMsInPulledInputData, extractSettingsVars, truncStr, + genDebounceFunc, } from "./backend/utils"; import LLMResponseInspectorDrawer from "./LLMResponseInspectorDrawer"; import CancelTracker from "./backend/canceler"; import { UserForcedPrematureExit } from "./backend/errors"; +import { ChatHistoryInfo, Dict, LLMSpec, QueryProgress, StandardizedLLMResponse, TemplateVarInfo } from "./backend/typing"; +import { AlertModalHandles } from "./AlertModal"; +import { Status } from "./StatusIndicatorComponent"; -const getUniqueLLMMetavarKey = (responses) => { +const getUniqueLLMMetavarKey = (responses: StandardizedLLMResponse[]) => { const metakeys = new Set( responses.map((resp_obj) => Object.keys(resp_obj.metavars)).flat(), ); @@ -51,7 +55,7 @@ const getUniqueLLMMetavarKey = (responses) => { while (metakeys.has(`LLM_${i}`)) i += 1; return `LLM_${i}`; }; -const bucketChatHistoryInfosByLLM = (chat_hist_infos) => { +const bucketChatHistoryInfosByLLM = (chat_hist_infos: ChatHistoryInfo[]) => { const chats_by_llm = {}; chat_hist_infos.forEach((chat_hist_info) => { if (chat_hist_info.llm in chats_by_llm) @@ -62,16 +66,16 @@ const bucketChatHistoryInfosByLLM = (chat_hist_infos) => { }; class PromptInfo { - prompt; // string - settings; // dict for any settings vars + prompt: string; + settings: Dict; - constructor(prompt, settings) { + constructor(prompt: string, settings: Dict) { this.prompt = prompt; this.settings = settings; } } -const displayPromptInfos = (promptInfos, wideFormat) => +const displayPromptInfos = (promptInfos: PromptInfo[], wideFormat: boolean) => promptInfos.map((info, idx) => (
{info.prompt}
@@ -81,7 +85,7 @@ const displayPromptInfos = (promptInfos, wideFormat) =>
{key} =  - {truncStr(val, wideFormat ? 512 : 72)} + {truncStr(val.toString(), wideFormat ? 512 : 72)}
); @@ -169,22 +173,22 @@ const PromptNode = ({ data, id, type: node_type }) => { // API Keys (set by user in popup GlobalSettingsModal) const apiKeys = useStore((state) => state.apiKeys); - const [jsonResponses, setJSONResponses] = useState(null); - const [templateVars, setTemplateVars] = useState(data.vars || []); - const [promptText, setPromptText] = useState(data.prompt || ""); - const [promptTextOnLastRun, setPromptTextOnLastRun] = useState(null); - const [status, setStatus] = useState("none"); - const [numGenerations, setNumGenerations] = useState(data.n || 1); - const [numGenerationsLastRun, setNumGenerationsLastRun] = useState( - data.n || 1, + const [jsonResponses, setJSONResponses] = useState(null); + const [templateVars, setTemplateVars] = useState(data.vars ?? []); + const [promptText, setPromptText] = useState(data.prompt ?? ""); + const [promptTextOnLastRun, setPromptTextOnLastRun] = useState(null); + const [status, setStatus] = useState(Status.NONE); + const [numGenerations, setNumGenerations] = useState(data.n ?? 1); + const [numGenerationsLastRun, setNumGenerationsLastRun] = useState( + data.n ?? 1, ); // The LLM items container const llmListContainer = useRef(null); - const [llmItemsCurrState, setLLMItemsCurrState] = useState([]); + const [llmItemsCurrState, setLLMItemsCurrState] = useState([]); // For displaying error messages to user - const alertModal = useRef(null); + const alertModal = useRef(null); // For a way to inspect responses without having to attach a dedicated node const inspectModal = useRef(null); @@ -194,10 +198,10 @@ const PromptNode = ({ data, id, type: node_type }) => { const [showDrawer, setShowDrawer] = useState(false); // For continuing with prior LLMs toggle - const [contWithPriorLLMs, setContWithPriorLLMs] = useState( + const [contWithPriorLLMs, setContWithPriorLLMs] = useState( data.contChat !== undefined ? data.contChat : node_type === "chat", ); - const [showContToggle, setShowContToggle] = useState(node_type === "chat"); + const [showContToggle, setShowContToggle] = useState(node_type === "chat"); const [contToggleDisabled, setContChatToggleDisabled] = useState(false); // For an info pop-up that shows all the prompts that will be sent off @@ -206,9 +210,9 @@ const PromptNode = ({ data, id, type: node_type }) => { useDisclosure(false); // Progress when querying responses - const [progress, setProgress] = useState(undefined); + const [progress, setProgress] = useState(undefined); const [progressAnimated, setProgressAnimated] = useState(true); - const [runTooltip, setRunTooltip] = useState(null); + const [runTooltip, setRunTooltip] = useState(undefined); // Cancelation of pending queries const [cancelId, setCancelId] = useState(Date.now()); @@ -216,19 +220,10 @@ const PromptNode = ({ data, id, type: node_type }) => { // Debounce helpers const debounceTimeoutRef = useRef(null); - const debounce = (func, delay) => { - return (...args) => { - if (debounceTimeoutRef.current) { - clearTimeout(debounceTimeoutRef.current); - } - debounceTimeoutRef.current = setTimeout(() => { - func(...args); - }, delay); - }; - }; + const debounce = genDebounceFunc(debounceTimeoutRef); const triggerAlert = useCallback( - (msg) => { + (msg: string) => { setProgress(undefined); llmListContainer?.current?.resetLLMItemsProgress(); alertModal?.current?.trigger(msg); @@ -245,8 +240,8 @@ const PromptNode = ({ data, id, type: node_type }) => { // Signal that prompt node state is dirty; user should re-run: const signalDirty = useCallback(() => { - if (promptTextOnLastRun !== null && status === "ready") - setStatus("warning"); + if (promptTextOnLastRun !== null && status === Status.READY) + setStatus(Status.WARNING); }, [promptTextOnLastRun, status]); const onLLMListItemsChange = useCallback( @@ -309,7 +304,7 @@ const PromptNode = ({ data, id, type: node_type }) => { }, [templateVars, id, pullInputData, updateShowContToggle]); const refreshTemplateHooks = useCallback( - (text) => { + (text: string) => { // Update template var fields + handles const found_template_vars = new Set(extractBracketedSubstrings(text)); // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this} @@ -357,7 +352,7 @@ const PromptNode = ({ data, id, type: node_type }) => { if (json.responses && json.responses.length > 0) { // Store responses and set status to green checkmark setJSONResponses(json.responses); - setStatus("ready"); + setStatus(Status.READY); } }); }, []); @@ -371,7 +366,7 @@ const PromptNode = ({ data, id, type: node_type }) => { useEffect(() => { if (refresh === true) { setDataPropsForNode(id, { refresh: false }); - setStatus("warning"); + setStatus(Status.WARNING); handleOnConnect(); } else if (refreshLLMList === true) { llmListContainer?.current?.refreshLLMProviderList(); @@ -387,17 +382,16 @@ const PromptNode = ({ data, id, type: node_type }) => { // For storing the unique LLMs in past_chats: const llm_names = new Set(); - const past_chat_llms = []; + const past_chat_llms: (LLMSpec | string)[] = []; // We need to calculate the conversation history from the pulled responses. // Note that TemplateVarInfo might have a 'chat_history' component, but this does not // include the most recent prompt and response --for that, we need to use the 'prompt' and 'text' items. // We need to create a revised chat history that concatenates the past history with the last AI + human turns: - const past_chats = pulled_data.__past_chats.map((info) => { + const past_chats = pulled_data.__past_chats.map((info: TemplateVarInfo) => { // Add to unique LLMs list, if necessary - const llm_name = info?.llm?.name; - if (llm_name !== undefined && !llm_names.has(llm_name)) { - llm_names.add(llm_name); + if (typeof info?.llm !== "string" && info?.llm?.name !== undefined && !llm_names.has(info.llm.name)) { + llm_names.add(info.llm.name); past_chat_llms.push(info.llm); } @@ -414,6 +408,7 @@ const PromptNode = ({ data, id, type: node_type }) => { // Append any present system message retroactively as the first message in the chat history: if ( + typeof info?.llm !== "string" && info?.llm?.settings?.system_msg !== undefined && updated_chat_hist[0].role !== "system" ) @@ -426,7 +421,7 @@ const PromptNode = ({ data, id, type: node_type }) => { messages: updated_chat_hist, fill_history: info.fill_history, metavars: info.metavars, - llm: llm_name, + llm: info?.llm?.name, batch_id: uuid(), }; }); @@ -480,7 +475,7 @@ const PromptNode = ({ data, id, type: node_type }) => { }).then((prompts) => { setPromptPreviews( prompts.map( - (p) => + (p: string) => new PromptInfo(p.toString(), extractSettingsVars(p.fill_history)), ), ); @@ -565,7 +560,7 @@ const PromptNode = ({ data, id, type: node_type }) => { setResponsesWillChange(true); // Tally how many queries per LLM: - const queries_per_llm = {}; + const queries_per_llm: Dict = {}; Object.keys(counts).forEach((llm_key) => { queries_per_llm[llm_key] = Object.keys(counts[llm_key]).reduce( (acc, prompt) => acc + counts[llm_key][prompt], @@ -711,12 +706,12 @@ Soft failing by replacing undefined with empty strings.`, } // Set status indicator - setStatus("loading"); + setStatus(Status.LOADING); setContChatToggleDisabled(true); setJSONResponses([]); setProgressAnimated(true); - const rejected = (err) => { + const rejected = (err: Error | string) => { if ( err instanceof UserForcedPrematureExit || CancelTracker.has(cancelId) @@ -724,9 +719,9 @@ Soft failing by replacing undefined with empty strings.`, // Handle a premature cancelation console.log("Canceled."); } else { - setStatus("error"); + setStatus(Status.ERROR); setContChatToggleDisabled(false); - triggerAlert(err.message || err); + triggerAlert(typeof err === "string" ? err : err?.message); } }; @@ -889,7 +884,7 @@ Soft failing by replacing undefined with empty strings.`, llmListContainer?.current?.ensureLLMItemsErrorProgress(llms_w_errors); // Set error status - setStatus("error"); + setStatus(Status.ERROR); setContChatToggleDisabled(false); // Trigger alert and display one error message per LLM of all collected errors: @@ -934,7 +929,7 @@ Soft failing by replacing undefined with empty strings.`, setNumGenerationsLastRun(numGenerations); // All responses collected! Change status to 'ready': - setStatus("ready"); + setStatus(Status.READY); // Ping any inspect nodes attached to this node to refresh their contents: pingOutputNodes(id); @@ -961,7 +956,7 @@ Soft failing by replacing undefined with empty strings.`, debounce(() => {}, 1)(); // erase any pending debounces // Set error status - setStatus("none"); + setStatus(Status.NONE); setContChatToggleDisabled(false); llmListContainer?.current?.resetLLMItemsProgress(); }, [cancelId, refreshCancelId]); @@ -972,8 +967,8 @@ Soft failing by replacing undefined with empty strings.`, if (!isNaN(n) && n.length > 0 && /^\d+$/.test(n)) { // n is an integer; save it n = parseInt(n); - if (n !== numGenerationsLastRun && status === "ready") - setStatus("warning"); + if (n !== numGenerationsLastRun && status === Status.READY) + setStatus(Status.WARNING); setNumGenerations(n); setDataPropsForNode(id, { n }); } @@ -982,14 +977,14 @@ Soft failing by replacing undefined with empty strings.`, ); const hideStatusIndicator = () => { - if (status !== "none") setStatus("none"); + if (status !== Status.NONE) setStatus(Status.NONE); }; // Dynamically update the textareas and position of the template hooks - const textAreaRef = useRef(null); + const textAreaRef = useRef(null); const [hooksY, setHooksY] = useState(138); const setRef = useCallback( - (elem) => { + (elem: HTMLDivElement | HTMLTextAreaElement) => { // To listen for resize events of the textarea, we need to use a ResizeObserver. // We initialize the ResizeObserver only once, when the 'ref' is first set, and only on the div wrapping textfields. // NOTE: This won't work on older browsers, but there's no alternative solution. @@ -1099,7 +1094,7 @@ Soft failing by replacing undefined with empty strings.`,
@@ -1140,7 +1136,7 @@ Soft failing by replacing undefined with empty strings.`, checked={contWithPriorLLMs} disabled={contToggleDisabled} onChange={(event) => { - setStatus("warning"); + setStatus(Status.WARNING); setContWithPriorLLMs(event.currentTarget.checked); setDataPropsForNode(id, { contChat: event.currentTarget.checked, diff --git a/chainforge/react-server/src/StatusIndicatorComponent.tsx b/chainforge/react-server/src/StatusIndicatorComponent.tsx index 47458b7..f6c9544 100644 --- a/chainforge/react-server/src/StatusIndicatorComponent.tsx +++ b/chainforge/react-server/src/StatusIndicatorComponent.tsx @@ -5,6 +5,7 @@ export enum Status { READY = "ready", ERROR = "error", LOADING = "loading", + NONE = "none", } interface StatusIndicatorProps { status: Status; diff --git a/chainforge/react-server/src/backend/typing.ts b/chainforge/react-server/src/backend/typing.ts index 2d7aff6..3b8bf16 100644 --- a/chainforge/react-server/src/backend/typing.ts +++ b/chainforge/react-server/src/backend/typing.ts @@ -23,21 +23,6 @@ export interface OpenAIFunctionCall { description?: string; } -/** The outputs of prompt nodes, text fields or other data passed internally in the front-end and to the PromptTemplate backend. - * Used to populate prompt templates and carry variables/metavariables along the chain. */ -export interface TemplateVarInfo { - text: string; - fill_history: Dict; - metavars?: Dict; - associate_id?: string; - llm?: string | Dict; -} - -export type PromptVarType = string | TemplateVarInfo; -export type PromptVarsDict = { - [key: string]: PromptVarType[]; -}; - /** OpenAI chat message format */ export interface ChatMessage { role: string; @@ -183,3 +168,24 @@ export type LLMSpec = { settings?: Dict; progress?: Dict; // only used for front-end to display progress collecting responses for this LLM }; + +/** The outputs of prompt nodes, text fields or other data passed internally in the front-end and to the PromptTemplate backend. + * Used to populate prompt templates and carry variables/metavariables along the chain. */ + export interface TemplateVarInfo { + text: string; + fill_history: Dict; + metavars?: Dict; + associate_id?: string; + llm?: string | LLMSpec; + chat_history?: ChatHistory; +} + +export type PromptVarType = string | TemplateVarInfo; +export type PromptVarsDict = { + [key: string]: PromptVarType[]; +}; + +export type QueryProgress = { + success: number; + error: number; +} \ No newline at end of file