wip prompt node

This commit is contained in:
Ian Arawjo 2024-03-02 23:15:42 -05:00
parent 6139e5aa4a
commit a576eba6f8
3 changed files with 78 additions and 75 deletions

View File

@ -5,7 +5,7 @@ import React, {
useCallback,
useMemo,
} from "react";
import { Handle } from "reactflow";
import { Handle, Position } from "reactflow";
import { v4 as uuid } from "uuid";
import {
Switch,
@ -38,12 +38,16 @@ import {
getLLMsInPulledInputData,
extractSettingsVars,
truncStr,
genDebounceFunc,
} from "./backend/utils";
import LLMResponseInspectorDrawer from "./LLMResponseInspectorDrawer";
import CancelTracker from "./backend/canceler";
import { UserForcedPrematureExit } from "./backend/errors";
import { ChatHistoryInfo, Dict, LLMSpec, QueryProgress, StandardizedLLMResponse, TemplateVarInfo } from "./backend/typing";
import { AlertModalHandles } from "./AlertModal";
import { Status } from "./StatusIndicatorComponent";
const getUniqueLLMMetavarKey = (responses) => {
const getUniqueLLMMetavarKey = (responses: StandardizedLLMResponse[]) => {
const metakeys = new Set(
responses.map((resp_obj) => Object.keys(resp_obj.metavars)).flat(),
);
@ -51,7 +55,7 @@ const getUniqueLLMMetavarKey = (responses) => {
while (metakeys.has(`LLM_${i}`)) i += 1;
return `LLM_${i}`;
};
const bucketChatHistoryInfosByLLM = (chat_hist_infos) => {
const bucketChatHistoryInfosByLLM = (chat_hist_infos: ChatHistoryInfo[]) => {
const chats_by_llm = {};
chat_hist_infos.forEach((chat_hist_info) => {
if (chat_hist_info.llm in chats_by_llm)
@ -62,16 +66,16 @@ const bucketChatHistoryInfosByLLM = (chat_hist_infos) => {
};
class PromptInfo {
prompt; // string
settings; // dict for any settings vars
prompt: string;
settings: Dict;
constructor(prompt, settings) {
constructor(prompt: string, settings: Dict) {
this.prompt = prompt;
this.settings = settings;
}
}
const displayPromptInfos = (promptInfos, wideFormat) =>
const displayPromptInfos = (promptInfos: PromptInfo[], wideFormat: boolean) =>
promptInfos.map((info, idx) => (
<div key={idx}>
<div className="prompt-preview">{info.prompt}</div>
@ -81,7 +85,7 @@ const displayPromptInfos = (promptInfos, wideFormat) =>
<div key={key} className="settings-var-inline response-var-inline">
<span className="response-var-name">{key}&nbsp;=&nbsp;</span>
<span className="response-var-value wrap-line">
{truncStr(val, wideFormat ? 512 : 72)}
{truncStr(val.toString(), wideFormat ? 512 : 72)}
</span>
</div>
);
@ -169,22 +173,22 @@ const PromptNode = ({ data, id, type: node_type }) => {
// API Keys (set by user in popup GlobalSettingsModal)
const apiKeys = useStore((state) => state.apiKeys);
const [jsonResponses, setJSONResponses] = useState(null);
const [templateVars, setTemplateVars] = useState(data.vars || []);
const [promptText, setPromptText] = useState(data.prompt || "");
const [promptTextOnLastRun, setPromptTextOnLastRun] = useState(null);
const [status, setStatus] = useState("none");
const [numGenerations, setNumGenerations] = useState(data.n || 1);
const [numGenerationsLastRun, setNumGenerationsLastRun] = useState(
data.n || 1,
const [jsonResponses, setJSONResponses] = useState<StandardizedLLMResponse[] | null>(null);
const [templateVars, setTemplateVars] = useState<string[]>(data.vars ?? []);
const [promptText, setPromptText] = useState<string>(data.prompt ?? "");
const [promptTextOnLastRun, setPromptTextOnLastRun] = useState<string | null>(null);
const [status, setStatus] = useState(Status.NONE);
const [numGenerations, setNumGenerations] = useState<number>(data.n ?? 1);
const [numGenerationsLastRun, setNumGenerationsLastRun] = useState<number>(
data.n ?? 1,
);
// The LLM items container
const llmListContainer = useRef(null);
const [llmItemsCurrState, setLLMItemsCurrState] = useState([]);
const [llmItemsCurrState, setLLMItemsCurrState] = useState<LLMSpec[]>([]);
// For displaying error messages to user
const alertModal = useRef(null);
const alertModal = useRef<AlertModalHandles | null>(null);
// For a way to inspect responses without having to attach a dedicated node
const inspectModal = useRef(null);
@ -194,10 +198,10 @@ const PromptNode = ({ data, id, type: node_type }) => {
const [showDrawer, setShowDrawer] = useState(false);
// For continuing with prior LLMs toggle
const [contWithPriorLLMs, setContWithPriorLLMs] = useState(
const [contWithPriorLLMs, setContWithPriorLLMs] = useState<boolean>(
data.contChat !== undefined ? data.contChat : node_type === "chat",
);
const [showContToggle, setShowContToggle] = useState(node_type === "chat");
const [showContToggle, setShowContToggle] = useState<boolean>(node_type === "chat");
const [contToggleDisabled, setContChatToggleDisabled] = useState(false);
// For an info pop-up that shows all the prompts that will be sent off
@ -206,9 +210,9 @@ const PromptNode = ({ data, id, type: node_type }) => {
useDisclosure(false);
// Progress when querying responses
const [progress, setProgress] = useState(undefined);
const [progress, setProgress] = useState<QueryProgress | undefined>(undefined);
const [progressAnimated, setProgressAnimated] = useState(true);
const [runTooltip, setRunTooltip] = useState(null);
const [runTooltip, setRunTooltip] = useState<string | null>(undefined);
// Cancelation of pending queries
const [cancelId, setCancelId] = useState(Date.now());
@ -216,19 +220,10 @@ const PromptNode = ({ data, id, type: node_type }) => {
// Debounce helpers
const debounceTimeoutRef = useRef(null);
const debounce = (func, delay) => {
return (...args) => {
if (debounceTimeoutRef.current) {
clearTimeout(debounceTimeoutRef.current);
}
debounceTimeoutRef.current = setTimeout(() => {
func(...args);
}, delay);
};
};
const debounce = genDebounceFunc(debounceTimeoutRef);
const triggerAlert = useCallback(
(msg) => {
(msg: string) => {
setProgress(undefined);
llmListContainer?.current?.resetLLMItemsProgress();
alertModal?.current?.trigger(msg);
@ -245,8 +240,8 @@ const PromptNode = ({ data, id, type: node_type }) => {
// Signal that prompt node state is dirty; user should re-run:
const signalDirty = useCallback(() => {
if (promptTextOnLastRun !== null && status === "ready")
setStatus("warning");
if (promptTextOnLastRun !== null && status === Status.READY)
setStatus(Status.WARNING);
}, [promptTextOnLastRun, status]);
const onLLMListItemsChange = useCallback(
@ -309,7 +304,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
}, [templateVars, id, pullInputData, updateShowContToggle]);
const refreshTemplateHooks = useCallback(
(text) => {
(text: string) => {
// Update template var fields + handles
const found_template_vars = new Set(extractBracketedSubstrings(text)); // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this}
@ -357,7 +352,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
if (json.responses && json.responses.length > 0) {
// Store responses and set status to green checkmark
setJSONResponses(json.responses);
setStatus("ready");
setStatus(Status.READY);
}
});
}, []);
@ -371,7 +366,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
useEffect(() => {
if (refresh === true) {
setDataPropsForNode(id, { refresh: false });
setStatus("warning");
setStatus(Status.WARNING);
handleOnConnect();
} else if (refreshLLMList === true) {
llmListContainer?.current?.refreshLLMProviderList();
@ -387,17 +382,16 @@ const PromptNode = ({ data, id, type: node_type }) => {
// For storing the unique LLMs in past_chats:
const llm_names = new Set();
const past_chat_llms = [];
const past_chat_llms: (LLMSpec | string)[] = [];
// We need to calculate the conversation history from the pulled responses.
// Note that TemplateVarInfo might have a 'chat_history' component, but this does not
// include the most recent prompt and response --for that, we need to use the 'prompt' and 'text' items.
// We need to create a revised chat history that concatenates the past history with the last AI + human turns:
const past_chats = pulled_data.__past_chats.map((info) => {
const past_chats = pulled_data.__past_chats.map((info: TemplateVarInfo) => {
// Add to unique LLMs list, if necessary
const llm_name = info?.llm?.name;
if (llm_name !== undefined && !llm_names.has(llm_name)) {
llm_names.add(llm_name);
if (typeof info?.llm !== "string" && info?.llm?.name !== undefined && !llm_names.has(info.llm.name)) {
llm_names.add(info.llm.name);
past_chat_llms.push(info.llm);
}
@ -414,6 +408,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
// Append any present system message retroactively as the first message in the chat history:
if (
typeof info?.llm !== "string" &&
info?.llm?.settings?.system_msg !== undefined &&
updated_chat_hist[0].role !== "system"
)
@ -426,7 +421,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
messages: updated_chat_hist,
fill_history: info.fill_history,
metavars: info.metavars,
llm: llm_name,
llm: info?.llm?.name,
batch_id: uuid(),
};
});
@ -480,7 +475,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
}).then((prompts) => {
setPromptPreviews(
prompts.map(
(p) =>
(p: string) =>
new PromptInfo(p.toString(), extractSettingsVars(p.fill_history)),
),
);
@ -565,7 +560,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
setResponsesWillChange(true);
// Tally how many queries per LLM:
const queries_per_llm = {};
const queries_per_llm: Dict<number> = {};
Object.keys(counts).forEach((llm_key) => {
queries_per_llm[llm_key] = Object.keys(counts[llm_key]).reduce(
(acc, prompt) => acc + counts[llm_key][prompt],
@ -711,12 +706,12 @@ Soft failing by replacing undefined with empty strings.`,
}
// Set status indicator
setStatus("loading");
setStatus(Status.LOADING);
setContChatToggleDisabled(true);
setJSONResponses([]);
setProgressAnimated(true);
const rejected = (err) => {
const rejected = (err: Error | string) => {
if (
err instanceof UserForcedPrematureExit ||
CancelTracker.has(cancelId)
@ -724,9 +719,9 @@ Soft failing by replacing undefined with empty strings.`,
// Handle a premature cancelation
console.log("Canceled.");
} else {
setStatus("error");
setStatus(Status.ERROR);
setContChatToggleDisabled(false);
triggerAlert(err.message || err);
triggerAlert(typeof err === "string" ? err : err?.message);
}
};
@ -889,7 +884,7 @@ Soft failing by replacing undefined with empty strings.`,
llmListContainer?.current?.ensureLLMItemsErrorProgress(llms_w_errors);
// Set error status
setStatus("error");
setStatus(Status.ERROR);
setContChatToggleDisabled(false);
// Trigger alert and display one error message per LLM of all collected errors:
@ -934,7 +929,7 @@ Soft failing by replacing undefined with empty strings.`,
setNumGenerationsLastRun(numGenerations);
// All responses collected! Change status to 'ready':
setStatus("ready");
setStatus(Status.READY);
// Ping any inspect nodes attached to this node to refresh their contents:
pingOutputNodes(id);
@ -961,7 +956,7 @@ Soft failing by replacing undefined with empty strings.`,
debounce(() => {}, 1)(); // erase any pending debounces
// Set error status
setStatus("none");
setStatus(Status.NONE);
setContChatToggleDisabled(false);
llmListContainer?.current?.resetLLMItemsProgress();
}, [cancelId, refreshCancelId]);
@ -972,8 +967,8 @@ Soft failing by replacing undefined with empty strings.`,
if (!isNaN(n) && n.length > 0 && /^\d+$/.test(n)) {
// n is an integer; save it
n = parseInt(n);
if (n !== numGenerationsLastRun && status === "ready")
setStatus("warning");
if (n !== numGenerationsLastRun && status === Status.READY)
setStatus(Status.WARNING);
setNumGenerations(n);
setDataPropsForNode(id, { n });
}
@ -982,14 +977,14 @@ Soft failing by replacing undefined with empty strings.`,
);
const hideStatusIndicator = () => {
if (status !== "none") setStatus("none");
if (status !== Status.NONE) setStatus(Status.NONE);
};
// Dynamically update the textareas and position of the template hooks
const textAreaRef = useRef(null);
const textAreaRef = useRef<HTMLTextAreaElement | HTMLDivElement | null>(null);
const [hooksY, setHooksY] = useState(138);
const setRef = useCallback(
(elem) => {
(elem: HTMLDivElement | HTMLTextAreaElement) => {
// To listen for resize events of the textarea, we need to use a ResizeObserver.
// We initialize the ResizeObserver only once, when the 'ref' is first set, and only on the div wrapping textfields.
// NOTE: This won't work on older browsers, but there's no alternative solution.
@ -1099,7 +1094,7 @@ Soft failing by replacing undefined with empty strings.`,
<Handle
type="source"
position="right"
position={Position.Right}
id="prompt"
className="grouped-handle"
style={{ top: "50%" }}
@ -1108,6 +1103,7 @@ Soft failing by replacing undefined with empty strings.`,
vars={templateVars}
nodeId={id}
startY={hooksY}
position={Position.Left}
ignoreHandles={["__past_chats"]}
/>
<hr />
@ -1140,7 +1136,7 @@ Soft failing by replacing undefined with empty strings.`,
checked={contWithPriorLLMs}
disabled={contToggleDisabled}
onChange={(event) => {
setStatus("warning");
setStatus(Status.WARNING);
setContWithPriorLLMs(event.currentTarget.checked);
setDataPropsForNode(id, {
contChat: event.currentTarget.checked,

View File

@ -5,6 +5,7 @@ export enum Status {
READY = "ready",
ERROR = "error",
LOADING = "loading",
NONE = "none",
}
interface StatusIndicatorProps {
status: Status;

View File

@ -23,21 +23,6 @@ export interface OpenAIFunctionCall {
description?: string;
}
/** The outputs of prompt nodes, text fields or other data passed internally in the front-end and to the PromptTemplate backend.
* Used to populate prompt templates and carry variables/metavariables along the chain. */
export interface TemplateVarInfo {
text: string;
fill_history: Dict<string>;
metavars?: Dict<string>;
associate_id?: string;
llm?: string | Dict;
}
export type PromptVarType = string | TemplateVarInfo;
export type PromptVarsDict = {
[key: string]: PromptVarType[];
};
/** OpenAI chat message format */
export interface ChatMessage {
role: string;
@ -183,3 +168,24 @@ export type LLMSpec = {
settings?: Dict;
progress?: Dict<number>; // only used for front-end to display progress collecting responses for this LLM
};
/** The outputs of prompt nodes, text fields or other data passed internally in the front-end and to the PromptTemplate backend.
* Used to populate prompt templates and carry variables/metavariables along the chain. */
export interface TemplateVarInfo {
text: string;
fill_history: Dict<string>;
metavars?: Dict<string>;
associate_id?: string;
llm?: string | LLMSpec;
chat_history?: ChatHistory;
}
export type PromptVarType = string | TemplateVarInfo;
export type PromptVarsDict = {
[key: string]: PromptVarType[];
};
export type QueryProgress = {
success: number;
error: number;
}