diff --git a/chainforge/app.py b/chainforge/app.py index 1df830b..9f98819 100644 --- a/chainforge/app.py +++ b/chainforge/app.py @@ -25,6 +25,10 @@ def main(): serve_parser.add_argument('--host', help="The host to run the server on. Defaults to 'localhost'.", type=str, default="localhost", nargs='?') + serve_parser.add_argument('--dir', + help="Set a custom directory to use for saving flows and autosaving. By default, ChainForge uses the user data location suggested by the `platformdirs` module. Should be the full path.", + type=str, + default=None) args = parser.parse_args() @@ -34,10 +38,13 @@ def main(): exit(0) port = args.port if args.port else 8000 - host = args.host if args.host else "localhost" + host = args.host if args.host else "localhost" + + if args.dir: + print(f"Using directory for storing flows: {args.dir}") print(f"Serving Flask server on {host} on port {port}...") - run_server(host=host, port=port, cmd_args=args) + run_server(host=host, port=port, flows_dir=args.dir) if __name__ == "__main__": main() \ No newline at end of file diff --git a/chainforge/flask_app.py b/chainforge/flask_app.py index e2876bf..cfef251 100644 --- a/chainforge/flask_app.py +++ b/chainforge/flask_app.py @@ -1,4 +1,4 @@ -import json, os, sys, asyncio, time +import json, os, sys, asyncio, time, shutil from dataclasses import dataclass from enum import Enum from typing import List @@ -759,6 +759,17 @@ def get_flow(filename): except FileNotFoundError: return jsonify({"error": "Flow not found"}), 404 +@app.route('/api/flowExists/<filename>', methods=['GET']) +def get_flow_exists(filename): + """Return the content of a specific flow""" + if not filename.endswith('.cforge'): + filename += '.cforge' + try: + is_file = os.path.isfile(os.path.join(FLOWS_DIR, filename)) + return jsonify({"exists": is_file}) + except FileNotFoundError: + return jsonify({"error": "Flow not found"}), 404 + @app.route('/api/flows/<filename>', methods=['DELETE']) def delete_flow(filename): """Delete a flow""" @@ -772,7 +783,7 @@ def delete_flow(filename): @app.route('/api/flows/<filename>', methods=['PUT']) def save_or_rename_flow(filename): - """Save or rename a flow""" + """Save, rename, or duplicate a flow""" data = request.json if not filename.endswith('.cforge'): @@ -781,11 +792,18 @@ def save_or_rename_flow(filename): if data.get('flow'): # Save flow (overwriting any existing flow file with the same name) flow_data = data.get('flow') + also_autosave = data.get('alsoAutosave') try: filepath = os.path.join(FLOWS_DIR, filename) with open(filepath, 'w') as f: json.dump(flow_data, f) + + # If we should also autosave, then attempt to override the autosave cache file: + if also_autosave: + autosave_filepath = os.path.join(FLOWS_DIR, '__autosave.cforge') + shutil.copy2(filepath, autosave_filepath) # copy the file to __autosave + return jsonify({"message": f"Flow '{filename}' saved!"}) except FileNotFoundError: return jsonify({"error": f"Could not save flow '{filename}' to local filesystem. See terminal for more details."}), 404 @@ -805,6 +823,36 @@ def save_or_rename_flow(filename): return jsonify({"message": f"Flow renamed from {filename} to {new_name}"}) except Exception as error: return jsonify({"error": str(error)}), 404 + + elif data.get('duplicate'): + # Duplicate flow + try: + # Check for name clashes (if a flow already exists with the new name) + copy_name = _get_unique_flow_name(filename, "Copy of ") + # Copy the file to the new (safe) path, and copy metadata too: + shutil.copy2(os.path.join(FLOWS_DIR, filename), os.path.join(FLOWS_DIR, f"{copy_name}.cforge")) + # Return the new filename + return jsonify({"copyName": copy_name}) + except Exception as error: + return jsonify({"error": str(error)}), 404 + +def _get_unique_flow_name(filename: str, prefix: str = None) -> str: + base, ext = os.path.splitext(filename) + if ext is None or len(ext) == 0: + ext = ".cforge" + unique_filename = base + ext + if prefix is not None: + unique_filename = prefix + unique_filename + i = 1 + + # Find the first non-clashing filename of the form <filename>(i).cforge where i=1,2,3 etc + while os.path.isfile(os.path.join(FLOWS_DIR, unique_filename)): + unique_filename = f"{base}({i}){ext}" + if prefix is not None: + unique_filename = prefix + unique_filename + i += 1 + + return unique_filename.replace(".cforge", "") @app.route('/api/getUniqueFlowFilename', methods=['PUT']) def get_unique_flow_name(): @@ -813,25 +861,17 @@ def get_unique_flow_name(): filename = data.get("name") try: - base, ext = os.path.splitext(filename) - if ext is None or len(ext) == 0: - ext = ".cforge" - unique_filename = base + ext - i = 1 - - # Find the first non-clashing filename of the form <filename>(i).cforge where i=1,2,3 etc - while os.path.isfile(os.path.join(FLOWS_DIR, unique_filename)): - unique_filename = f"{base}({i}){ext}" - i += 1 - - return jsonify(unique_filename.replace(".cforge", "")) + new_name = _get_unique_flow_name(filename) + return jsonify(new_name) except Exception as e: return jsonify({"error": str(e)}), 404 -def run_server(host="", port=8000, cmd_args=None): - global HOSTNAME, PORT +def run_server(host="", port=8000, flows_dir=None): + global HOSTNAME, PORT, FLOWS_DIR HOSTNAME = host - PORT = port + PORT = port + if flows_dir: + FLOWS_DIR = flows_dir app.run(host=host, port=port) if __name__ == '__main__': diff --git a/chainforge/react-server/src/App.tsx b/chainforge/react-server/src/App.tsx index ba203e3..9ca353c 100644 --- a/chainforge/react-server/src/App.tsx +++ b/chainforge/react-server/src/App.tsx @@ -6,7 +6,6 @@ import React, { useContext, useMemo, useTransition, - KeyboardEventHandler, KeyboardEvent, } from "react"; import ReactFlow, { Controls, Background, ReactFlowInstance } from "reactflow"; @@ -63,6 +62,7 @@ import { getDefaultModelSettings, } from "./ModelSettingSchemas"; import { v4 as uuid } from "uuid"; +import axios from "axios"; import LZString from "lz-string"; import { EXAMPLEFLOW_1 } from "./example_flows"; @@ -78,7 +78,11 @@ import "lazysizes/plugins/attrchange/ls.attrchange"; import { shallow } from "zustand/shallow"; import useStore, { StoreHandles } from "./store"; import StorageCache, { StringLookup } from "./backend/cache"; -import { APP_IS_RUNNING_LOCALLY, browserTabIsActive } from "./backend/utils"; +import { + APP_IS_RUNNING_LOCALLY, + browserTabIsActive, + FLASK_BASE_URL, +} from "./backend/utils"; import { Dict, JSONCompatible, LLMSpec } from "./backend/typing"; import { ensureUniqueFlowFilename, @@ -113,6 +117,14 @@ const IS_ACCEPTED_BROWSER = // we have access to the Flask backend for, e.g., Python code evaluation. const IS_RUNNING_LOCALLY = APP_IS_RUNNING_LOCALLY(); +const SAVE_FLOW_FILENAME_TO_BROWSER_CACHE = (name: string) => { + console.log("Saving flow filename", name); + // Save the current filename of the user's working flow + StorageCache.saveToLocalStorage("chainforge-cur-file", { + flowFileName: name, + }); +}; + const selector = (state: StoreHandles) => ({ nodes: state.nodes, edges: state.edges, @@ -266,6 +278,11 @@ const App = () => { const safeSetFlowFileName = useCallback(async (newName: string) => { const uniqueName = await ensureUniqueFlowFilename(newName); setFlowFileName(uniqueName); + SAVE_FLOW_FILENAME_TO_BROWSER_CACHE(uniqueName); + }, []); + const setFlowFileNameAndCache = useCallback((newName: string) => { + setFlowFileName(newName); + SAVE_FLOW_FILENAME_TO_BROWSER_CACHE(newName); }, []); // For 'share' button @@ -387,6 +404,7 @@ const App = () => { flowData?: unknown, saveToLocalFilesystem?: string, hideErrorAlert?: boolean, + onError?: () => void, ) => { if (!rfInstance && !flowData) return; @@ -406,11 +424,16 @@ const App = () => { // Save! const flowFile = `${saveToLocalFilesystem ?? flowFileName}.cforge`; if (saveToLocalFilesystem !== undefined) - return saveFlowToLocalFilesystem(flow_and_cache, flowFile); + return saveFlowToLocalFilesystem( + flow_and_cache, + flowFile, + saveToLocalFilesystem !== "__autosave", + ); // @ts-expect-error The exported RF instance is JSON compatible but TypeScript won't read it as such. else downloadJSON(flow_and_cache, flowFile); }) .catch((err) => { + if (onError) onError(); if (hideErrorAlert) console.error(err); else handleError(err); }); @@ -432,14 +455,18 @@ const App = () => { setShowSaveSuccess(false); startSaveTransition(() => { - // NOTE: This currently only saves the front-end state. Cache files - // are not pulled or overwritten upon loading from localStorage. + // Get current flow state const flow = rf.toObject(); - StorageCache.saveToLocalStorage("chainforge-flow", flow); - // Attempt to save the current state of the back-end state, - // the StorageCache. (This does LZ compression to save space.) - StorageCache.saveToLocalStorage("chainforge-state"); + const saveToLocalStorage = () => { + // This line only saves the front-end state. Cache files + // are not pulled or overwritten upon loading from localStorage. + StorageCache.saveToLocalStorage("chainforge-flow", flow); + + // Attempt to save the current back-end state, + // in the StorageCache. (This does LZ compression to save space.) + StorageCache.saveToLocalStorage("chainforge-state"); + }; const onFlowSaved = () => { console.log("Flow saved!"); @@ -452,10 +479,18 @@ const App = () => { // If running locally, aattempt to save a copy of the flow to the lcoal filesystem, // so it shows up in the list of saved flows. if (IS_RUNNING_LOCALLY) - exportFlow(flow, fileName ?? flowFileName, hideErrorAlert)?.then( - onFlowSaved, - ); - else onFlowSaved(); + // SAVE TO LOCAL FILESYSTEM (only), and if that fails, try to save to localStorage + exportFlow( + flow, + fileName ?? flowFileName, + hideErrorAlert, + saveToLocalStorage, + )?.then(onFlowSaved); + else { + // SAVE TO BROWSER LOCALSTORAGE + saveToLocalStorage(); + onFlowSaved(); + } }); }, [rfInstance, exportFlow, flowFileName], @@ -475,8 +510,13 @@ const App = () => { // Initialize auto-saving const initAutosaving = useCallback( - (rf_inst: ReactFlowInstance) => { - if (autosavingInterval !== undefined) return; // autosaving interval already set + (rf_inst: ReactFlowInstance, reinit?: boolean) => { + if (autosavingInterval !== undefined) { + // Autosaving interval already set + if (reinit) + clearInterval(autosavingInterval); // reinitialize interval, clearing the current one + else return; // do nothing + } console.log("Init autosaving"); // Autosave the flow to localStorage every minute: @@ -539,7 +579,9 @@ const App = () => { StorageCache.clear(); // New flow filename - setFlowFileName(`flow-${Date.now()}`); + const new_filename = `flow-${Date.now()}`; + setFlowFileNameAndCache(new_filename); + if (rfInstance) rfInstance.setViewport({ x: 200, y: 80, zoom: 1 }); }, [setNodes, setEdges, resetLLMColors, rfInstance]); @@ -575,7 +617,7 @@ const App = () => { }, 10); // Start auto-saving, if it's not already enabled - if (rf_inst) initAutosaving(rf_inst); + if (rf_inst) initAutosaving(rf_inst, true); }, [resetLLMColors, setNodes, setEdges, initAutosaving], ); @@ -584,23 +626,28 @@ const App = () => { importState(StorageCache.getAllMatching((key) => key.startsWith("r."))); }, [importState]); - const autosavedFlowExists = useCallback(() => { - return window.localStorage.getItem("chainforge-flow") !== null; - }, []); - const loadFlowFromAutosave = useCallback( - async (rf_inst: ReactFlowInstance) => { - const saved_flow = StorageCache.loadFromLocalStorage( - "chainforge-flow", - false, - ) as Dict; - if (saved_flow) { - StorageCache.loadFromLocalStorage("chainforge-state", true); - importGlobalStateFromCache(); - loadFlow(saved_flow, rf_inst); + // Find the autosaved flow, if it exists, returning + // whether it exists and the location ("browser" or "filesystem") that it exists at. + const autosavedFlowExists = useCallback(async () => { + if (IS_RUNNING_LOCALLY) { + // If running locally, we try to fetch a flow autosaved on the user's local machine first: + try { + const response = await axios.get( + `${FLASK_BASE_URL}api/flowExists/__autosave`, + ); + const autosave_file_exists = response.data.exists as boolean; + if (autosave_file_exists) + return { exists: autosave_file_exists, location: "filesystem" }; + } catch (error) { + // Soft fail, continuing onwards to checking localStorage instead } - }, - [importGlobalStateFromCache, loadFlow], - ); + } + + return { + exists: window.localStorage.getItem("chainforge-flow") !== null, + location: "browser", + }; + }, []); // Import data to the cache stored on the local filesystem (in backend) const handleImportCache = useCallback( @@ -715,6 +762,38 @@ const App = () => { fetchOpenAIEval(evalname).then(importFlowFromJSON).catch(handleError); }; + const loadFlowFromAutosave = useCallback( + async (rf_inst: ReactFlowInstance, fromFilesystem?: boolean) => { + if (fromFilesystem) { + // From local filesystem + // Fetch the flow + const response = await axios.get( + `${FLASK_BASE_URL}api/flows/__autosave`, + ); + + // Attempt to load flow into the UI + try { + importFlowFromJSON(response.data, rf_inst); + console.log("Loaded flow from autosave on local machine."); + } catch (error) { + handleError(error as Error); + } + } else { + // From browser localStorage + const saved_flow = StorageCache.loadFromLocalStorage( + "chainforge-flow", + false, + ) as Dict; + if (saved_flow) { + StorageCache.loadFromLocalStorage("chainforge-state", true); + importGlobalStateFromCache(); + loadFlow(saved_flow, rf_inst); + } + } + }, + [importGlobalStateFromCache, loadFlow, importFlowFromJSON, handleError], + ); + // Load flow from examples modal const onSelectExampleFlow = (name: string, example_category?: string) => { // Trigger the 'loading' modal @@ -723,7 +802,7 @@ const App = () => { // Detect a special category of the example flow, and use the right loader for it: if (example_category === "openai-eval") { importFlowFromOpenAIEval(name); - setFlowFileName(`flow-${Date.now()}`); + setFlowFileNameAndCache(`flow-${Date.now()}`); return; } @@ -732,7 +811,7 @@ const App = () => { .then(function (flowJSON) { // We have the data, import it: importFlowFromJSON(flowJSON); - setFlowFileName(`flow-${Date.now()}`); + setFlowFileNameAndCache(`flow-${Date.now()}`); }) .catch(handleError); }; @@ -871,6 +950,20 @@ const App = () => { err.message, ); }); + + // We also need to fetch the current flowFileName + // Attempt to get the last working filename on component mount + const last_working_flow_filename = StorageCache.loadFromLocalStorage( + "chainforge-cur-file", + ); + if ( + last_working_flow_filename && + typeof last_working_flow_filename === "object" && + "flowFileName" in last_working_flow_filename + ) { + // Use last working flow name + setFlowFileName(last_working_flow_filename.flowFileName as string); + } } else { // Check if there's a shared flow UID in the URL as a GET param // If so, we need to look it up in the database and attempt to load it: @@ -910,14 +1003,19 @@ const App = () => { } // Attempt to load an autosaved flow, if one exists: - if (autosavedFlowExists()) loadFlowFromAutosave(rf_inst); - else { - // Load an interesting default starting flow for new users - importFlowFromJSON(EXAMPLEFLOW_1, rf_inst); + autosavedFlowExists().then(({ exists, location }) => { + if (!exists) { + // Load an interesting default starting flow for new users + importFlowFromJSON(EXAMPLEFLOW_1, rf_inst); - // Open a welcome pop-up - // openWelcomeModal(); - } + // Open a welcome pop-up + // openWelcomeModal(); + } else if (location === "browser") { + loadFlowFromAutosave(rf_inst, false); + } else if (location === "filesystem") { + loadFlowFromAutosave(rf_inst, true); + } + }); // Turn off loading wheel setIsLoading(false); @@ -1218,7 +1316,9 @@ const App = () => { <FlowSidebar currentFlow={flowFileName} onLoadFlow={(flowData, name) => { - if (name !== undefined) setFlowFileName(name); + if (name !== undefined) { + setFlowFileNameAndCache(name); + } if (flowData !== undefined) { try { importFlowFromJSON(flowData); @@ -1231,7 +1331,7 @@ const App = () => { }} /> ); - }, [flowFileName, importFlowFromJSON, showAlert]); + }, [flowFileName, importFlowFromJSON, showAlert, setFlowFileNameAndCache]); if (!IS_ACCEPTED_BROWSER) { return ( @@ -1334,6 +1434,7 @@ const App = () => { ml="sm" size="1.625rem" onClick={() => saveFlow()} + bg="#eee" loading={isSaving} disabled={isLoading || isSaving} > diff --git a/chainforge/react-server/src/AreYouSureModal.tsx b/chainforge/react-server/src/AreYouSureModal.tsx index b0ac59c..9578175 100644 --- a/chainforge/react-server/src/AreYouSureModal.tsx +++ b/chainforge/react-server/src/AreYouSureModal.tsx @@ -5,6 +5,7 @@ import { useDisclosure } from "@mantine/hooks"; export interface AreYouSureModalProps { title: string; message: string; + color?: string; onConfirm?: () => void; } @@ -14,7 +15,7 @@ export interface AreYouSureModalRef { /** Modal that lets user rename a single value, using a TextInput field. */ const AreYouSureModal = forwardRef<AreYouSureModalRef, AreYouSureModalProps>( - function AreYouSureModal({ title, message, onConfirm }, ref) { + function AreYouSureModal({ title, message, color, onConfirm }, ref) { const [opened, { open, close }] = useDisclosure(false); const description = message || "Are you sure?"; @@ -37,7 +38,7 @@ const AreYouSureModal = forwardRef<AreYouSureModalRef, AreYouSureModalProps>( onClose={close} title={title} styles={{ - header: { backgroundColor: "orange", color: "white" }, + header: { backgroundColor: color ?? "orange", color: "white" }, root: { position: "relative", left: "-5%" }, }} > @@ -54,7 +55,7 @@ const AreYouSureModal = forwardRef<AreYouSureModalRef, AreYouSureModalProps>( > <Button variant="light" - color="orange" + color={color ?? "orange"} type="submit" w="40%" onClick={close} diff --git a/chainforge/react-server/src/FlowSidebar.tsx b/chainforge/react-server/src/FlowSidebar.tsx index c7566bc..e9612b3 100644 --- a/chainforge/react-server/src/FlowSidebar.tsx +++ b/chainforge/react-server/src/FlowSidebar.tsx @@ -5,6 +5,7 @@ import { IconMenu2, IconX, IconCheck, + IconCopy, } from "@tabler/icons-react"; import axios from "axios"; import { AlertModalContext } from "./AlertModal"; @@ -20,6 +21,7 @@ import { Flex, Divider, ScrollArea, + Tooltip, } from "@mantine/core"; import { FLASK_BASE_URL } from "./backend/utils"; @@ -112,6 +114,26 @@ const FlowSidebar: React.FC<FlowSidebarProps> = ({ setNewEditName(flowFile); }; + // 'Duplicate' the flow + const handleDuplicateFlow = async ( + flowFile: string, + event: React.MouseEvent<HTMLButtonElement, MouseEvent>, + ) => { + event.stopPropagation(); // Prevent triggering the parent click + await axios + .put(`${FLASK_BASE_URL}api/flows/${flowFile}`, { + duplicate: true, + }) + .then((resp) => { + onLoadFlow(undefined, resp.data.copyName as string); // Tell the parent that the filename has changed. This won't replace the flow. + fetchSavedFlowList(); // Refresh the list + }) + .catch((err) => { + console.error(err); + if (showAlert) showAlert(err); + }); + }; + // Cancel editing const handleCancelEdit = ( event: React.MouseEvent<HTMLButtonElement, MouseEvent>, @@ -191,7 +213,7 @@ const FlowSidebar: React.FC<FlowSidebarProps> = ({ onClose={() => setIsOpen(false)} title="Saved Flows" position="left" - size="250px" // Adjust sidebar width + size="350px" // Adjust sidebar width padding="md" withCloseButton={true} scrollAreaComponent={ScrollArea.Autosize} @@ -261,18 +283,45 @@ const FlowSidebar: React.FC<FlowSidebarProps> = ({ {flow.name} </Text> <Flex gap="0px"> - <ActionIcon - color="blue" - onClick={(e) => handleEditClick(flow.name, e)} + <Tooltip + label="Edit name" + withArrow + arrowPosition="center" + withinPortal > - <IconEdit size={18} /> - </ActionIcon> - <ActionIcon - color="red" - onClick={(e) => handleDeleteFlow(flow.name, e)} + <ActionIcon + color="blue" + onClick={(e) => handleEditClick(flow.name, e)} + > + <IconEdit size={18} /> + </ActionIcon> + </Tooltip> + <Tooltip + label="Duplicate this flow" + withArrow + arrowPosition="center" + withinPortal > - <IconTrash size={18} /> - </ActionIcon> + <ActionIcon + color="blue" + onClick={(e) => handleDuplicateFlow(flow.name, e)} + > + <IconCopy size={18} /> + </ActionIcon> + </Tooltip> + <Tooltip + label="Delete this flow" + withArrow + arrowPosition="center" + withinPortal + > + <ActionIcon + color="red" + onClick={(e) => handleDeleteFlow(flow.name, e)} + > + <IconTrash size={18} /> + </ActionIcon> + </Tooltip> </Flex> </Flex> <Text size="xs" color="gray"> diff --git a/chainforge/react-server/src/ItemsNode.tsx b/chainforge/react-server/src/ItemsNode.tsx index 16e6b3d..0f0198e 100644 --- a/chainforge/react-server/src/ItemsNode.tsx +++ b/chainforge/react-server/src/ItemsNode.tsx @@ -55,7 +55,7 @@ const ItemsNode: React.FC<ItemsNodeProps> = ({ data, id }) => { const flags = useStore((state) => state.flags); const [contentDiv, setContentDiv] = useState<React.ReactNode | null>(null); - const [isEditing, setIsEditing] = useState(true); + const [isEditing, setIsEditing] = useState(false); const [csvInput, setCsvInput] = useState<React.ReactNode | null>(null); const [countText, setCountText] = useState<React.ReactNode | null>(null); diff --git a/chainforge/react-server/src/LLMListComponent.tsx b/chainforge/react-server/src/LLMListComponent.tsx index 900870f..f111443 100644 --- a/chainforge/react-server/src/LLMListComponent.tsx +++ b/chainforge/react-server/src/LLMListComponent.tsx @@ -23,39 +23,18 @@ import { StrictModeDroppable } from "./StrictModeDroppable"; import ModelSettingsModal, { ModelSettingsModalRef, } from "./ModelSettingsModal"; -import { - getDefaultModelFormData, - getDefaultModelSettings, -} from "./ModelSettingSchemas"; +import { getDefaultModelSettings } from "./ModelSettingSchemas"; import useStore, { initLLMProviders, initLLMProviderMenu } from "./store"; import { Dict, JSONCompatible, LLMGroup, LLMSpec } from "./backend/typing"; import { useContextMenu } from "mantine-contextmenu"; import { ContextMenuItemOptions } from "mantine-contextmenu/dist/types"; +import { ensureUniqueName } from "./backend/utils"; // The LLM(s) to include by default on a PromptNode whenever one is created. // Defaults to ChatGPT (GPT3.5) when running locally, and HF-hosted falcon-7b for online version since it's free. const DEFAULT_INIT_LLMS = [initLLMProviders[0]]; // Helper funcs -// Ensure that a name is 'unique'; if not, return an amended version with a count tacked on (e.g. "GPT-4 (2)") -const ensureUniqueName = (_name: string, _prev_names: string[]) => { - // Strip whitespace around names - const prev_names = _prev_names.map((n) => n.trim()); - const name = _name.trim(); - - // Check if name is unique - if (!prev_names.includes(name)) return name; - - // Name isn't unique; find a unique one: - let i = 2; - let new_name = `${name} (${i})`; - while (prev_names.includes(new_name)) { - i += 1; - new_name = `${name} (${i})`; - } - return new_name; -}; - /** Get position CSS style below and left-aligned to the input element */ const getPositionCSSStyle = ( elem: HTMLButtonElement, diff --git a/chainforge/react-server/src/LLMResponseInspectorModal.tsx b/chainforge/react-server/src/LLMResponseInspectorModal.tsx index bfc3b3d..54d6f00 100644 --- a/chainforge/react-server/src/LLMResponseInspectorModal.tsx +++ b/chainforge/react-server/src/LLMResponseInspectorModal.tsx @@ -71,7 +71,10 @@ const LLMResponseInspectorModal = forwardRef< </button> </div> } - styles={{ title: { justifyContent: "space-between", width: "100%" } }} + styles={{ + title: { justifyContent: "space-between", width: "100%" }, + header: { paddingBottom: "0px" }, + }} > <div className="inspect-modal-response-container" diff --git a/chainforge/react-server/src/MultiEvalNode.tsx b/chainforge/react-server/src/MultiEvalNode.tsx index 10db55e..9f15cea 100644 --- a/chainforge/react-server/src/MultiEvalNode.tsx +++ b/chainforge/react-server/src/MultiEvalNode.tsx @@ -221,10 +221,10 @@ const MultiEvalNode: React.FC<MultiEvalNodeProps> = ({ data, id }) => { const bringNodeToFront = useStore((state) => state.bringNodeToFront); const inputEdgesForNode = useStore((state) => state.inputEdgesForNode); - const flags = useStore((state) => state.flags); - const AI_SUPPORT_ENABLED = useMemo(() => { - return flags.aiSupport; - }, [flags]); + // const flags = useStore((state) => state.flags); + // const AI_SUPPORT_ENABLED = useMemo(() => { + // return flags.aiSupport; + // }, [flags]); const [status, setStatus] = useState<Status>(Status.NONE); // For displaying error messages to user diff --git a/chainforge/react-server/src/PromptNode.tsx b/chainforge/react-server/src/PromptNode.tsx index 93af5a8..ce811a0 100644 --- a/chainforge/react-server/src/PromptNode.tsx +++ b/chainforge/react-server/src/PromptNode.tsx @@ -18,9 +18,20 @@ import { Modal, Box, Tooltip, + Flex, + Button, + ActionIcon, + Divider, } from "@mantine/core"; import { useDisclosure } from "@mantine/hooks"; -import { IconEraser, IconList } from "@tabler/icons-react"; +import { + IconArrowLeft, + IconArrowRight, + IconEraser, + IconList, + IconPlus, + IconTrash, +} from "@tabler/icons-react"; import useStore from "./store"; import BaseNode from "./BaseNode"; import NodeLabel from "./NodeLabelComponent"; @@ -41,6 +52,7 @@ import { extractSettingsVars, truncStr, genDebounceFunc, + ensureUniqueName, } from "./backend/utils"; import LLMResponseInspectorDrawer from "./LLMResponseInspectorDrawer"; import CancelTracker from "./backend/canceler"; @@ -64,6 +76,8 @@ import { queryLLM, } from "./backend/backend"; import { StringLookup } from "./backend/cache"; +import { union } from "./backend/setUtils"; +import AreYouSureModal, { AreYouSureModalRef } from "./AreYouSureModal"; const getUniqueLLMMetavarKey = (responses: LLMResponse[]) => { const metakeys = new Set( @@ -82,22 +96,44 @@ const bucketChatHistoryInfosByLLM = (chat_hist_infos: ChatHistoryInfo[]) => { }); return chats_by_llm; }; +const getRootPromptFor = ( + promptTexts: string | string[], + varNameForRootTemplate: string, +) => { + if (typeof promptTexts === "string") return promptTexts; + else if (promptTexts.length === 1) return promptTexts[0]; + else return `{${varNameForRootTemplate}}`; +}; export class PromptInfo { prompt: string; - settings: Dict; + settings?: Dict; + label?: string; - constructor(prompt: string, settings: Dict) { + constructor(prompt: string, settings?: Dict, label?: string) { this.prompt = prompt; this.settings = settings; + this.label = label; } } -const displayPromptInfos = (promptInfos: PromptInfo[], wideFormat: boolean) => +const displayPromptInfos = ( + promptInfos: PromptInfo[], + wideFormat: boolean, + bgColor?: string, +) => promptInfos.map((info, idx) => ( <div key={idx}> - <div className="prompt-preview">{info.prompt}</div> - {info.settings ? ( + <div className="prompt-preview" style={{ backgroundColor: bgColor }}> + {info.label && ( + <Text c="black" size="xs" fw="bold" mb={0}> + {info.label} + <hr /> + </Text> + )} + {info.prompt} + </div> + {info.settings && Object.entries(info.settings).map(([key, val]) => { return ( <div key={key} className="settings-var-inline response-var-inline"> @@ -107,10 +143,7 @@ const displayPromptInfos = (promptInfos: PromptInfo[], wideFormat: boolean) => </span> </div> ); - }) - ) : ( - <></> - )} + })} </div> )); @@ -118,12 +151,14 @@ export interface PromptListPopoverProps { promptInfos: PromptInfo[]; onHover: () => void; onClick: () => void; + promptTemplates?: string[] | string; } export const PromptListPopover: React.FC<PromptListPopoverProps> = ({ promptInfos, onHover, onClick, + promptTemplates, }) => { const [opened, { close, open }] = useDisclosure(false); @@ -172,6 +207,29 @@ export const PromptListPopover: React.FC<PromptListPopoverProps> = ({ Preview of generated prompts ({promptInfos.length} total) </Text> </Center> + {Array.isArray(promptTemplates) && promptTemplates.length > 1 && ( + <Box> + <Divider + my="xs" + label="Prompt variants" + fw="bold" + labelPosition="center" + /> + {displayPromptInfos( + promptTemplates.map( + (t, i) => new PromptInfo(t, undefined, `Variant ${i + 1}`), + ), + false, + "#ddf1f8", + )} + <Divider + my="xs" + label="Concrete prompts" + fw="bold" + labelPosition="center" + /> + </Box> + )} {displayPromptInfos(promptInfos, false)} </Popover.Dropdown> </Popover> @@ -182,12 +240,14 @@ export interface PromptListModalProps { promptPreviews: PromptInfo[]; infoModalOpened: boolean; closeInfoModal: () => void; + promptTemplates?: string[] | string; } export const PromptListModal: React.FC<PromptListModalProps> = ({ promptPreviews, infoModalOpened, closeInfoModal, + promptTemplates, }) => { return ( <Modal @@ -205,6 +265,29 @@ export const PromptListModal: React.FC<PromptListModalProps> = ({ }} > <Box m="lg" mt="xl"> + {Array.isArray(promptTemplates) && promptTemplates.length > 1 && ( + <Box> + <Divider + my="xs" + label="Prompt variants" + fw="bold" + labelPosition="center" + /> + {displayPromptInfos( + promptTemplates.map( + (t, i) => new PromptInfo(t, undefined, `Variant ${i + 1}`), + ), + true, + "#ddf1f8", + )} + <Divider + my="xs" + label="Concrete prompts (filled in)" + fw="bold" + labelPosition="center" + /> + </Box> + )} {displayPromptInfos(promptPreviews, true)} </Box> </Modal> @@ -221,6 +304,7 @@ export interface PromptNodeProps { contChat: boolean; refresh: boolean; refreshLLMList: boolean; + idxPromptVariantShown?: number; }; id: string; type: string; @@ -257,10 +341,15 @@ const PromptNode: React.FC<PromptNodeProps> = ({ null, ); const [templateVars, setTemplateVars] = useState<string[]>(data.vars ?? []); - const [promptText, setPromptText] = useState<string>(data.prompt ?? ""); - const [promptTextOnLastRun, setPromptTextOnLastRun] = useState<string | null>( - null, + const [promptText, setPromptText] = useState<string | string[]>( + data.prompt ?? "", ); + const [idxPromptVariantShown, setIdxPromptVariantShown] = useState<number>( + data.idxPromptVariantShown ?? 0, + ); + const [promptTextOnLastRun, setPromptTextOnLastRun] = useState< + string | string[] | null + >(null); const [status, setStatus] = useState(Status.NONE); const [numGenerations, setNumGenerations] = useState<number>(data.n ?? 1); const [numGenerationsLastRun, setNumGenerationsLastRun] = useState<number>( @@ -391,10 +480,17 @@ const PromptNode: React.FC<PromptNodeProps> = ({ }, [templateVars, id, pullInputData, updateShowContToggle]); const refreshTemplateHooks = useCallback( - (text: string) => { - // Update template var fields + handles - const found_template_vars = new Set(extractBracketedSubstrings(text)); // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this} + (text: string | string[]) => { + const texts = typeof text === "string" ? [text] : text; + // Get all template vars in the prompt(s) + let found_template_vars = new Set<string>(); + for (const t of texts) { + const substrs = extractBracketedSubstrings(t); // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this} + found_template_vars = union(found_template_vars, new Set(substrs)); + } + + // Update template var fields + handles if (!setsAreEqual(found_template_vars, new Set(templateVars))) { if (node_type !== "chat") { try { @@ -413,27 +509,29 @@ const PromptNode: React.FC<PromptNodeProps> = ({ const handleInputChange = useCallback( (event: React.ChangeEvent<HTMLTextAreaElement>) => { - const value = event.target.value; + const value = event.target.value as string; const updateStatus = promptTextOnLastRun !== null && status !== Status.WARNING && value !== promptTextOnLastRun; - // Store prompt text - data.prompt = value; - // Debounce the global state change to happen only after 500ms, as it forces a costly rerender: - debounce((_value, _updateStatus) => { - setPromptText(_value); - setDataPropsForNode(id, { prompt: _value }); - refreshTemplateHooks(_value); + debounce((_value: string, _updateStatus, _idxPromptVariantShown) => { + setPromptText((prompts) => { + if (typeof prompts === "string") prompts = _value; + else prompts[_idxPromptVariantShown] = _value; + setDataPropsForNode(id, { prompt: prompts }); + refreshTemplateHooks(prompts); + return prompts; + }); if (_updateStatus) setStatus(Status.WARNING); - }, 300)(value, updateStatus); + }, 300)(value, updateStatus, idxPromptVariantShown); // Debounce refreshing the template hooks so we don't annoy the user // debounce((_value) => refreshTemplateHooks(_value), 500)(value); }, [ + idxPromptVariantShown, promptTextOnLastRun, status, refreshTemplateHooks, @@ -552,7 +650,7 @@ const PromptNode: React.FC<PromptNodeProps> = ({ // Ask the backend how many responses it needs to collect, given the input data: const fetchResponseCounts = useCallback( ( - prompt: string, + prompt: string | string[], vars: Dict, llms: (StringOrHash | LLMSpec)[], chat_histories?: @@ -592,14 +690,24 @@ const PromptNode: React.FC<PromptNodeProps> = ({ const pulled_vars = pullInputData(templateVars, id); updateShowContToggle(pulled_vars); - generatePrompts(promptText, pulled_vars).then((prompts) => { - setPromptPreviews( - prompts.map( - (p: PromptTemplate) => - new PromptInfo(p.toString(), extractSettingsVars(p.fill_history)), - ), - ); - }); + const prompts = + typeof promptText === "string" ? [promptText] : promptText; + + Promise.all(prompts.map((p) => generatePrompts(p, pulled_vars))).then( + (results) => { + // Handle all the results here + const all_concrete_prompts = results.flatMap((ps) => + ps.map( + (p: PromptTemplate) => + new PromptInfo( + p.toString(), + extractSettingsVars(p.fill_history), + ), + ), + ); + setPromptPreviews(all_concrete_prompts); + }, + ); pullInputChats(); } catch (err) { @@ -827,9 +935,18 @@ Soft failing by replacing undefined with empty strings.`, // Pull the data to fill in template input variables, if any let pulled_data: Dict<(string | TemplateVarInfo)[]> = {}; + let var_for_prompt_templates: string; try { // Try to pull inputs pulled_data = pullInputData(templateVars, id); + + // Add a special new variable for the root prompt template(s) + var_for_prompt_templates = ensureUniqueName( + "prompt", + Object.keys(pulled_data), + ); + if (typeof promptText !== "string" && promptText.length > 1) + pulled_data[var_for_prompt_templates] = promptText; // this will be filled in when calling queryLLMs } catch (err) { if (showAlert) showAlert((err as Error)?.message ?? err); console.error(err); @@ -873,7 +990,9 @@ Soft failing by replacing undefined with empty strings.`, // Fetch info about the number of queries we'll need to make const fetch_resp_count = () => fetchResponseCounts( - prompt_template, + typeof prompt_template === "string" + ? prompt_template + : `{${var_for_prompt_templates}}`, // Use special root prompt if there's multiple prompt variants pulled_data, _llmItemsCurrState, pulled_chats as ChatHistoryInfo[], @@ -951,9 +1070,9 @@ Soft failing by replacing undefined with empty strings.`, const query_llms = () => { return queryLLM( id, - _llmItemsCurrState, // deep clone it first + _llmItemsCurrState, numGenerations, - prompt_template, + getRootPromptFor(prompt_template, var_for_prompt_templates), // Use special root prompt if there's multiple prompt variants pulled_data, chat_hist_by_llm, apiKeys || {}, @@ -1015,7 +1134,7 @@ Soft failing by replacing undefined with empty strings.`, o.metavars = resp_obj.metavars ?? {}; // Add a metavar for the prompt *template* in this PromptNode - o.metavars.__pt = prompt_template; + // o.metavars.__pt = prompt_template; // Carry over any chat history if (resp_obj.chat_history) @@ -1162,6 +1281,16 @@ Soft failing by replacing undefined with empty strings.`, // Dynamically update the textareas and position of the template hooks const textAreaRef = useRef<HTMLTextAreaElement | HTMLDivElement | null>(null); + const resizeTextarea = () => { + const textarea = textAreaRef.current; + + if (textarea) { + textarea.style.height = "auto"; // Reset height to shrink if needed + const newHeight = Math.min(textarea.scrollHeight, 600); + textarea.style.height = `${newHeight}px`; + } + }; + const [hooksY, setHooksY] = useState(138); const setRef = useCallback( (elem: HTMLDivElement | HTMLTextAreaElement | null) => { @@ -1188,6 +1317,147 @@ Soft failing by replacing undefined with empty strings.`, [textAreaRef], ); + const deleteVariantConfirmModal = useRef<AreYouSureModalRef>(null); + const handleAddPromptVariant = useCallback(() => { + // Pushes a new prompt variant, updating the prompts list and duplicating the current shown prompt + const prompts = typeof promptText === "string" ? [promptText] : promptText; + const curIdx = Math.max( + 0, + Math.min(prompts.length - 1, idxPromptVariantShown), + ); // clamp + const curShownPrompt = prompts[curIdx]; + setPromptText(prompts.concat([curShownPrompt])); + setIdxPromptVariantShown(prompts.length); + }, [promptText, idxPromptVariantShown]); + + const gotoPromptVariant = useCallback( + (shift: number) => { + const prompts = + typeof promptText === "string" ? [promptText] : promptText; + const newIdx = Math.max( + 0, + Math.min(prompts.length - 1, idxPromptVariantShown + shift), + ); // clamp + setIdxPromptVariantShown(newIdx); + resizeTextarea(); + }, + [promptText, idxPromptVariantShown], + ); + + const handleRemovePromptVariant = useCallback(() => { + setPromptText((prompts) => { + if (typeof prompts === "string" || prompts.length === 1) return prompts; // cannot remove the last one + prompts.splice(idxPromptVariantShown, 1); // remove the indexed variant + const newIdx = Math.max(0, idxPromptVariantShown - 1); + setIdxPromptVariantShown(newIdx); // goto the previous variant, if possible + + if (textAreaRef.current) { + // We have to force an update here since idxPromptVariantShown might've not changed + // @ts-expect-error Mantine has a 'value' property on Textareas, but TypeScript doesn't know this + textAreaRef.current.value = prompts[newIdx]; + resizeTextarea(); + } + + return [...prompts]; + }); + }, [idxPromptVariantShown, textAreaRef]); + + // Whenever idx of prompt variant changes, we need to refresh the Textarea: + useEffect(() => { + if (textAreaRef.current && Array.isArray(promptText)) { + // @ts-expect-error Mantine has a 'value' property on Textareas, but TypeScript doesn't know this + textAreaRef.current.value = promptText[idxPromptVariantShown]; + resizeTextarea(); + } + }, [idxPromptVariantShown]); + + const promptVariantControls = useMemo(() => { + return ( + <Flex justify="right" pos="absolute" right={10}> + {typeof promptText === "string" || promptText.length === 1 ? ( + <Tooltip + label="Add prompt variant. This duplicates the current prompt, allowing you to tweak it to test variations. (You can also accomplish the same thing by template chaining.)" + multiline + position="right" + withArrow + arrowSize={8} + w={220} + withinPortal + > + <Button + size="xs" + variant="subtle" + color="gray" + mt={3} + mr={3} + p={0} + fw="normal" + h="1.0rem" + onClick={handleAddPromptVariant} + > + + add variant + </Button> + </Tooltip> + ) : ( + <> + <ActionIcon + size="xs" + c="black" + onClick={() => gotoPromptVariant(-1)} + > + <IconArrowLeft size={19} /> + </ActionIcon> + + <Text size="xs"> + Variant {idxPromptVariantShown + 1} of{" "} + {typeof promptText === "string" ? 1 : promptText.length} + </Text> + + <ActionIcon + size="xs" + c="black" + mr={3} + onClick={() => gotoPromptVariant(1)} + > + <IconArrowRight size={19} /> + </ActionIcon> + + <Tooltip + label="Add prompt variant" + position="right" + withArrow + withinPortal + > + <ActionIcon + size="xs" + c="black" + mr={2} + onClick={handleAddPromptVariant} + > + <IconPlus size={19} /> + </ActionIcon> + </Tooltip> + + <Tooltip + label="Remove this variant" + position="right" + withArrow + withinPortal + > + <ActionIcon + size="xs" + c="black" + onClick={() => deleteVariantConfirmModal?.current?.trigger()} + > + <IconTrash size={19} /> + </ActionIcon> + </Tooltip> + </> + )} + </Flex> + ); + }, [idxPromptVariantShown, promptText, deleteVariantConfirmModal]); + // Add custom context menu options on right-click. // 1. Convert TextFields to Items Node, for convenience. const customContextMenuItems = useMemo( @@ -1229,6 +1499,7 @@ Soft failing by replacing undefined with empty strings.`, <PromptListPopover key="prompt-previews" promptInfos={promptPreviews} + promptTemplates={promptText} onHover={handlePreviewHover} onClick={openInfoModal} />, @@ -1240,9 +1511,17 @@ Soft failing by replacing undefined with empty strings.`, /> <PromptListModal promptPreviews={promptPreviews} + promptTemplates={promptText} infoModalOpened={infoModalOpened} closeInfoModal={closeInfoModal} /> + <AreYouSureModal + ref={deleteVariantConfirmModal} + title="Delete prompt variant" + message="Are you sure you want to delete this prompt variant? This action is irreversible." + color="red" + onConfirm={handleRemovePromptVariant} + /> {node_type === "chat" ? ( <div ref={setRef}> @@ -1254,7 +1533,12 @@ Soft failing by replacing undefined with empty strings.`, key={0} className="prompt-field-fixed nodrag nowheel" minRows={4} - defaultValue={data.prompt} + defaultValue={ + typeof data.prompt === "string" + ? data.prompt + : data.prompt && + data.prompt[data.idxPromptVariantShown ?? 0] + } onChange={handleInputChange} miw={230} styles={{ @@ -1273,15 +1557,22 @@ Soft failing by replacing undefined with empty strings.`, ) : ( <Textarea ref={setRef} - autosize + // autosize className="prompt-field-fixed nodrag nowheel" - minRows={4} + minRows={5} maxRows={12} - defaultValue={data.prompt} + defaultValue={ + typeof data.prompt === "string" + ? data.prompt + : data.prompt && data.prompt[data.idxPromptVariantShown ?? 0] + } onChange={handleInputChange} + // value={typeof promptText === "string" ? promptText : promptText[idxPromptVariantShown]} /> )} + {promptVariantControls} + <Handle type="source" position={Position.Right} @@ -1289,13 +1580,17 @@ Soft failing by replacing undefined with empty strings.`, className="grouped-handle" style={{ top: "50%" }} /> - <TemplateHooks - vars={templateVars} - nodeId={id} - startY={hooksY} - position={Position.Left} - ignoreHandles={["__past_chats"]} - /> + + <Box mih={14}> + <TemplateHooks + vars={templateVars} + nodeId={id} + startY={hooksY} + position={Position.Left} + ignoreHandles={["__past_chats"]} + /> + </Box> + <hr /> <div> <div style={{ marginBottom: "10px", padding: "4px" }}> diff --git a/chainforge/react-server/src/backend/backend.ts b/chainforge/react-server/src/backend/backend.ts index 55e3775..988c837 100644 --- a/chainforge/react-server/src/backend/backend.ts +++ b/chainforge/react-server/src/backend/backend.ts @@ -29,6 +29,8 @@ import { repairCachedResponses, deepcopy, llmResponseDataToString, + extendArray, + extendArrayDict, } from "./utils"; import StorageCache, { StringLookup } from "./cache"; import { PromptPipeline } from "./query"; @@ -520,7 +522,7 @@ export async function generatePrompts( /** * Calculates how many queries we need to make, given the passed prompt and vars. * - * @param prompt the prompt template, with any {{}} vars + * @param prompt the prompt template, with any {} vars; or alternatively, an array of such templates * @param vars a dict of the template variables to fill the prompt template with, by name. * For each var value, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.) * @param llms the list of LLMs you will query @@ -531,7 +533,7 @@ export async function generatePrompts( * If there was an error, returns a dict with a single key, 'error'. */ export async function countQueries( - prompt: string, + prompt: string | string[], vars: PromptVarsDict, llms: Array<StringOrHash | LLMSpec>, n: number, @@ -545,19 +547,27 @@ export async function countQueries( vars = deepcopy(vars); llms = deepcopy(llms); - let all_prompt_permutations: PromptTemplate[] | Dict<PromptTemplate[]>; + const prompt_templates = typeof prompt === "string" ? [prompt] : prompt; + const all_prompt_permutations: PromptTemplate[] | Dict<PromptTemplate[]> = + cont_only_w_prior_llms && Array.isArray(llms) ? {} : []; - const gen_prompts = new PromptPermutationGenerator(prompt); - if (cont_only_w_prior_llms && Array.isArray(llms)) { - all_prompt_permutations = {}; - llms.forEach((llm_spec) => { - const llm_key = extract_llm_key(llm_spec); - (all_prompt_permutations as Dict<PromptTemplate[]>)[llm_key] = Array.from( - gen_prompts.generate(filterVarsByLLM(vars, llm_key)), + for (const pt of prompt_templates) { + const gen_prompts = new PromptPermutationGenerator(pt); + if (cont_only_w_prior_llms && Array.isArray(llms)) { + llms.forEach((llm_spec) => { + const llm_key = extract_llm_key(llm_spec); + extendArrayDict( + all_prompt_permutations as Dict<PromptTemplate[]>, + llm_key, + Array.from(gen_prompts.generate(filterVarsByLLM(vars, llm_key))), + ); + }); + } else { + extendArray( + all_prompt_permutations as PromptTemplate[], + Array.from(gen_prompts.generate(vars)), ); - }); - } else { - all_prompt_permutations = Array.from(gen_prompts.generate(vars)); + } } let cache_file_lookup: Dict = {}; @@ -701,10 +711,12 @@ export async function fetchEnvironAPIKeys(): Promise<Dict<string>> { export async function saveFlowToLocalFilesystem( flowJSON: Dict, filename: string, + alsoAutosave: boolean, ): Promise<void> { try { await axios.put(`${FLASK_BASE_URL}api/flows/${filename}`, { flow: flowJSON, + alsoAutosave: alsoAutosave, }); } catch (error) { throw new Error( @@ -739,7 +751,7 @@ export async function ensureUniqueFlowFilename( * @param id a unique ID to refer to this information. Used when cache'ing responses. * @param llm a string, list of strings, or list of LLM spec dicts specifying the LLM(s) to query. * @param n the amount of generations for each prompt. All LLMs will be queried the same number of times 'n' per each prompt. - * @param prompt the prompt template, with any {{}} vars + * @param prompt the prompt template, with any {} vars * @param vars a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.) * @param chat_histories Either an array of `ChatHistory` (to use across all LLMs), or a dict indexed by LLM nicknames of `ChatHistory` arrays to use per LLM. diff --git a/chainforge/react-server/src/backend/models.ts b/chainforge/react-server/src/backend/models.ts index 09110f6..7f13863 100644 --- a/chainforge/react-server/src/backend/models.ts +++ b/chainforge/react-server/src/backend/models.ts @@ -313,6 +313,8 @@ export const RATE_LIMIT_BY_MODEL: { [key in LLM]?: number } = { }; export const RATE_LIMIT_BY_PROVIDER: { [key in LLMProvider]?: number } = { + [LLMProvider.OpenAI]: 1000, // Tier 3 pricing limit is 5000 per minute, across most models, we use 1000 to be safe. + [LLMProvider.Azure_OpenAI]: 1000, // Tier 3 pricing limit is 5000 per minute, across most models, we use 1000 to be safe. [LLMProvider.Anthropic]: 25, // Tier 1 pricing limit is 50 per minute, across all models; we halve this, to be safe. [LLMProvider.Together]: 30, // Paid tier limit is 60 per minute, across all models; we halve this, to be safe. [LLMProvider.Google]: 1000, // RPM for Google Gemini models 1.5 is quite generous; at base it is 1000 RPM. If you are using the free version it's 15 RPM, but we can expect most CF users to be using paid (and anyway you can just re-run prompt node until satisfied). diff --git a/chainforge/react-server/src/backend/utils.ts b/chainforge/react-server/src/backend/utils.ts index 55f9ff3..ce55d60 100644 --- a/chainforge/react-server/src/backend/utils.ts +++ b/chainforge/react-server/src/backend/utils.ts @@ -1967,10 +1967,12 @@ export const extractSettingsVars = (vars?: PromptVarsDict) => { vars !== undefined && Object.keys(vars).some((k) => k.charAt(0) === "=") ) { - return transformDict( - deepcopy(vars), - (k) => k.charAt(0) === "=", - (k) => k.substring(1), + return StringLookup.concretizeDict( + transformDict( + deepcopy(vars), + (k) => k.charAt(0) === "=", + (k) => k.substring(1), + ), ); } else return {}; }; @@ -2398,3 +2400,52 @@ export const compressBase64Image = (b64: string): Promise<string> => { ) .then((compressedBlob) => blobToBase64(compressedBlob as Blob)); }; + +/** + * Extends array `a` with the values of `b`. + * @param a The array to extend (in-place). + * @param b The array to add to the end of `a`. + * @returns `a`, extended. + */ +export const extendArray = <T>(a: Array<T>, b: Array<T>): Array<T> => { + for (const i in b) { + a.push(b[i]); + } + return a; +}; + +/** + * Extends the array `key` in a dict with `values`, creating a new array if the key is missing. + * @param dict The dictionary to extend (in-place). + * @param key The key of the dictionary. + * @param values The new array to append to the end of the dict value for `key`. + */ +export const extendArrayDict = <K extends string | number | symbol, V>( + dict: Record<K, V[]>, + key: K, + values: V[], +): void => { + if (!dict[key]) { + dict[key] = []; + } + extendArray(dict[key], values); +}; + +/** Ensure that a name is 'unique'; if not, return an amended version with a count tacked on (e.g. "GPT-4 (2)") */ +export const ensureUniqueName = (_name: string, _prev_names: string[]) => { + // Strip whitespace around names + const prev_names = _prev_names.map((n) => n.trim()); + const name = _name.trim(); + + // Check if name is unique + if (!prev_names.includes(name)) return name; + + // Name isn't unique; find a unique one: + let i = 2; + let new_name = `${name} (${i})`; + while (prev_names.includes(new_name)) { + i += 1; + new_name = `${name} (${i})`; + } + return new_name; +}; diff --git a/setup.py b/setup.py index b6c42e5..6648757 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ def readme(): setup( name="chainforge", - version="0.3.4.3", + version="0.3.4.4", packages=find_packages(), author="Ian Arawjo", description="A Visual Programming Environment for Prompt Engineering",