From ce282459ca7955667dc940b1c21f2537a2d67a31 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Mon, 11 Mar 2024 10:34:43 -0400 Subject: [PATCH] Turn AlertModal into a Provider with useContext --- chainforge/flask_app.py | 4 +- chainforge/react-server/src/AiPopover.tsx | 20 ++- chainforge/react-server/src/AlertModal.tsx | 95 +++++++++---- chainforge/react-server/src/App.tsx | 126 +++++++++++------- .../react-server/src/CodeEvaluatorNode.tsx | 10 +- .../react-server/src/GlobalSettingsModal.tsx | 56 ++++---- chainforge/react-server/src/LLMEvalNode.js | 9 +- .../react-server/src/NodeLabelComponent.tsx | 5 +- chainforge/react-server/src/PromptNode.tsx | 24 ++-- chainforge/react-server/src/SimpleEvalNode.js | 16 ++- .../react-server/src/TabularDataNode.tsx | 16 ++- .../react-server/src/backend/backend.ts | 97 +++++++++----- chainforge/react-server/src/backend/cache.ts | 5 +- chainforge/react-server/src/store.tsx | 4 +- 14 files changed, 301 insertions(+), 186 deletions(-) diff --git a/chainforge/flask_app.py b/chainforge/flask_app.py index 81aa29e..fb2217a 100644 --- a/chainforge/flask_app.py +++ b/chainforge/flask_app.py @@ -268,13 +268,13 @@ def index(): def executepy(): """ Executes a Python function sent from JavaScript, - over all the `StandardizedLLMResponse` objects passed in from the front-end. + over all the `LLMResponse` objects passed in from the front-end. POST'd data should be in the form: { 'id': # a unique ID to refer to this information. Used when cache'ing responses. 'code': str, # the body of the lambda function to evaluate, in form: lambda responses: - 'responses': List[StandardizedLLMResponse] # the responses to run on. + 'responses': List[LLMResponse] # the responses to run on. 'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch. # If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt. 'process_type': 'evaluator' | 'processor' # the type of processing to perform. Evaluators only 'score'/annotate responses. Processors change responses (e.g. text). diff --git a/chainforge/react-server/src/AiPopover.tsx b/chainforge/react-server/src/AiPopover.tsx index d0eeabd..688935c 100644 --- a/chainforge/react-server/src/AiPopover.tsx +++ b/chainforge/react-server/src/AiPopover.tsx @@ -1,4 +1,4 @@ -import React, { useCallback, useMemo, useRef, useState } from "react"; +import React, { useCallback, useContext, useMemo, useState } from "react"; import { Stack, NumberInput, @@ -18,7 +18,7 @@ import { getAIFeaturesModels, } from "./backend/ai"; import { IconSparkles, IconAlertCircle } from "@tabler/icons-react"; -import AlertModal, { AlertModalRef } from "./AlertModal"; +import { AlertModalContext } from "./AlertModal"; import useStore from "./store"; import { INFO_CODEBLOCK_JS, @@ -30,7 +30,7 @@ import { queryLLM } from "./backend/backend"; import { splitText } from "./SplitNode"; import { escapeBraces } from "./backend/template"; import { cleanMetavarsFilterFunc } from "./backend/utils"; -import { Dict, TemplateVarInfo, VarsContext } from "./backend/typing"; +import { VarsContext } from "./backend/typing"; const zeroGap = { gap: "0rem" }; const popoverShadow = "rgb(38, 57, 77) 0px 10px 30px -14px"; @@ -255,7 +255,7 @@ export function AIGenReplaceItemsPopover({ const aiFeaturesProvider = useStore((state) => state.aiFeaturesProvider); // Alerts - const alertModal = useRef(null); + const showAlert = useContext(AlertModalContext); // Command Fill state const [commandFillNumber, setCommandFillNumber] = useState(3); @@ -298,7 +298,7 @@ export function AIGenReplaceItemsPopover({ if (e instanceof AIError) { setDidCommandFillError(true); } else { - if (alertModal.current) alertModal.current.trigger(e?.message); + if (showAlert) showAlert(e?.message); else console.error(e); } }) @@ -321,7 +321,7 @@ export function AIGenReplaceItemsPopover({ console.log(e); setDidGenerateAndReplaceError(true); } else { - if (alertModal.current) alertModal.current.trigger(e?.message); + if (showAlert) showAlert(e?.message); else console.error(e); } }) @@ -470,7 +470,6 @@ export function AIGenReplaceItemsPopover({ {replaceUI} - ); } @@ -508,7 +507,7 @@ export function AIGenCodeEvaluatorPopover({ const [awaitingResponse, setAwaitingResponse] = useState(false); // Alerts - const alertModal = useRef(null); + const showAlert = useContext(AlertModalContext); const [didEncounterError, setDidEncounterError] = useState(false); // Handle errors @@ -518,9 +517,9 @@ export function AIGenCodeEvaluatorPopover({ if (onLoadingChange) onLoadingChange(false); setDidEncounterError(true); if (typeof err !== "string") console.error(err); - alertModal.current?.trigger(typeof err === "string" ? err : err?.message); + if (showAlert) showAlert(typeof err === "string" ? err : err?.message); }, - [setAwaitingResponse, onLoadingChange, setDidEncounterError, alertModal], + [setAwaitingResponse, onLoadingChange, setDidEncounterError, showAlert], ); // Generate an evaluate function, given the user-specified prompt, in the proper programming language @@ -731,7 +730,6 @@ ${currentEvalCode} - ); } diff --git a/chainforge/react-server/src/AlertModal.tsx b/chainforge/react-server/src/AlertModal.tsx index fe1fa3b..a7a31a5 100644 --- a/chainforge/react-server/src/AlertModal.tsx +++ b/chainforge/react-server/src/AlertModal.tsx @@ -1,5 +1,12 @@ /** An alert popup for displaying errors */ -import React, { useState, forwardRef, useImperativeHandle } from "react"; +import React, { + useState, + forwardRef, + useImperativeHandle, + createContext, + useRef, + useMemo, +} from "react"; import { useDisclosure } from "@mantine/hooks"; import { Modal, ModalBaseStylesNames, Styles } from "@mantine/core"; @@ -9,35 +16,69 @@ const ALERT_MODAL_STYLE = { } as Styles; export interface AlertModalRef { - trigger: (msg?: string) => void; + trigger: (msg?: string | Error) => void; } -const AlertModal = forwardRef(function AlertModal(props, ref) { - // Mantine modal popover for alerts - const [opened, { open, close }] = useDisclosure(false); - const [alertMsg, setAlertMsg] = useState(""); +/** + * The Alert Modal displays error messages to the user in a pop-up dialog. + */ +export const AlertModal = forwardRef( + function AlertModal(props, ref) { + // Mantine modal popover for alerts + const [opened, { open, close }] = useDisclosure(false); + const [alertMsg, setAlertMsg] = useState(""); - // This gives the parent access to triggering the modal alert - const trigger = (msg?: string) => { - if (!msg) msg = "Unknown error."; - console.error(msg); - setAlertMsg(msg); - open(); - }; - useImperativeHandle(ref, () => ({ - trigger, - })); + // This gives the parent access to triggering the modal alert + const trigger = (msg?: string | Error) => { + if (!msg) msg = "Unknown error."; + else if (typeof msg !== "string") msg = msg.message; + console.error(msg); + setAlertMsg(msg); + open(); + }; + useImperativeHandle(ref, () => ({ + trigger, + })); + + return ( + +

{alertMsg}

+
+ ); + }, +); +export default AlertModal; + +export const AlertModalContext = createContext< + ((msg?: string | Error) => void) | undefined +>(undefined); + +/** + * Wraps children components to provide the same AlertModal to everywhere in the component tree. + * Saves space and reduces duplicate declarations. + */ +export const AlertModalProvider = ({ + children, +}: { + children: React.ReactNode[]; +}) => { + // Create one AlertModal for the entire application + const alertModal = useRef(null); + + // We have to wrap trigger() in a memoized function, as passing it down directly will trigger re-renders every frame. + const showAlert = useMemo(() => { + return (msg?: string | Error) => alertModal?.current?.trigger(msg); + }, [alertModal]); return ( - -

{alertMsg}

-
+ + + {children} + ); -}); - -export default AlertModal; +}; diff --git a/chainforge/react-server/src/App.tsx b/chainforge/react-server/src/App.tsx index 38343c8..22856d6 100644 --- a/chainforge/react-server/src/App.tsx +++ b/chainforge/react-server/src/App.tsx @@ -1,4 +1,10 @@ -import React, { useState, useCallback, useRef, useEffect } from "react"; +import React, { + useState, + useCallback, + useRef, + useEffect, + useContext, +} from "react"; import ReactFlow, { Controls, Background, ReactFlowInstance } from "reactflow"; import { Button, @@ -31,13 +37,15 @@ import CodeEvaluatorNode from "./CodeEvaluatorNode"; import VisNode from "./VisNode"; import InspectNode from "./InspectorNode"; import ScriptNode from "./ScriptNode"; -import AlertModal, { AlertModalRef } from "./AlertModal"; +import { AlertModalProvider, AlertModalContext } from "./AlertModal"; import ItemsNode from "./ItemsNode"; import TabularDataNode from "./TabularDataNode"; import JoinNode from "./JoinNode"; import SplitNode from "./SplitNode"; import CommentNode from "./CommentNode"; -import GlobalSettingsModal, { GlobalSettingsModalRef } from "./GlobalSettingsModal"; +import GlobalSettingsModal, { + GlobalSettingsModalRef, +} from "./GlobalSettingsModal"; import ExampleFlowsModal, { ExampleFlowsModalRef } from "./ExampleFlowsModal"; import AreYouSureModal, { AreYouSureModalRef } from "./AreYouSureModal"; import LLMEvaluatorNode from "./LLMEvalNode"; @@ -60,7 +68,13 @@ import useStore, { StoreHandles } from "./store"; import StorageCache from "./backend/cache"; import { APP_IS_RUNNING_LOCALLY, browserTabIsActive } from "./backend/utils"; import { Dict, JSONCompatible, LLMSpec } from "./backend/typing"; -import { exportCache, fetchEnvironAPIKeys, fetchExampleFlow, fetchOpenAIEval, importCache } from "./backend/backend"; +import { + exportCache, + fetchEnvironAPIKeys, + fetchExampleFlow, + fetchOpenAIEval, + importCache, +} from "./backend/backend"; // Device / Browser detection import { @@ -217,7 +231,9 @@ const App = () => { // For saving / loading const [rfInstance, setRfInstance] = useState(null); - const [autosavingInterval, setAutosavingInterval] = useState(undefined); + const [autosavingInterval, setAutosavingInterval] = useState< + NodeJS.Timeout | undefined + >(undefined); // For 'share' button const clipboard = useClipboard({ timeout: 1500 }); @@ -232,9 +248,16 @@ const App = () => { // For an info pop-up that welcomes new users // const [welcomeModalOpened, { open: openWelcomeModal, close: closeWelcomeModal }] = useDisclosure(false); + // For displaying alerts + const showAlert = useContext(AlertModalContext); + // For confirmation popup const confirmationModal = useRef(null); - const [confirmationDialogProps, setConfirmationDialogProps] = useState<{title: string, message: string, onConfirm?: () => void}>({ + const [confirmationDialogProps, setConfirmationDialogProps] = useState<{ + title: string; + message: string; + onConfirm?: () => void; + }>({ title: "Confirm action", message: "Are you sure?", }); @@ -266,13 +289,22 @@ const App = () => { return { x: -(x / zoom) + centerX / zoom, y: -(y / zoom) + centerY / zoom }; }; - const addNode = (id: string, type?: string, data?: Dict, offsetX?: number, offsetY?: number) => { + const addNode = ( + id: string, + type?: string, + data?: Dict, + offsetX?: number, + offsetY?: number, + ) => { const { x, y } = getViewportCenter(); addNodeToStore({ id: `${id}-` + Date.now(), type: type ?? id, data: data ?? {}, - position: { x: x - 200 + (offsetX ? offsetX : 0), y: y - 100 + (offsetY ? offsetY : 0)}, + position: { + x: x - 200 + (offsetX || 0), + y: y - 100 + (offsetY || 0), + }, }); }; @@ -317,7 +349,7 @@ const App = () => { const msg = typeof err === "string" ? err : err.message; setIsLoading(false); setWaitingForShare(false); - if (alertModal.current) alertModal.current.trigger(msg); + if (showAlert) showAlert(msg); console.error(msg); }; @@ -529,39 +561,44 @@ const App = () => { input.accept = ".cforge, .json"; // Handle file selection - // @ts-expect-error The event is correctly typed here, but for some reason TS doesn't pick up on it. - input.addEventListener("change", function (event: React.ChangeEvent) { - // Start loading spinner - setIsLoading(false); + input.addEventListener( + "change", + // @ts-expect-error The event is correctly typed here, but for some reason TS doesn't pick up on it. + function (event: React.ChangeEvent) { + // Start loading spinner + setIsLoading(false); - const files = event.target.files; - if (!files || !Array.isArray(files) || files.length === 0) { - console.error("No files found to load."); - return; - } - - const file = files[0]; - const reader = new window.FileReader(); - - // Handle file load event - reader.addEventListener("load", function () { - try { - if (typeof reader.result !== "string") - throw new Error("File could not be read: Unknown format or empty."); - - // We try to parse the JSON response - const flow_and_cache = JSON.parse(reader.result); - - // Import it to React Flow and import cache data on the backend - importFlowFromJSON(flow_and_cache); - } catch (error) { - handleError(error as Error); + const files = event.target.files; + if (!files || !Array.isArray(files) || files.length === 0) { + console.error("No files found to load."); + return; } - }); - // Read the selected file as text - reader.readAsText(file); - }); + const file = files[0]; + const reader = new window.FileReader(); + + // Handle file load event + reader.addEventListener("load", function () { + try { + if (typeof reader.result !== "string") + throw new Error( + "File could not be read: Unknown format or empty.", + ); + + // We try to parse the JSON response + const flow_and_cache = JSON.parse(reader.result); + + // Import it to React Flow and import cache data on the backend + importFlowFromJSON(flow_and_cache); + } catch (error) { + handleError(error as Error); + } + }); + + // Read the selected file as text + reader.readAsText(file); + }, + ); // Trigger the file selector input.click(); @@ -571,9 +608,7 @@ const App = () => { const importFlowFromOpenAIEval = (evalname: string) => { setIsLoading(true); - fetchOpenAIEval(evalname) - .then(importFlowFromJSON) - .catch(handleError); + fetchOpenAIEval(evalname).then(importFlowFromJSON).catch(handleError); }; // Load flow from examples modal @@ -863,9 +898,8 @@ const App = () => { ); } else return ( -
- - + + { Send us feedback
- + ); }; diff --git a/chainforge/react-server/src/CodeEvaluatorNode.tsx b/chainforge/react-server/src/CodeEvaluatorNode.tsx index 34b8d0b..05d39b2 100644 --- a/chainforge/react-server/src/CodeEvaluatorNode.tsx +++ b/chainforge/react-server/src/CodeEvaluatorNode.tsx @@ -6,6 +6,7 @@ import React, { useMemo, forwardRef, useImperativeHandle, + useContext, } from "react"; import { Handle, NodeProps, Position } from "reactflow"; import { @@ -58,7 +59,7 @@ import { } from "./backend/typing"; import { Status } from "./StatusIndicatorComponent"; import { executejs, executepy, grabResponses } from "./backend/backend"; -import { AlertModalRef } from "./AlertModal"; +import { AlertModalContext } from "./AlertModal"; // Whether we are running on localhost or not, and hence whether // we have access to the Flask backend for, e.g., Python code evaluation. @@ -387,7 +388,7 @@ const CodeEvaluatorNode: React.FC = ({ const [lastContext, setLastContext] = useState({}); // For displaying error messages to user - const alertModal = useRef(null); + const showAlert = useContext(AlertModalContext); // For an info pop-up that explains the type of ResponseInfo const [infoModalOpened, { open: openInfoModal, close: closeInfoModal }] = @@ -413,9 +414,9 @@ const CodeEvaluatorNode: React.FC = ({ setStatus(Status.ERROR); setLastRunSuccess(false); if (typeof err !== "string") console.error(err); - alertModal.current?.trigger(typeof err === "string" ? err : err?.message); + if (showAlert) showAlert(typeof err === "string" ? err : err?.message); }, - [alertModal], + [showAlert], ); const pullInputs = useCallback(() => { @@ -788,7 +789,6 @@ The Python interpeter in the browser is Pyodide. You may not be able to run some onEdit={hideStatusIndicator} icon={} status={status} - alertModal={alertModal} handleRunClick={handleRunClick} runButtonTooltip={run_tooltip} customButtons={customButtons} diff --git a/chainforge/react-server/src/GlobalSettingsModal.tsx b/chainforge/react-server/src/GlobalSettingsModal.tsx index 88c5de3..1b6deb9 100644 --- a/chainforge/react-server/src/GlobalSettingsModal.tsx +++ b/chainforge/react-server/src/GlobalSettingsModal.tsx @@ -4,6 +4,7 @@ import React, { useImperativeHandle, useCallback, useEffect, + useContext, } from "react"; import { TextInput, @@ -34,12 +35,15 @@ import { import { Dropzone, FileWithPath } from "@mantine/dropzone"; import useStore from "./store"; import { APP_IS_RUNNING_LOCALLY } from "./backend/utils"; -import fetch_from_backend from "./fetch_from_backend"; import { setCustomProviders } from "./ModelSettingSchemas"; import { getAIFeaturesModelProviders } from "./backend/ai"; import { CustomLLMProviderSpec, Dict } from "./backend/typing"; -import { initCustomProvider, loadCachedCustomProviders, removeCustomProvider } from "./backend/backend"; -import { AlertModalRef } from "./AlertModal"; +import { + initCustomProvider, + loadCachedCustomProviders, + removeCustomProvider, +} from "./backend/backend"; +import { AlertModalContext } from "./AlertModal"; const _LINK_STYLE = { color: "#1E90FF", textDecoration: "none" }; @@ -47,7 +51,10 @@ const _LINK_STYLE = { color: "#1E90FF", textDecoration: "none" }; let LOADED_CUSTOM_PROVIDERS = false; // Read a file as text and pass the text to a cb (callback) function -const read_file = (file: FileWithPath, cb: (contents: string | ArrayBuffer | null) => void) => { +const read_file = ( + file: FileWithPath, + cb: (contents: string | ArrayBuffer | null) => void, +) => { const reader = new window.FileReader(); reader.onload = function (event) { const fileContent = event.target?.result; @@ -67,7 +74,9 @@ interface CustomProviderScriptDropzoneProps { /** A Dropzone to load a Python `.py` script that registers a `CustomModelProvider` in the Flask backend. * If successful, the list of custom model providers in the ChainForge UI dropdown is updated. * */ -const CustomProviderScriptDropzone: React.FC = ({ onError, onSetProviders }) => { +const CustomProviderScriptDropzone: React.FC< + CustomProviderScriptDropzoneProps +> = ({ onError, onSetProviders }) => { const theme = useMantineTheme(); const [isLoading, setIsLoading] = useState(false); @@ -106,9 +115,7 @@ const CustomProviderScriptDropzone: React.FC onReject={(files) => console.log("rejected files", files)} maxSize={3 * 1024 ** 2} > - +
void; } -export interface GlobalSettingsModalProps { - alertModal?: React.RefObject; -} - -const GlobalSettingsModal = forwardRef( +const GlobalSettingsModal = forwardRef( function GlobalSettingsModal(props, ref) { const [opened, { open, close }] = useDisclosure(false); const setAPIKeys = useStore((state) => state.setAPIKeys); const getFlag = useStore((state) => state.getFlag); const setFlag = useStore((state) => state.setFlag); const AvailableLLMs = useStore((state) => state.AvailableLLMs); + const aiFeaturesProvider = useStore((state) => state.aiFeaturesProvider); const setAvailableLLMs = useStore((state) => state.setAvailableLLMs); const nodes = useStore((state) => state.nodes); const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); - const alertModal = props?.alertModal; - const setAIFeaturesProvider = useStore( - (state) => state.setAIFeaturesProvider, - ); - const aiFeaturesProvider = useStore((state) => state.aiFeaturesProvider); + + const showAlert = useContext(AlertModalContext); const [aiSupportActive, setAISupportActive] = useState( getFlag("aiSupport") as boolean, @@ -204,12 +205,14 @@ const GlobalSettingsModal = forwardRef { const msg = typeof err === "string" ? err : err.message; - if (alertModal && alertModal.current) alertModal.current.trigger(msg); + if (showAlert) showAlert(msg); }, - [alertModal], + [showAlert], ); - const [customProviders, setLocalCustomProviders] = useState([]); + const [customProviders, setLocalCustomProviders] = useState< + CustomLLMProviderSpec[] + >([]); const refreshLLMProviderLists = useCallback(() => { // We unfortunately have to force all prompt/chat nodes to refresh their LLM lists, bc // apparently the update to the AvailableLLMs list is not immediately propagated to them. @@ -223,7 +226,8 @@ const GlobalSettingsModal = forwardRef { - removeCustomProvider(name).then(() => { + removeCustomProvider(name) + .then(() => { // Successfully deleted the custom provider from backend; // now updated the front-end UI to reflect this: setAvailableLLMs(AvailableLLMs.filter((p) => p.name !== name)); @@ -243,13 +247,13 @@ const GlobalSettingsModal = forwardRef { + loadCachedCustomProviders() + .then((providers) => { // Success; pass custom providers list to store: setCustomProviders(providers); setLocalCustomProviders(providers); - }, - ).catch(console.error); + }) + .catch(console.error); } }, []); diff --git a/chainforge/react-server/src/LLMEvalNode.js b/chainforge/react-server/src/LLMEvalNode.js index c49db02..9442f15 100644 --- a/chainforge/react-server/src/LLMEvalNode.js +++ b/chainforge/react-server/src/LLMEvalNode.js @@ -5,6 +5,7 @@ import React, { useEffect, forwardRef, useImperativeHandle, + useContext, } from "react"; import { Handle } from "reactflow"; import { Group, NativeSelect, Progress, Text, Textarea } from "@mantine/core"; @@ -20,6 +21,7 @@ import LLMResponseInspectorModal from "./LLMResponseInspectorModal"; import InspectFooter from "./InspectFooter"; import LLMResponseInspectorDrawer from "./LLMResponseInspectorDrawer"; import { stripLLMDetailsFromResponses } from "./backend/utils"; +import { AlertModalContext } from "./AlertModal"; // The default prompt shown in gray highlights to give people a good example of an evaluation prompt. const PLACEHOLDER_PROMPT = @@ -211,7 +213,7 @@ const LLMEvaluatorNode = ({ data, id }) => { const llmEvaluatorRef = useRef(null); const [status, setStatus] = useState("none"); - const alertModal = useRef(null); + const showAlert = useContext(AlertModalContext); const inspectModal = useRef(null); // eslint-disable-next-line @@ -243,7 +245,7 @@ const LLMEvaluatorNode = ({ data, id }) => { setStatus("error"); setProgress(undefined); if (typeof err !== "string") console.error(err); - alertModal.current.trigger(typeof err === "string" ? err : err?.message); + if (showAlert) showAlert(typeof err === "string" ? err : err?.message); }; // Fetch info about the number of queries we'll need to make @@ -296,7 +298,7 @@ const LLMEvaluatorNode = ({ data, id }) => { pingOutputNodes, setStatus, showDrawer, - alertModal, + showAlert, ]); const showResponseInspector = useCallback(() => { @@ -334,7 +336,6 @@ const LLMEvaluatorNode = ({ data, id }) => { nodeId={id} icon={} status={status} - alertModal={alertModal} handleRunClick={handleRunClick} runButtonTooltip="Run scorer over inputs" /> diff --git a/chainforge/react-server/src/NodeLabelComponent.tsx b/chainforge/react-server/src/NodeLabelComponent.tsx index 01f0e9f..c699a49 100644 --- a/chainforge/react-server/src/NodeLabelComponent.tsx +++ b/chainforge/react-server/src/NodeLabelComponent.tsx @@ -4,13 +4,13 @@ import React, { useState, useEffect, useCallback, + useContext, } from "react"; import { Tooltip } from "@mantine/core"; import { EditText, onSaveProps } from "react-edit-text"; import "react-edit-text/dist/index.css"; import useStore from "./store"; import StatusIndicator, { Status } from "./StatusIndicatorComponent"; -import AlertModal, { AlertModalRef } from "./AlertModal"; import AreYouSureModal, { AreYouSureModalRef } from "./AreYouSureModal"; export interface NodeLabelProps { @@ -22,7 +22,6 @@ export interface NodeLabelProps { editable?: boolean; status?: Status; isRunning?: boolean; - alertModal?: React.Ref; customButtons?: React.ReactElement[]; handleRunClick?: () => void; handleStopClick?: (nodeId: string) => void; @@ -45,7 +44,6 @@ export const NodeLabel: React.FC = ({ editable, status, isRunning, - alertModal, customButtons, handleRunClick, handleStopClick, @@ -167,7 +165,6 @@ export const NodeLabel: React.FC = ({ readonly={editable !== undefined ? !editable : false} /> {statusIndicator} -
{customButtons ?? <>} {isRunning ? stopButton : runButton} diff --git a/chainforge/react-server/src/PromptNode.tsx b/chainforge/react-server/src/PromptNode.tsx index a4a6469..724f781 100644 --- a/chainforge/react-server/src/PromptNode.tsx +++ b/chainforge/react-server/src/PromptNode.tsx @@ -4,6 +4,7 @@ import React, { useRef, useCallback, useMemo, + useContext, } from "react"; import { Handle, Position } from "reactflow"; import { v4 as uuid } from "uuid"; @@ -51,7 +52,7 @@ import { LLMResponse, TemplateVarInfo, } from "./backend/typing"; -import { AlertModalRef } from "./AlertModal"; +import { AlertModalContext } from "./AlertModal"; import { Status } from "./StatusIndicatorComponent"; const getUniqueLLMMetavarKey = (responses: LLMResponse[]) => { @@ -167,6 +168,8 @@ const PromptNode = ({ data, id, type: node_type }) => { [node_type], ); + console.log("re-render"); + // Get state from the Zustand store: const edges = useStore((state) => state.edges); const pullInputData = useStore((state) => state.pullInputData); @@ -199,7 +202,7 @@ const PromptNode = ({ data, id, type: node_type }) => { const [llmItemsCurrState, setLLMItemsCurrState] = useState([]); // For displaying error messages to user - const alertModal = useRef(null); + const showAlert = useContext(AlertModalContext); // For a way to inspect responses without having to attach a dedicated node const inspectModal = useRef(null); @@ -241,9 +244,9 @@ const PromptNode = ({ data, id, type: node_type }) => { (msg: string) => { setProgress(undefined); llmListContainer?.current?.resetLLMItemsProgress(); - alertModal?.current?.trigger(msg); + if (showAlert) showAlert(msg); }, - [llmListContainer, alertModal], + [llmListContainer, showAlert], ); const showResponseInspector = useCallback(() => { @@ -313,7 +316,6 @@ const PromptNode = ({ data, id, type: node_type }) => { const pulled_data = pullInputData(templateVars, id); updateShowContToggle(pulled_data); } catch (err) { - // alertModal.current?.trigger(err.message); console.error(err); } }, [templateVars, id, pullInputData, updateShowContToggle]); @@ -705,7 +707,7 @@ Soft failing by replacing undefined with empty strings.`, // Try to pull inputs pulled_data = pullInputData(templateVars, id); } catch (err) { - alertModal.current?.trigger((err as Error)?.message ?? err); + if (showAlert) showAlert((err as Error)?.message ?? err); console.error(err); return; // early exit } @@ -923,10 +925,11 @@ Soft failing by replacing undefined with empty strings.`, "\n"; }); // We trigger the alert directly (don't use triggerAlert) here because we want to keep the progress bar: - alertModal?.current?.trigger( - "Errors collecting responses. Re-run prompt node to retry.\n\n" + - combined_err_msg, - ); + if (showAlert) + showAlert( + "Errors collecting responses. Re-run prompt node to retry.\n\n" + + combined_err_msg, + ); return; } @@ -1039,7 +1042,6 @@ Soft failing by replacing undefined with empty strings.`, icon={node_icon} status={status} isRunning={status === "loading"} - alertModal={alertModal} handleRunClick={handleRunClick} handleStopClick={handleStopClick} handleRunHover={handleRunHover} diff --git a/chainforge/react-server/src/SimpleEvalNode.js b/chainforge/react-server/src/SimpleEvalNode.js index 43df1bb..af2a52e 100644 --- a/chainforge/react-server/src/SimpleEvalNode.js +++ b/chainforge/react-server/src/SimpleEvalNode.js @@ -1,4 +1,10 @@ -import React, { useState, useCallback, useEffect, useRef } from "react"; +import React, { + useState, + useCallback, + useEffect, + useRef, + useContext, +} from "react"; import { Handle } from "reactflow"; import { NativeSelect, @@ -29,6 +35,7 @@ import { toStandardResponseFormat, } from "./backend/utils"; import LLMResponseInspectorDrawer from "./LLMResponseInspectorDrawer"; +import { AlertModalContext } from "./AlertModal"; const createJSEvalCodeFor = (responseFormat, operation, value, valueType) => { let responseObj = "r.text"; @@ -80,7 +87,7 @@ const SimpleEvalNode = ({ data, id }) => { const [pastInputs, setPastInputs] = useState([]); const [status, setStatus] = useState("none"); - const alertModal = useRef(null); + const showAlert = useContext(AlertModalContext); const inspectModal = useRef(null); // eslint-disable-next-line @@ -156,7 +163,7 @@ const SimpleEvalNode = ({ data, id }) => { const rejected = (err_msg) => { setStatus("error"); - alertModal.current.trigger(err_msg); + if (showAlert) showAlert(err_msg); }; // Generate JS code for the user's spec @@ -200,7 +207,7 @@ const SimpleEvalNode = ({ data, id }) => { handlePullInputs, pingOutputNodes, setStatus, - alertModal, + showAlert, status, varValue, varValueType, @@ -260,7 +267,6 @@ const SimpleEvalNode = ({ data, id }) => { nodeId={id} icon={} status={status} - alertModal={alertModal} handleRunClick={handleRunClick} runButtonTooltip="Run evaluator over inputs" /> diff --git a/chainforge/react-server/src/TabularDataNode.tsx b/chainforge/react-server/src/TabularDataNode.tsx index c1ea325..04a015d 100644 --- a/chainforge/react-server/src/TabularDataNode.tsx +++ b/chainforge/react-server/src/TabularDataNode.tsx @@ -1,4 +1,10 @@ -import React, { useState, useRef, useEffect, useCallback } from "react"; +import React, { + useState, + useRef, + useEffect, + useCallback, + useContext, +} from "react"; import { Menu, NumberInput, Switch, Text, Tooltip } from "@mantine/core"; import EditableTable from "./EditableTable"; import * as XLSX from "xlsx"; @@ -12,7 +18,7 @@ import { import TemplateHooks from "./TemplateHooksComponent"; import BaseNode from "./BaseNode"; import NodeLabel from "./NodeLabelComponent"; -import AlertModal, { AlertModalRef } from "./AlertModal"; +import { AlertModalContext } from "./AlertModal"; import RenameValueModal, { RenameValueModalRef } from "./RenameValueModal"; import useStore from "./store"; import { sampleRandomElements } from "./backend/utils"; @@ -91,7 +97,7 @@ const TabularDataNode: React.FC = ({ data, id }) => { const [hooksY, setHooksY] = useState(120); // For displaying error messages to user - const alertModal = useRef(null); + const showAlert = useContext(AlertModalContext); // For renaming a column const renameColumnModal = useRef(null); @@ -467,7 +473,7 @@ const TabularDataNode: React.FC = ({ data, id }) => { }, [tableData, tableColumns, shouldSample, sampleNum]); const handleError = (err: Error) => { - if (alertModal.current) alertModal.current?.trigger(err.message); + if (showAlert) showAlert(err.message); console.error(err.message); }; @@ -520,8 +526,6 @@ const TabularDataNode: React.FC = ({ data, id }) => { ]} /> - - { "Access-Control-Allow-Origin": "*", }, body: JSON.stringify({ name: evalname }), - }).then(function (res) { - return res.json(); - }).then(function (json) { - if (json?.error !== undefined || !json?.data) - throw new Error(json.error as string ?? "Request to fetch example flow was sent to backend server, but there was no response."); - return json.data as Dict; - }); + }) + .then(function (res) { + return res.json(); + }) + .then(function (json) { + if (json?.error !== undefined || !json?.data) + throw new Error( + (json.error as string) ?? + "Request to fetch example flow was sent to backend server, but there was no response.", + ); + return json.data as Dict; + }); } // App is not running locally, but hosted on a site. @@ -1476,13 +1481,18 @@ export async function fetchOpenAIEval(evalname: string): Promise { "Access-Control-Allow-Origin": "*", }, body: JSON.stringify({ name: evalname }), - }).then(function (res) { - return res.json(); - }).then(function (json) { - if (json?.error !== undefined || !json?.data) - throw new Error(json.error as string ?? "Request to fetch OpenAI eval was sent to backend server, but there was no response."); - return json.data as Dict; - }); + }) + .then(function (res) { + return res.json(); + }) + .then(function (json) { + if (json?.error !== undefined || !json?.data) + throw new Error( + (json.error as string) ?? + "Request to fetch OpenAI eval was sent to backend server, but there was no response.", + ); + return json.data as Dict; + }); } // App is not running locally, but hosted on a site. @@ -1500,7 +1510,9 @@ export async function fetchOpenAIEval(evalname: string): Promise { * @returns a Promise with the JSON of the response. Will include 'error' key if error'd; if success, * a 'providers' key with a list of all loaded custom provider callbacks, as dicts. */ -export async function initCustomProvider(code: string): Promise { +export async function initCustomProvider( + code: string, +): Promise { // Attempt to fetch the example flow from the local filesystem // by querying the Flask server: return fetch(`${FLASK_BASE_URL}app/initCustomProvider`, { @@ -1510,13 +1522,15 @@ export async function initCustomProvider(code: string): Promise { "Access-Control-Allow-Origin": "*", }, body: JSON.stringify({ name }), - }).then(function (res) { - return res.json(); - }).then(function (json) { - if (!json || json.error || !json.success) - throw new Error(json.error ?? "Unknown error"); - return true; - }); + }) + .then(function (res) { + return res.json(); + }) + .then(function (json) { + if (!json || json.error || !json.success) + throw new Error(json.error ?? "Unknown error"); + return true; + }); } /** @@ -1551,7 +1567,9 @@ export async function removeCustomProvider(name: string): Promise { * @returns a Promise with the JSON of the response. Will include 'error' key if error'd; if success, * a 'providers' key with all loaded custom providers in an array. If there were none, returns empty array. */ -export async function loadCachedCustomProviders(): Promise { +export async function loadCachedCustomProviders(): Promise< + CustomLLMProviderSpec[] +> { return fetch(`${FLASK_BASE_URL}app/loadCachedCustomProviders`, { method: "POST", headers: { @@ -1559,11 +1577,16 @@ export async function loadCachedCustomProviders(): Promise((set, get) => ({ }, onConnect: (connection) => { // Get the target node information - const target = connection.target ? get().getNode(connection.target) : undefined; + const target = connection.target + ? get().getNode(connection.target) + : undefined; if (target === undefined) return; if (