From 158fa74d97475fd877841ce496b73e2500861de5 Mon Sep 17 00:00:00 2001 From: ianarawjo Date: Sat, 15 Mar 2025 17:48:32 -0400 Subject: [PATCH] Add prompt variants option to Prompt Node (#334) * Add prompt variants feature to Prompt Node * Fix countQueries backwards compatibility * Add alertmodal for deleting prompt variant * Add prompt variants to PromptPreview modals * Autoresize textarea when switching prompt variants. Ensure auto-templating is only used for variants when length exceeds 1. * Update package version * Add an option to CLI to set custom flow directory --- chainforge/app.py | 11 +- chainforge/flask_app.py | 74 +++- chainforge/react-server/src/App.tsx | 189 +++++++-- .../react-server/src/AreYouSureModal.tsx | 7 +- chainforge/react-server/src/FlowSidebar.tsx | 71 +++- chainforge/react-server/src/ItemsNode.tsx | 2 +- .../react-server/src/LLMListComponent.tsx | 25 +- .../src/LLMResponseInspectorModal.tsx | 5 +- chainforge/react-server/src/MultiEvalNode.tsx | 8 +- chainforge/react-server/src/PromptNode.tsx | 393 +++++++++++++++--- .../react-server/src/backend/backend.ts | 40 +- chainforge/react-server/src/backend/models.ts | 2 + chainforge/react-server/src/backend/utils.ts | 59 ++- setup.py | 2 +- 14 files changed, 714 insertions(+), 174 deletions(-) diff --git a/chainforge/app.py b/chainforge/app.py index 1df830b..9f98819 100644 --- a/chainforge/app.py +++ b/chainforge/app.py @@ -25,6 +25,10 @@ def main(): serve_parser.add_argument('--host', help="The host to run the server on. Defaults to 'localhost'.", type=str, default="localhost", nargs='?') + serve_parser.add_argument('--dir', + help="Set a custom directory to use for saving flows and autosaving. By default, ChainForge uses the user data location suggested by the `platformdirs` module. Should be the full path.", + type=str, + default=None) args = parser.parse_args() @@ -34,10 +38,13 @@ def main(): exit(0) port = args.port if args.port else 8000 - host = args.host if args.host else "localhost" + host = args.host if args.host else "localhost" + + if args.dir: + print(f"Using directory for storing flows: {args.dir}") print(f"Serving Flask server on {host} on port {port}...") - run_server(host=host, port=port, cmd_args=args) + run_server(host=host, port=port, flows_dir=args.dir) if __name__ == "__main__": main() \ No newline at end of file diff --git a/chainforge/flask_app.py b/chainforge/flask_app.py index e2876bf..cfef251 100644 --- a/chainforge/flask_app.py +++ b/chainforge/flask_app.py @@ -1,4 +1,4 @@ -import json, os, sys, asyncio, time +import json, os, sys, asyncio, time, shutil from dataclasses import dataclass from enum import Enum from typing import List @@ -759,6 +759,17 @@ def get_flow(filename): except FileNotFoundError: return jsonify({"error": "Flow not found"}), 404 +@app.route('/api/flowExists/', methods=['GET']) +def get_flow_exists(filename): + """Return the content of a specific flow""" + if not filename.endswith('.cforge'): + filename += '.cforge' + try: + is_file = os.path.isfile(os.path.join(FLOWS_DIR, filename)) + return jsonify({"exists": is_file}) + except FileNotFoundError: + return jsonify({"error": "Flow not found"}), 404 + @app.route('/api/flows/', methods=['DELETE']) def delete_flow(filename): """Delete a flow""" @@ -772,7 +783,7 @@ def delete_flow(filename): @app.route('/api/flows/', methods=['PUT']) def save_or_rename_flow(filename): - """Save or rename a flow""" + """Save, rename, or duplicate a flow""" data = request.json if not filename.endswith('.cforge'): @@ -781,11 +792,18 @@ def save_or_rename_flow(filename): if data.get('flow'): # Save flow (overwriting any existing flow file with the same name) flow_data = data.get('flow') + also_autosave = data.get('alsoAutosave') try: filepath = os.path.join(FLOWS_DIR, filename) with open(filepath, 'w') as f: json.dump(flow_data, f) + + # If we should also autosave, then attempt to override the autosave cache file: + if also_autosave: + autosave_filepath = os.path.join(FLOWS_DIR, '__autosave.cforge') + shutil.copy2(filepath, autosave_filepath) # copy the file to __autosave + return jsonify({"message": f"Flow '{filename}' saved!"}) except FileNotFoundError: return jsonify({"error": f"Could not save flow '{filename}' to local filesystem. See terminal for more details."}), 404 @@ -805,6 +823,36 @@ def save_or_rename_flow(filename): return jsonify({"message": f"Flow renamed from {filename} to {new_name}"}) except Exception as error: return jsonify({"error": str(error)}), 404 + + elif data.get('duplicate'): + # Duplicate flow + try: + # Check for name clashes (if a flow already exists with the new name) + copy_name = _get_unique_flow_name(filename, "Copy of ") + # Copy the file to the new (safe) path, and copy metadata too: + shutil.copy2(os.path.join(FLOWS_DIR, filename), os.path.join(FLOWS_DIR, f"{copy_name}.cforge")) + # Return the new filename + return jsonify({"copyName": copy_name}) + except Exception as error: + return jsonify({"error": str(error)}), 404 + +def _get_unique_flow_name(filename: str, prefix: str = None) -> str: + base, ext = os.path.splitext(filename) + if ext is None or len(ext) == 0: + ext = ".cforge" + unique_filename = base + ext + if prefix is not None: + unique_filename = prefix + unique_filename + i = 1 + + # Find the first non-clashing filename of the form (i).cforge where i=1,2,3 etc + while os.path.isfile(os.path.join(FLOWS_DIR, unique_filename)): + unique_filename = f"{base}({i}){ext}" + if prefix is not None: + unique_filename = prefix + unique_filename + i += 1 + + return unique_filename.replace(".cforge", "") @app.route('/api/getUniqueFlowFilename', methods=['PUT']) def get_unique_flow_name(): @@ -813,25 +861,17 @@ def get_unique_flow_name(): filename = data.get("name") try: - base, ext = os.path.splitext(filename) - if ext is None or len(ext) == 0: - ext = ".cforge" - unique_filename = base + ext - i = 1 - - # Find the first non-clashing filename of the form (i).cforge where i=1,2,3 etc - while os.path.isfile(os.path.join(FLOWS_DIR, unique_filename)): - unique_filename = f"{base}({i}){ext}" - i += 1 - - return jsonify(unique_filename.replace(".cforge", "")) + new_name = _get_unique_flow_name(filename) + return jsonify(new_name) except Exception as e: return jsonify({"error": str(e)}), 404 -def run_server(host="", port=8000, cmd_args=None): - global HOSTNAME, PORT +def run_server(host="", port=8000, flows_dir=None): + global HOSTNAME, PORT, FLOWS_DIR HOSTNAME = host - PORT = port + PORT = port + if flows_dir: + FLOWS_DIR = flows_dir app.run(host=host, port=port) if __name__ == '__main__': diff --git a/chainforge/react-server/src/App.tsx b/chainforge/react-server/src/App.tsx index ba203e3..9ca353c 100644 --- a/chainforge/react-server/src/App.tsx +++ b/chainforge/react-server/src/App.tsx @@ -6,7 +6,6 @@ import React, { useContext, useMemo, useTransition, - KeyboardEventHandler, KeyboardEvent, } from "react"; import ReactFlow, { Controls, Background, ReactFlowInstance } from "reactflow"; @@ -63,6 +62,7 @@ import { getDefaultModelSettings, } from "./ModelSettingSchemas"; import { v4 as uuid } from "uuid"; +import axios from "axios"; import LZString from "lz-string"; import { EXAMPLEFLOW_1 } from "./example_flows"; @@ -78,7 +78,11 @@ import "lazysizes/plugins/attrchange/ls.attrchange"; import { shallow } from "zustand/shallow"; import useStore, { StoreHandles } from "./store"; import StorageCache, { StringLookup } from "./backend/cache"; -import { APP_IS_RUNNING_LOCALLY, browserTabIsActive } from "./backend/utils"; +import { + APP_IS_RUNNING_LOCALLY, + browserTabIsActive, + FLASK_BASE_URL, +} from "./backend/utils"; import { Dict, JSONCompatible, LLMSpec } from "./backend/typing"; import { ensureUniqueFlowFilename, @@ -113,6 +117,14 @@ const IS_ACCEPTED_BROWSER = // we have access to the Flask backend for, e.g., Python code evaluation. const IS_RUNNING_LOCALLY = APP_IS_RUNNING_LOCALLY(); +const SAVE_FLOW_FILENAME_TO_BROWSER_CACHE = (name: string) => { + console.log("Saving flow filename", name); + // Save the current filename of the user's working flow + StorageCache.saveToLocalStorage("chainforge-cur-file", { + flowFileName: name, + }); +}; + const selector = (state: StoreHandles) => ({ nodes: state.nodes, edges: state.edges, @@ -266,6 +278,11 @@ const App = () => { const safeSetFlowFileName = useCallback(async (newName: string) => { const uniqueName = await ensureUniqueFlowFilename(newName); setFlowFileName(uniqueName); + SAVE_FLOW_FILENAME_TO_BROWSER_CACHE(uniqueName); + }, []); + const setFlowFileNameAndCache = useCallback((newName: string) => { + setFlowFileName(newName); + SAVE_FLOW_FILENAME_TO_BROWSER_CACHE(newName); }, []); // For 'share' button @@ -387,6 +404,7 @@ const App = () => { flowData?: unknown, saveToLocalFilesystem?: string, hideErrorAlert?: boolean, + onError?: () => void, ) => { if (!rfInstance && !flowData) return; @@ -406,11 +424,16 @@ const App = () => { // Save! const flowFile = `${saveToLocalFilesystem ?? flowFileName}.cforge`; if (saveToLocalFilesystem !== undefined) - return saveFlowToLocalFilesystem(flow_and_cache, flowFile); + return saveFlowToLocalFilesystem( + flow_and_cache, + flowFile, + saveToLocalFilesystem !== "__autosave", + ); // @ts-expect-error The exported RF instance is JSON compatible but TypeScript won't read it as such. else downloadJSON(flow_and_cache, flowFile); }) .catch((err) => { + if (onError) onError(); if (hideErrorAlert) console.error(err); else handleError(err); }); @@ -432,14 +455,18 @@ const App = () => { setShowSaveSuccess(false); startSaveTransition(() => { - // NOTE: This currently only saves the front-end state. Cache files - // are not pulled or overwritten upon loading from localStorage. + // Get current flow state const flow = rf.toObject(); - StorageCache.saveToLocalStorage("chainforge-flow", flow); - // Attempt to save the current state of the back-end state, - // the StorageCache. (This does LZ compression to save space.) - StorageCache.saveToLocalStorage("chainforge-state"); + const saveToLocalStorage = () => { + // This line only saves the front-end state. Cache files + // are not pulled or overwritten upon loading from localStorage. + StorageCache.saveToLocalStorage("chainforge-flow", flow); + + // Attempt to save the current back-end state, + // in the StorageCache. (This does LZ compression to save space.) + StorageCache.saveToLocalStorage("chainforge-state"); + }; const onFlowSaved = () => { console.log("Flow saved!"); @@ -452,10 +479,18 @@ const App = () => { // If running locally, aattempt to save a copy of the flow to the lcoal filesystem, // so it shows up in the list of saved flows. if (IS_RUNNING_LOCALLY) - exportFlow(flow, fileName ?? flowFileName, hideErrorAlert)?.then( - onFlowSaved, - ); - else onFlowSaved(); + // SAVE TO LOCAL FILESYSTEM (only), and if that fails, try to save to localStorage + exportFlow( + flow, + fileName ?? flowFileName, + hideErrorAlert, + saveToLocalStorage, + )?.then(onFlowSaved); + else { + // SAVE TO BROWSER LOCALSTORAGE + saveToLocalStorage(); + onFlowSaved(); + } }); }, [rfInstance, exportFlow, flowFileName], @@ -475,8 +510,13 @@ const App = () => { // Initialize auto-saving const initAutosaving = useCallback( - (rf_inst: ReactFlowInstance) => { - if (autosavingInterval !== undefined) return; // autosaving interval already set + (rf_inst: ReactFlowInstance, reinit?: boolean) => { + if (autosavingInterval !== undefined) { + // Autosaving interval already set + if (reinit) + clearInterval(autosavingInterval); // reinitialize interval, clearing the current one + else return; // do nothing + } console.log("Init autosaving"); // Autosave the flow to localStorage every minute: @@ -539,7 +579,9 @@ const App = () => { StorageCache.clear(); // New flow filename - setFlowFileName(`flow-${Date.now()}`); + const new_filename = `flow-${Date.now()}`; + setFlowFileNameAndCache(new_filename); + if (rfInstance) rfInstance.setViewport({ x: 200, y: 80, zoom: 1 }); }, [setNodes, setEdges, resetLLMColors, rfInstance]); @@ -575,7 +617,7 @@ const App = () => { }, 10); // Start auto-saving, if it's not already enabled - if (rf_inst) initAutosaving(rf_inst); + if (rf_inst) initAutosaving(rf_inst, true); }, [resetLLMColors, setNodes, setEdges, initAutosaving], ); @@ -584,23 +626,28 @@ const App = () => { importState(StorageCache.getAllMatching((key) => key.startsWith("r."))); }, [importState]); - const autosavedFlowExists = useCallback(() => { - return window.localStorage.getItem("chainforge-flow") !== null; - }, []); - const loadFlowFromAutosave = useCallback( - async (rf_inst: ReactFlowInstance) => { - const saved_flow = StorageCache.loadFromLocalStorage( - "chainforge-flow", - false, - ) as Dict; - if (saved_flow) { - StorageCache.loadFromLocalStorage("chainforge-state", true); - importGlobalStateFromCache(); - loadFlow(saved_flow, rf_inst); + // Find the autosaved flow, if it exists, returning + // whether it exists and the location ("browser" or "filesystem") that it exists at. + const autosavedFlowExists = useCallback(async () => { + if (IS_RUNNING_LOCALLY) { + // If running locally, we try to fetch a flow autosaved on the user's local machine first: + try { + const response = await axios.get( + `${FLASK_BASE_URL}api/flowExists/__autosave`, + ); + const autosave_file_exists = response.data.exists as boolean; + if (autosave_file_exists) + return { exists: autosave_file_exists, location: "filesystem" }; + } catch (error) { + // Soft fail, continuing onwards to checking localStorage instead } - }, - [importGlobalStateFromCache, loadFlow], - ); + } + + return { + exists: window.localStorage.getItem("chainforge-flow") !== null, + location: "browser", + }; + }, []); // Import data to the cache stored on the local filesystem (in backend) const handleImportCache = useCallback( @@ -715,6 +762,38 @@ const App = () => { fetchOpenAIEval(evalname).then(importFlowFromJSON).catch(handleError); }; + const loadFlowFromAutosave = useCallback( + async (rf_inst: ReactFlowInstance, fromFilesystem?: boolean) => { + if (fromFilesystem) { + // From local filesystem + // Fetch the flow + const response = await axios.get( + `${FLASK_BASE_URL}api/flows/__autosave`, + ); + + // Attempt to load flow into the UI + try { + importFlowFromJSON(response.data, rf_inst); + console.log("Loaded flow from autosave on local machine."); + } catch (error) { + handleError(error as Error); + } + } else { + // From browser localStorage + const saved_flow = StorageCache.loadFromLocalStorage( + "chainforge-flow", + false, + ) as Dict; + if (saved_flow) { + StorageCache.loadFromLocalStorage("chainforge-state", true); + importGlobalStateFromCache(); + loadFlow(saved_flow, rf_inst); + } + } + }, + [importGlobalStateFromCache, loadFlow, importFlowFromJSON, handleError], + ); + // Load flow from examples modal const onSelectExampleFlow = (name: string, example_category?: string) => { // Trigger the 'loading' modal @@ -723,7 +802,7 @@ const App = () => { // Detect a special category of the example flow, and use the right loader for it: if (example_category === "openai-eval") { importFlowFromOpenAIEval(name); - setFlowFileName(`flow-${Date.now()}`); + setFlowFileNameAndCache(`flow-${Date.now()}`); return; } @@ -732,7 +811,7 @@ const App = () => { .then(function (flowJSON) { // We have the data, import it: importFlowFromJSON(flowJSON); - setFlowFileName(`flow-${Date.now()}`); + setFlowFileNameAndCache(`flow-${Date.now()}`); }) .catch(handleError); }; @@ -871,6 +950,20 @@ const App = () => { err.message, ); }); + + // We also need to fetch the current flowFileName + // Attempt to get the last working filename on component mount + const last_working_flow_filename = StorageCache.loadFromLocalStorage( + "chainforge-cur-file", + ); + if ( + last_working_flow_filename && + typeof last_working_flow_filename === "object" && + "flowFileName" in last_working_flow_filename + ) { + // Use last working flow name + setFlowFileName(last_working_flow_filename.flowFileName as string); + } } else { // Check if there's a shared flow UID in the URL as a GET param // If so, we need to look it up in the database and attempt to load it: @@ -910,14 +1003,19 @@ const App = () => { } // Attempt to load an autosaved flow, if one exists: - if (autosavedFlowExists()) loadFlowFromAutosave(rf_inst); - else { - // Load an interesting default starting flow for new users - importFlowFromJSON(EXAMPLEFLOW_1, rf_inst); + autosavedFlowExists().then(({ exists, location }) => { + if (!exists) { + // Load an interesting default starting flow for new users + importFlowFromJSON(EXAMPLEFLOW_1, rf_inst); - // Open a welcome pop-up - // openWelcomeModal(); - } + // Open a welcome pop-up + // openWelcomeModal(); + } else if (location === "browser") { + loadFlowFromAutosave(rf_inst, false); + } else if (location === "filesystem") { + loadFlowFromAutosave(rf_inst, true); + } + }); // Turn off loading wheel setIsLoading(false); @@ -1218,7 +1316,9 @@ const App = () => { { - if (name !== undefined) setFlowFileName(name); + if (name !== undefined) { + setFlowFileNameAndCache(name); + } if (flowData !== undefined) { try { importFlowFromJSON(flowData); @@ -1231,7 +1331,7 @@ const App = () => { }} /> ); - }, [flowFileName, importFlowFromJSON, showAlert]); + }, [flowFileName, importFlowFromJSON, showAlert, setFlowFileNameAndCache]); if (!IS_ACCEPTED_BROWSER) { return ( @@ -1334,6 +1434,7 @@ const App = () => { ml="sm" size="1.625rem" onClick={() => saveFlow()} + bg="#eee" loading={isSaving} disabled={isLoading || isSaving} > diff --git a/chainforge/react-server/src/AreYouSureModal.tsx b/chainforge/react-server/src/AreYouSureModal.tsx index b0ac59c..9578175 100644 --- a/chainforge/react-server/src/AreYouSureModal.tsx +++ b/chainforge/react-server/src/AreYouSureModal.tsx @@ -5,6 +5,7 @@ import { useDisclosure } from "@mantine/hooks"; export interface AreYouSureModalProps { title: string; message: string; + color?: string; onConfirm?: () => void; } @@ -14,7 +15,7 @@ export interface AreYouSureModalRef { /** Modal that lets user rename a single value, using a TextInput field. */ const AreYouSureModal = forwardRef( - function AreYouSureModal({ title, message, onConfirm }, ref) { + function AreYouSureModal({ title, message, color, onConfirm }, ref) { const [opened, { open, close }] = useDisclosure(false); const description = message || "Are you sure?"; @@ -37,7 +38,7 @@ const AreYouSureModal = forwardRef( onClose={close} title={title} styles={{ - header: { backgroundColor: "orange", color: "white" }, + header: { backgroundColor: color ?? "orange", color: "white" }, root: { position: "relative", left: "-5%" }, }} > @@ -54,7 +55,7 @@ const AreYouSureModal = forwardRef( > } - styles={{ title: { justifyContent: "space-between", width: "100%" } }} + styles={{ + title: { justifyContent: "space-between", width: "100%" }, + header: { paddingBottom: "0px" }, + }} >
= ({ data, id }) => { const bringNodeToFront = useStore((state) => state.bringNodeToFront); const inputEdgesForNode = useStore((state) => state.inputEdgesForNode); - const flags = useStore((state) => state.flags); - const AI_SUPPORT_ENABLED = useMemo(() => { - return flags.aiSupport; - }, [flags]); + // const flags = useStore((state) => state.flags); + // const AI_SUPPORT_ENABLED = useMemo(() => { + // return flags.aiSupport; + // }, [flags]); const [status, setStatus] = useState(Status.NONE); // For displaying error messages to user diff --git a/chainforge/react-server/src/PromptNode.tsx b/chainforge/react-server/src/PromptNode.tsx index 93af5a8..ce811a0 100644 --- a/chainforge/react-server/src/PromptNode.tsx +++ b/chainforge/react-server/src/PromptNode.tsx @@ -18,9 +18,20 @@ import { Modal, Box, Tooltip, + Flex, + Button, + ActionIcon, + Divider, } from "@mantine/core"; import { useDisclosure } from "@mantine/hooks"; -import { IconEraser, IconList } from "@tabler/icons-react"; +import { + IconArrowLeft, + IconArrowRight, + IconEraser, + IconList, + IconPlus, + IconTrash, +} from "@tabler/icons-react"; import useStore from "./store"; import BaseNode from "./BaseNode"; import NodeLabel from "./NodeLabelComponent"; @@ -41,6 +52,7 @@ import { extractSettingsVars, truncStr, genDebounceFunc, + ensureUniqueName, } from "./backend/utils"; import LLMResponseInspectorDrawer from "./LLMResponseInspectorDrawer"; import CancelTracker from "./backend/canceler"; @@ -64,6 +76,8 @@ import { queryLLM, } from "./backend/backend"; import { StringLookup } from "./backend/cache"; +import { union } from "./backend/setUtils"; +import AreYouSureModal, { AreYouSureModalRef } from "./AreYouSureModal"; const getUniqueLLMMetavarKey = (responses: LLMResponse[]) => { const metakeys = new Set( @@ -82,22 +96,44 @@ const bucketChatHistoryInfosByLLM = (chat_hist_infos: ChatHistoryInfo[]) => { }); return chats_by_llm; }; +const getRootPromptFor = ( + promptTexts: string | string[], + varNameForRootTemplate: string, +) => { + if (typeof promptTexts === "string") return promptTexts; + else if (promptTexts.length === 1) return promptTexts[0]; + else return `{${varNameForRootTemplate}}`; +}; export class PromptInfo { prompt: string; - settings: Dict; + settings?: Dict; + label?: string; - constructor(prompt: string, settings: Dict) { + constructor(prompt: string, settings?: Dict, label?: string) { this.prompt = prompt; this.settings = settings; + this.label = label; } } -const displayPromptInfos = (promptInfos: PromptInfo[], wideFormat: boolean) => +const displayPromptInfos = ( + promptInfos: PromptInfo[], + wideFormat: boolean, + bgColor?: string, +) => promptInfos.map((info, idx) => (
-
{info.prompt}
- {info.settings ? ( +
+ {info.label && ( + + {info.label} +
+
+ )} + {info.prompt} +
+ {info.settings && Object.entries(info.settings).map(([key, val]) => { return (
@@ -107,10 +143,7 @@ const displayPromptInfos = (promptInfos: PromptInfo[], wideFormat: boolean) =>
); - }) - ) : ( - <> - )} + })}
)); @@ -118,12 +151,14 @@ export interface PromptListPopoverProps { promptInfos: PromptInfo[]; onHover: () => void; onClick: () => void; + promptTemplates?: string[] | string; } export const PromptListPopover: React.FC = ({ promptInfos, onHover, onClick, + promptTemplates, }) => { const [opened, { close, open }] = useDisclosure(false); @@ -172,6 +207,29 @@ export const PromptListPopover: React.FC = ({ Preview of generated prompts ({promptInfos.length} total) + {Array.isArray(promptTemplates) && promptTemplates.length > 1 && ( + + + {displayPromptInfos( + promptTemplates.map( + (t, i) => new PromptInfo(t, undefined, `Variant ${i + 1}`), + ), + false, + "#ddf1f8", + )} + + + )} {displayPromptInfos(promptInfos, false)} @@ -182,12 +240,14 @@ export interface PromptListModalProps { promptPreviews: PromptInfo[]; infoModalOpened: boolean; closeInfoModal: () => void; + promptTemplates?: string[] | string; } export const PromptListModal: React.FC = ({ promptPreviews, infoModalOpened, closeInfoModal, + promptTemplates, }) => { return ( = ({ }} > + {Array.isArray(promptTemplates) && promptTemplates.length > 1 && ( + + + {displayPromptInfos( + promptTemplates.map( + (t, i) => new PromptInfo(t, undefined, `Variant ${i + 1}`), + ), + true, + "#ddf1f8", + )} + + + )} {displayPromptInfos(promptPreviews, true)} @@ -221,6 +304,7 @@ export interface PromptNodeProps { contChat: boolean; refresh: boolean; refreshLLMList: boolean; + idxPromptVariantShown?: number; }; id: string; type: string; @@ -257,10 +341,15 @@ const PromptNode: React.FC = ({ null, ); const [templateVars, setTemplateVars] = useState(data.vars ?? []); - const [promptText, setPromptText] = useState(data.prompt ?? ""); - const [promptTextOnLastRun, setPromptTextOnLastRun] = useState( - null, + const [promptText, setPromptText] = useState( + data.prompt ?? "", ); + const [idxPromptVariantShown, setIdxPromptVariantShown] = useState( + data.idxPromptVariantShown ?? 0, + ); + const [promptTextOnLastRun, setPromptTextOnLastRun] = useState< + string | string[] | null + >(null); const [status, setStatus] = useState(Status.NONE); const [numGenerations, setNumGenerations] = useState(data.n ?? 1); const [numGenerationsLastRun, setNumGenerationsLastRun] = useState( @@ -391,10 +480,17 @@ const PromptNode: React.FC = ({ }, [templateVars, id, pullInputData, updateShowContToggle]); const refreshTemplateHooks = useCallback( - (text: string) => { - // Update template var fields + handles - const found_template_vars = new Set(extractBracketedSubstrings(text)); // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this} + (text: string | string[]) => { + const texts = typeof text === "string" ? [text] : text; + // Get all template vars in the prompt(s) + let found_template_vars = new Set(); + for (const t of texts) { + const substrs = extractBracketedSubstrings(t); // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this} + found_template_vars = union(found_template_vars, new Set(substrs)); + } + + // Update template var fields + handles if (!setsAreEqual(found_template_vars, new Set(templateVars))) { if (node_type !== "chat") { try { @@ -413,27 +509,29 @@ const PromptNode: React.FC = ({ const handleInputChange = useCallback( (event: React.ChangeEvent) => { - const value = event.target.value; + const value = event.target.value as string; const updateStatus = promptTextOnLastRun !== null && status !== Status.WARNING && value !== promptTextOnLastRun; - // Store prompt text - data.prompt = value; - // Debounce the global state change to happen only after 500ms, as it forces a costly rerender: - debounce((_value, _updateStatus) => { - setPromptText(_value); - setDataPropsForNode(id, { prompt: _value }); - refreshTemplateHooks(_value); + debounce((_value: string, _updateStatus, _idxPromptVariantShown) => { + setPromptText((prompts) => { + if (typeof prompts === "string") prompts = _value; + else prompts[_idxPromptVariantShown] = _value; + setDataPropsForNode(id, { prompt: prompts }); + refreshTemplateHooks(prompts); + return prompts; + }); if (_updateStatus) setStatus(Status.WARNING); - }, 300)(value, updateStatus); + }, 300)(value, updateStatus, idxPromptVariantShown); // Debounce refreshing the template hooks so we don't annoy the user // debounce((_value) => refreshTemplateHooks(_value), 500)(value); }, [ + idxPromptVariantShown, promptTextOnLastRun, status, refreshTemplateHooks, @@ -552,7 +650,7 @@ const PromptNode: React.FC = ({ // Ask the backend how many responses it needs to collect, given the input data: const fetchResponseCounts = useCallback( ( - prompt: string, + prompt: string | string[], vars: Dict, llms: (StringOrHash | LLMSpec)[], chat_histories?: @@ -592,14 +690,24 @@ const PromptNode: React.FC = ({ const pulled_vars = pullInputData(templateVars, id); updateShowContToggle(pulled_vars); - generatePrompts(promptText, pulled_vars).then((prompts) => { - setPromptPreviews( - prompts.map( - (p: PromptTemplate) => - new PromptInfo(p.toString(), extractSettingsVars(p.fill_history)), - ), - ); - }); + const prompts = + typeof promptText === "string" ? [promptText] : promptText; + + Promise.all(prompts.map((p) => generatePrompts(p, pulled_vars))).then( + (results) => { + // Handle all the results here + const all_concrete_prompts = results.flatMap((ps) => + ps.map( + (p: PromptTemplate) => + new PromptInfo( + p.toString(), + extractSettingsVars(p.fill_history), + ), + ), + ); + setPromptPreviews(all_concrete_prompts); + }, + ); pullInputChats(); } catch (err) { @@ -827,9 +935,18 @@ Soft failing by replacing undefined with empty strings.`, // Pull the data to fill in template input variables, if any let pulled_data: Dict<(string | TemplateVarInfo)[]> = {}; + let var_for_prompt_templates: string; try { // Try to pull inputs pulled_data = pullInputData(templateVars, id); + + // Add a special new variable for the root prompt template(s) + var_for_prompt_templates = ensureUniqueName( + "prompt", + Object.keys(pulled_data), + ); + if (typeof promptText !== "string" && promptText.length > 1) + pulled_data[var_for_prompt_templates] = promptText; // this will be filled in when calling queryLLMs } catch (err) { if (showAlert) showAlert((err as Error)?.message ?? err); console.error(err); @@ -873,7 +990,9 @@ Soft failing by replacing undefined with empty strings.`, // Fetch info about the number of queries we'll need to make const fetch_resp_count = () => fetchResponseCounts( - prompt_template, + typeof prompt_template === "string" + ? prompt_template + : `{${var_for_prompt_templates}}`, // Use special root prompt if there's multiple prompt variants pulled_data, _llmItemsCurrState, pulled_chats as ChatHistoryInfo[], @@ -951,9 +1070,9 @@ Soft failing by replacing undefined with empty strings.`, const query_llms = () => { return queryLLM( id, - _llmItemsCurrState, // deep clone it first + _llmItemsCurrState, numGenerations, - prompt_template, + getRootPromptFor(prompt_template, var_for_prompt_templates), // Use special root prompt if there's multiple prompt variants pulled_data, chat_hist_by_llm, apiKeys || {}, @@ -1015,7 +1134,7 @@ Soft failing by replacing undefined with empty strings.`, o.metavars = resp_obj.metavars ?? {}; // Add a metavar for the prompt *template* in this PromptNode - o.metavars.__pt = prompt_template; + // o.metavars.__pt = prompt_template; // Carry over any chat history if (resp_obj.chat_history) @@ -1162,6 +1281,16 @@ Soft failing by replacing undefined with empty strings.`, // Dynamically update the textareas and position of the template hooks const textAreaRef = useRef(null); + const resizeTextarea = () => { + const textarea = textAreaRef.current; + + if (textarea) { + textarea.style.height = "auto"; // Reset height to shrink if needed + const newHeight = Math.min(textarea.scrollHeight, 600); + textarea.style.height = `${newHeight}px`; + } + }; + const [hooksY, setHooksY] = useState(138); const setRef = useCallback( (elem: HTMLDivElement | HTMLTextAreaElement | null) => { @@ -1188,6 +1317,147 @@ Soft failing by replacing undefined with empty strings.`, [textAreaRef], ); + const deleteVariantConfirmModal = useRef(null); + const handleAddPromptVariant = useCallback(() => { + // Pushes a new prompt variant, updating the prompts list and duplicating the current shown prompt + const prompts = typeof promptText === "string" ? [promptText] : promptText; + const curIdx = Math.max( + 0, + Math.min(prompts.length - 1, idxPromptVariantShown), + ); // clamp + const curShownPrompt = prompts[curIdx]; + setPromptText(prompts.concat([curShownPrompt])); + setIdxPromptVariantShown(prompts.length); + }, [promptText, idxPromptVariantShown]); + + const gotoPromptVariant = useCallback( + (shift: number) => { + const prompts = + typeof promptText === "string" ? [promptText] : promptText; + const newIdx = Math.max( + 0, + Math.min(prompts.length - 1, idxPromptVariantShown + shift), + ); // clamp + setIdxPromptVariantShown(newIdx); + resizeTextarea(); + }, + [promptText, idxPromptVariantShown], + ); + + const handleRemovePromptVariant = useCallback(() => { + setPromptText((prompts) => { + if (typeof prompts === "string" || prompts.length === 1) return prompts; // cannot remove the last one + prompts.splice(idxPromptVariantShown, 1); // remove the indexed variant + const newIdx = Math.max(0, idxPromptVariantShown - 1); + setIdxPromptVariantShown(newIdx); // goto the previous variant, if possible + + if (textAreaRef.current) { + // We have to force an update here since idxPromptVariantShown might've not changed + // @ts-expect-error Mantine has a 'value' property on Textareas, but TypeScript doesn't know this + textAreaRef.current.value = prompts[newIdx]; + resizeTextarea(); + } + + return [...prompts]; + }); + }, [idxPromptVariantShown, textAreaRef]); + + // Whenever idx of prompt variant changes, we need to refresh the Textarea: + useEffect(() => { + if (textAreaRef.current && Array.isArray(promptText)) { + // @ts-expect-error Mantine has a 'value' property on Textareas, but TypeScript doesn't know this + textAreaRef.current.value = promptText[idxPromptVariantShown]; + resizeTextarea(); + } + }, [idxPromptVariantShown]); + + const promptVariantControls = useMemo(() => { + return ( + + {typeof promptText === "string" || promptText.length === 1 ? ( + + + + ) : ( + <> + gotoPromptVariant(-1)} + > + + + + + Variant {idxPromptVariantShown + 1} of{" "} + {typeof promptText === "string" ? 1 : promptText.length} + + + gotoPromptVariant(1)} + > + + + + + + + + + + + deleteVariantConfirmModal?.current?.trigger()} + > + + + + + )} + + ); + }, [idxPromptVariantShown, promptText, deleteVariantConfirmModal]); + // Add custom context menu options on right-click. // 1. Convert TextFields to Items Node, for convenience. const customContextMenuItems = useMemo( @@ -1229,6 +1499,7 @@ Soft failing by replacing undefined with empty strings.`, , @@ -1240,9 +1511,17 @@ Soft failing by replacing undefined with empty strings.`, /> + {node_type === "chat" ? (
@@ -1254,7 +1533,12 @@ Soft failing by replacing undefined with empty strings.`, key={0} className="prompt-field-fixed nodrag nowheel" minRows={4} - defaultValue={data.prompt} + defaultValue={ + typeof data.prompt === "string" + ? data.prompt + : data.prompt && + data.prompt[data.idxPromptVariantShown ?? 0] + } onChange={handleInputChange} miw={230} styles={{ @@ -1273,15 +1557,22 @@ Soft failing by replacing undefined with empty strings.`, ) : (