mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-15 00:36:29 +00:00
wip prompt node
This commit is contained in:
parent
6139e5aa4a
commit
a576eba6f8
@ -5,7 +5,7 @@ import React, {
|
||||
useCallback,
|
||||
useMemo,
|
||||
} from "react";
|
||||
import { Handle } from "reactflow";
|
||||
import { Handle, Position } from "reactflow";
|
||||
import { v4 as uuid } from "uuid";
|
||||
import {
|
||||
Switch,
|
||||
@ -38,12 +38,16 @@ import {
|
||||
getLLMsInPulledInputData,
|
||||
extractSettingsVars,
|
||||
truncStr,
|
||||
genDebounceFunc,
|
||||
} from "./backend/utils";
|
||||
import LLMResponseInspectorDrawer from "./LLMResponseInspectorDrawer";
|
||||
import CancelTracker from "./backend/canceler";
|
||||
import { UserForcedPrematureExit } from "./backend/errors";
|
||||
import { ChatHistoryInfo, Dict, LLMSpec, QueryProgress, StandardizedLLMResponse, TemplateVarInfo } from "./backend/typing";
|
||||
import { AlertModalHandles } from "./AlertModal";
|
||||
import { Status } from "./StatusIndicatorComponent";
|
||||
|
||||
const getUniqueLLMMetavarKey = (responses) => {
|
||||
const getUniqueLLMMetavarKey = (responses: StandardizedLLMResponse[]) => {
|
||||
const metakeys = new Set(
|
||||
responses.map((resp_obj) => Object.keys(resp_obj.metavars)).flat(),
|
||||
);
|
||||
@ -51,7 +55,7 @@ const getUniqueLLMMetavarKey = (responses) => {
|
||||
while (metakeys.has(`LLM_${i}`)) i += 1;
|
||||
return `LLM_${i}`;
|
||||
};
|
||||
const bucketChatHistoryInfosByLLM = (chat_hist_infos) => {
|
||||
const bucketChatHistoryInfosByLLM = (chat_hist_infos: ChatHistoryInfo[]) => {
|
||||
const chats_by_llm = {};
|
||||
chat_hist_infos.forEach((chat_hist_info) => {
|
||||
if (chat_hist_info.llm in chats_by_llm)
|
||||
@ -62,16 +66,16 @@ const bucketChatHistoryInfosByLLM = (chat_hist_infos) => {
|
||||
};
|
||||
|
||||
class PromptInfo {
|
||||
prompt; // string
|
||||
settings; // dict for any settings vars
|
||||
prompt: string;
|
||||
settings: Dict;
|
||||
|
||||
constructor(prompt, settings) {
|
||||
constructor(prompt: string, settings: Dict) {
|
||||
this.prompt = prompt;
|
||||
this.settings = settings;
|
||||
}
|
||||
}
|
||||
|
||||
const displayPromptInfos = (promptInfos, wideFormat) =>
|
||||
const displayPromptInfos = (promptInfos: PromptInfo[], wideFormat: boolean) =>
|
||||
promptInfos.map((info, idx) => (
|
||||
<div key={idx}>
|
||||
<div className="prompt-preview">{info.prompt}</div>
|
||||
@ -81,7 +85,7 @@ const displayPromptInfos = (promptInfos, wideFormat) =>
|
||||
<div key={key} className="settings-var-inline response-var-inline">
|
||||
<span className="response-var-name">{key} = </span>
|
||||
<span className="response-var-value wrap-line">
|
||||
{truncStr(val, wideFormat ? 512 : 72)}
|
||||
{truncStr(val.toString(), wideFormat ? 512 : 72)}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
@ -169,22 +173,22 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
// API Keys (set by user in popup GlobalSettingsModal)
|
||||
const apiKeys = useStore((state) => state.apiKeys);
|
||||
|
||||
const [jsonResponses, setJSONResponses] = useState(null);
|
||||
const [templateVars, setTemplateVars] = useState(data.vars || []);
|
||||
const [promptText, setPromptText] = useState(data.prompt || "");
|
||||
const [promptTextOnLastRun, setPromptTextOnLastRun] = useState(null);
|
||||
const [status, setStatus] = useState("none");
|
||||
const [numGenerations, setNumGenerations] = useState(data.n || 1);
|
||||
const [numGenerationsLastRun, setNumGenerationsLastRun] = useState(
|
||||
data.n || 1,
|
||||
const [jsonResponses, setJSONResponses] = useState<StandardizedLLMResponse[] | null>(null);
|
||||
const [templateVars, setTemplateVars] = useState<string[]>(data.vars ?? []);
|
||||
const [promptText, setPromptText] = useState<string>(data.prompt ?? "");
|
||||
const [promptTextOnLastRun, setPromptTextOnLastRun] = useState<string | null>(null);
|
||||
const [status, setStatus] = useState(Status.NONE);
|
||||
const [numGenerations, setNumGenerations] = useState<number>(data.n ?? 1);
|
||||
const [numGenerationsLastRun, setNumGenerationsLastRun] = useState<number>(
|
||||
data.n ?? 1,
|
||||
);
|
||||
|
||||
// The LLM items container
|
||||
const llmListContainer = useRef(null);
|
||||
const [llmItemsCurrState, setLLMItemsCurrState] = useState([]);
|
||||
const [llmItemsCurrState, setLLMItemsCurrState] = useState<LLMSpec[]>([]);
|
||||
|
||||
// For displaying error messages to user
|
||||
const alertModal = useRef(null);
|
||||
const alertModal = useRef<AlertModalHandles | null>(null);
|
||||
|
||||
// For a way to inspect responses without having to attach a dedicated node
|
||||
const inspectModal = useRef(null);
|
||||
@ -194,10 +198,10 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
const [showDrawer, setShowDrawer] = useState(false);
|
||||
|
||||
// For continuing with prior LLMs toggle
|
||||
const [contWithPriorLLMs, setContWithPriorLLMs] = useState(
|
||||
const [contWithPriorLLMs, setContWithPriorLLMs] = useState<boolean>(
|
||||
data.contChat !== undefined ? data.contChat : node_type === "chat",
|
||||
);
|
||||
const [showContToggle, setShowContToggle] = useState(node_type === "chat");
|
||||
const [showContToggle, setShowContToggle] = useState<boolean>(node_type === "chat");
|
||||
const [contToggleDisabled, setContChatToggleDisabled] = useState(false);
|
||||
|
||||
// For an info pop-up that shows all the prompts that will be sent off
|
||||
@ -206,9 +210,9 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
useDisclosure(false);
|
||||
|
||||
// Progress when querying responses
|
||||
const [progress, setProgress] = useState(undefined);
|
||||
const [progress, setProgress] = useState<QueryProgress | undefined>(undefined);
|
||||
const [progressAnimated, setProgressAnimated] = useState(true);
|
||||
const [runTooltip, setRunTooltip] = useState(null);
|
||||
const [runTooltip, setRunTooltip] = useState<string | null>(undefined);
|
||||
|
||||
// Cancelation of pending queries
|
||||
const [cancelId, setCancelId] = useState(Date.now());
|
||||
@ -216,19 +220,10 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
|
||||
// Debounce helpers
|
||||
const debounceTimeoutRef = useRef(null);
|
||||
const debounce = (func, delay) => {
|
||||
return (...args) => {
|
||||
if (debounceTimeoutRef.current) {
|
||||
clearTimeout(debounceTimeoutRef.current);
|
||||
}
|
||||
debounceTimeoutRef.current = setTimeout(() => {
|
||||
func(...args);
|
||||
}, delay);
|
||||
};
|
||||
};
|
||||
const debounce = genDebounceFunc(debounceTimeoutRef);
|
||||
|
||||
const triggerAlert = useCallback(
|
||||
(msg) => {
|
||||
(msg: string) => {
|
||||
setProgress(undefined);
|
||||
llmListContainer?.current?.resetLLMItemsProgress();
|
||||
alertModal?.current?.trigger(msg);
|
||||
@ -245,8 +240,8 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
|
||||
// Signal that prompt node state is dirty; user should re-run:
|
||||
const signalDirty = useCallback(() => {
|
||||
if (promptTextOnLastRun !== null && status === "ready")
|
||||
setStatus("warning");
|
||||
if (promptTextOnLastRun !== null && status === Status.READY)
|
||||
setStatus(Status.WARNING);
|
||||
}, [promptTextOnLastRun, status]);
|
||||
|
||||
const onLLMListItemsChange = useCallback(
|
||||
@ -309,7 +304,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
}, [templateVars, id, pullInputData, updateShowContToggle]);
|
||||
|
||||
const refreshTemplateHooks = useCallback(
|
||||
(text) => {
|
||||
(text: string) => {
|
||||
// Update template var fields + handles
|
||||
const found_template_vars = new Set(extractBracketedSubstrings(text)); // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this}
|
||||
|
||||
@ -357,7 +352,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
if (json.responses && json.responses.length > 0) {
|
||||
// Store responses and set status to green checkmark
|
||||
setJSONResponses(json.responses);
|
||||
setStatus("ready");
|
||||
setStatus(Status.READY);
|
||||
}
|
||||
});
|
||||
}, []);
|
||||
@ -371,7 +366,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
useEffect(() => {
|
||||
if (refresh === true) {
|
||||
setDataPropsForNode(id, { refresh: false });
|
||||
setStatus("warning");
|
||||
setStatus(Status.WARNING);
|
||||
handleOnConnect();
|
||||
} else if (refreshLLMList === true) {
|
||||
llmListContainer?.current?.refreshLLMProviderList();
|
||||
@ -387,17 +382,16 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
|
||||
// For storing the unique LLMs in past_chats:
|
||||
const llm_names = new Set();
|
||||
const past_chat_llms = [];
|
||||
const past_chat_llms: (LLMSpec | string)[] = [];
|
||||
|
||||
// We need to calculate the conversation history from the pulled responses.
|
||||
// Note that TemplateVarInfo might have a 'chat_history' component, but this does not
|
||||
// include the most recent prompt and response --for that, we need to use the 'prompt' and 'text' items.
|
||||
// We need to create a revised chat history that concatenates the past history with the last AI + human turns:
|
||||
const past_chats = pulled_data.__past_chats.map((info) => {
|
||||
const past_chats = pulled_data.__past_chats.map((info: TemplateVarInfo) => {
|
||||
// Add to unique LLMs list, if necessary
|
||||
const llm_name = info?.llm?.name;
|
||||
if (llm_name !== undefined && !llm_names.has(llm_name)) {
|
||||
llm_names.add(llm_name);
|
||||
if (typeof info?.llm !== "string" && info?.llm?.name !== undefined && !llm_names.has(info.llm.name)) {
|
||||
llm_names.add(info.llm.name);
|
||||
past_chat_llms.push(info.llm);
|
||||
}
|
||||
|
||||
@ -414,6 +408,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
|
||||
// Append any present system message retroactively as the first message in the chat history:
|
||||
if (
|
||||
typeof info?.llm !== "string" &&
|
||||
info?.llm?.settings?.system_msg !== undefined &&
|
||||
updated_chat_hist[0].role !== "system"
|
||||
)
|
||||
@ -426,7 +421,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
messages: updated_chat_hist,
|
||||
fill_history: info.fill_history,
|
||||
metavars: info.metavars,
|
||||
llm: llm_name,
|
||||
llm: info?.llm?.name,
|
||||
batch_id: uuid(),
|
||||
};
|
||||
});
|
||||
@ -480,7 +475,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
}).then((prompts) => {
|
||||
setPromptPreviews(
|
||||
prompts.map(
|
||||
(p) =>
|
||||
(p: string) =>
|
||||
new PromptInfo(p.toString(), extractSettingsVars(p.fill_history)),
|
||||
),
|
||||
);
|
||||
@ -565,7 +560,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
setResponsesWillChange(true);
|
||||
|
||||
// Tally how many queries per LLM:
|
||||
const queries_per_llm = {};
|
||||
const queries_per_llm: Dict<number> = {};
|
||||
Object.keys(counts).forEach((llm_key) => {
|
||||
queries_per_llm[llm_key] = Object.keys(counts[llm_key]).reduce(
|
||||
(acc, prompt) => acc + counts[llm_key][prompt],
|
||||
@ -711,12 +706,12 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
}
|
||||
|
||||
// Set status indicator
|
||||
setStatus("loading");
|
||||
setStatus(Status.LOADING);
|
||||
setContChatToggleDisabled(true);
|
||||
setJSONResponses([]);
|
||||
setProgressAnimated(true);
|
||||
|
||||
const rejected = (err) => {
|
||||
const rejected = (err: Error | string) => {
|
||||
if (
|
||||
err instanceof UserForcedPrematureExit ||
|
||||
CancelTracker.has(cancelId)
|
||||
@ -724,9 +719,9 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
// Handle a premature cancelation
|
||||
console.log("Canceled.");
|
||||
} else {
|
||||
setStatus("error");
|
||||
setStatus(Status.ERROR);
|
||||
setContChatToggleDisabled(false);
|
||||
triggerAlert(err.message || err);
|
||||
triggerAlert(typeof err === "string" ? err : err?.message);
|
||||
}
|
||||
};
|
||||
|
||||
@ -889,7 +884,7 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
llmListContainer?.current?.ensureLLMItemsErrorProgress(llms_w_errors);
|
||||
|
||||
// Set error status
|
||||
setStatus("error");
|
||||
setStatus(Status.ERROR);
|
||||
setContChatToggleDisabled(false);
|
||||
|
||||
// Trigger alert and display one error message per LLM of all collected errors:
|
||||
@ -934,7 +929,7 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
setNumGenerationsLastRun(numGenerations);
|
||||
|
||||
// All responses collected! Change status to 'ready':
|
||||
setStatus("ready");
|
||||
setStatus(Status.READY);
|
||||
|
||||
// Ping any inspect nodes attached to this node to refresh their contents:
|
||||
pingOutputNodes(id);
|
||||
@ -961,7 +956,7 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
debounce(() => {}, 1)(); // erase any pending debounces
|
||||
|
||||
// Set error status
|
||||
setStatus("none");
|
||||
setStatus(Status.NONE);
|
||||
setContChatToggleDisabled(false);
|
||||
llmListContainer?.current?.resetLLMItemsProgress();
|
||||
}, [cancelId, refreshCancelId]);
|
||||
@ -972,8 +967,8 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
if (!isNaN(n) && n.length > 0 && /^\d+$/.test(n)) {
|
||||
// n is an integer; save it
|
||||
n = parseInt(n);
|
||||
if (n !== numGenerationsLastRun && status === "ready")
|
||||
setStatus("warning");
|
||||
if (n !== numGenerationsLastRun && status === Status.READY)
|
||||
setStatus(Status.WARNING);
|
||||
setNumGenerations(n);
|
||||
setDataPropsForNode(id, { n });
|
||||
}
|
||||
@ -982,14 +977,14 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
);
|
||||
|
||||
const hideStatusIndicator = () => {
|
||||
if (status !== "none") setStatus("none");
|
||||
if (status !== Status.NONE) setStatus(Status.NONE);
|
||||
};
|
||||
|
||||
// Dynamically update the textareas and position of the template hooks
|
||||
const textAreaRef = useRef(null);
|
||||
const textAreaRef = useRef<HTMLTextAreaElement | HTMLDivElement | null>(null);
|
||||
const [hooksY, setHooksY] = useState(138);
|
||||
const setRef = useCallback(
|
||||
(elem) => {
|
||||
(elem: HTMLDivElement | HTMLTextAreaElement) => {
|
||||
// To listen for resize events of the textarea, we need to use a ResizeObserver.
|
||||
// We initialize the ResizeObserver only once, when the 'ref' is first set, and only on the div wrapping textfields.
|
||||
// NOTE: This won't work on older browsers, but there's no alternative solution.
|
||||
@ -1099,7 +1094,7 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
|
||||
<Handle
|
||||
type="source"
|
||||
position="right"
|
||||
position={Position.Right}
|
||||
id="prompt"
|
||||
className="grouped-handle"
|
||||
style={{ top: "50%" }}
|
||||
@ -1108,6 +1103,7 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
vars={templateVars}
|
||||
nodeId={id}
|
||||
startY={hooksY}
|
||||
position={Position.Left}
|
||||
ignoreHandles={["__past_chats"]}
|
||||
/>
|
||||
<hr />
|
||||
@ -1140,7 +1136,7 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
checked={contWithPriorLLMs}
|
||||
disabled={contToggleDisabled}
|
||||
onChange={(event) => {
|
||||
setStatus("warning");
|
||||
setStatus(Status.WARNING);
|
||||
setContWithPriorLLMs(event.currentTarget.checked);
|
||||
setDataPropsForNode(id, {
|
||||
contChat: event.currentTarget.checked,
|
@ -5,6 +5,7 @@ export enum Status {
|
||||
READY = "ready",
|
||||
ERROR = "error",
|
||||
LOADING = "loading",
|
||||
NONE = "none",
|
||||
}
|
||||
interface StatusIndicatorProps {
|
||||
status: Status;
|
||||
|
@ -23,21 +23,6 @@ export interface OpenAIFunctionCall {
|
||||
description?: string;
|
||||
}
|
||||
|
||||
/** The outputs of prompt nodes, text fields or other data passed internally in the front-end and to the PromptTemplate backend.
|
||||
* Used to populate prompt templates and carry variables/metavariables along the chain. */
|
||||
export interface TemplateVarInfo {
|
||||
text: string;
|
||||
fill_history: Dict<string>;
|
||||
metavars?: Dict<string>;
|
||||
associate_id?: string;
|
||||
llm?: string | Dict;
|
||||
}
|
||||
|
||||
export type PromptVarType = string | TemplateVarInfo;
|
||||
export type PromptVarsDict = {
|
||||
[key: string]: PromptVarType[];
|
||||
};
|
||||
|
||||
/** OpenAI chat message format */
|
||||
export interface ChatMessage {
|
||||
role: string;
|
||||
@ -183,3 +168,24 @@ export type LLMSpec = {
|
||||
settings?: Dict;
|
||||
progress?: Dict<number>; // only used for front-end to display progress collecting responses for this LLM
|
||||
};
|
||||
|
||||
/** The outputs of prompt nodes, text fields or other data passed internally in the front-end and to the PromptTemplate backend.
|
||||
* Used to populate prompt templates and carry variables/metavariables along the chain. */
|
||||
export interface TemplateVarInfo {
|
||||
text: string;
|
||||
fill_history: Dict<string>;
|
||||
metavars?: Dict<string>;
|
||||
associate_id?: string;
|
||||
llm?: string | LLMSpec;
|
||||
chat_history?: ChatHistory;
|
||||
}
|
||||
|
||||
export type PromptVarType = string | TemplateVarInfo;
|
||||
export type PromptVarsDict = {
|
||||
[key: string]: PromptVarType[];
|
||||
};
|
||||
|
||||
export type QueryProgress = {
|
||||
success: number;
|
||||
error: number;
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user