From d98de2b7cac8e922d3a39c5d868618fdb244e232 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Sun, 10 Mar 2024 23:16:59 -0400 Subject: [PATCH] WIP. Refactor addNodes in App.tsx to be simpler. --- chainforge/react-server/package-lock.json | 12 ++ chainforge/react-server/package.json | 1 + chainforge/react-server/src/App.tsx | 201 ++++-------------- .../react-server/src/CodeEvaluatorNode.tsx | 2 +- .../react-server/src/ExampleFlowsModal.tsx | 6 +- ...ttingsModal.js => GlobalSettingsModal.tsx} | 111 +++++----- ...mButtonGroup.js => LLMItemButtonGroup.tsx} | 17 +- .../src/{LLMListItem.js => LLMListItem.tsx} | 30 ++- .../react-server/src/ModelSettingSchemas.tsx | 6 +- .../react-server/src/backend/backend.ts | 23 +- chainforge/react-server/src/backend/cache.ts | 8 +- chainforge/react-server/src/store.tsx | 6 +- 12 files changed, 183 insertions(+), 240 deletions(-) rename chainforge/react-server/src/{GlobalSettingsModal.js => GlobalSettingsModal.tsx} (88%) rename chainforge/react-server/src/{LLMItemButtonGroup.js => LLMItemButtonGroup.tsx} (81%) rename chainforge/react-server/src/{LLMListItem.js => LLMListItem.tsx} (82%) diff --git a/chainforge/react-server/package-lock.json b/chainforge/react-server/package-lock.json index 4360d16..e7a073b 100644 --- a/chainforge/react-server/package-lock.json +++ b/chainforge/react-server/package-lock.json @@ -107,6 +107,7 @@ "@types/papaparse": "^5.3.14", "@types/react-beautiful-dnd": "^13.1.8", "@types/react-edit-text": "^5.0.4", + "@types/styled-components": "^5.1.34", "eslint": "^8.56.0", "eslint-config-prettier": "^9.1.0", "eslint-config-semistandard": "^17.0.0", @@ -7111,6 +7112,17 @@ "resolved": "https://registry.npmjs.org/@types/stack-utils/-/stack-utils-2.0.1.tgz", "integrity": "sha512-Hl219/BT5fLAaz6NDkSuhzasy49dwQS/DSdu4MdggFB8zcXv7vflBI3xp7FEmkmdDkBUI2bPUNeMttp2knYdxw==" }, + "node_modules/@types/styled-components": { + "version": "5.1.34", + "resolved": "https://registry.npmjs.org/@types/styled-components/-/styled-components-5.1.34.tgz", + "integrity": "sha512-mmiVvwpYklFIv9E8qfxuPyIt/OuyIrn6gMOAMOFUO3WJfSrSE+sGUoa4PiZj77Ut7bKZpaa6o1fBKS/4TOEvnA==", + "dev": true, + "dependencies": { + "@types/hoist-non-react-statics": "*", + "@types/react": "*", + "csstype": "^3.0.2" + } + }, "node_modules/@types/testing-library__jest-dom": { "version": "5.14.9", "resolved": "https://registry.npmjs.org/@types/testing-library__jest-dom/-/testing-library__jest-dom-5.14.9.tgz", diff --git a/chainforge/react-server/package.json b/chainforge/react-server/package.json index dd95f9e..6e26b93 100644 --- a/chainforge/react-server/package.json +++ b/chainforge/react-server/package.json @@ -133,6 +133,7 @@ "@types/papaparse": "^5.3.14", "@types/react-beautiful-dnd": "^13.1.8", "@types/react-edit-text": "^5.0.4", + "@types/styled-components": "^5.1.34", "eslint": "^8.56.0", "eslint-config-prettier": "^9.1.0", "eslint-config-semistandard": "^17.0.0", diff --git a/chainforge/react-server/src/App.tsx b/chainforge/react-server/src/App.tsx index d282c56..e6d1d4f 100644 --- a/chainforge/react-server/src/App.tsx +++ b/chainforge/react-server/src/App.tsx @@ -37,8 +37,8 @@ import TabularDataNode from "./TabularDataNode"; import JoinNode from "./JoinNode"; import SplitNode from "./SplitNode"; import CommentNode from "./CommentNode"; -import GlobalSettingsModal from "./GlobalSettingsModal"; -import ExampleFlowsModal from "./ExampleFlowsModal"; +import GlobalSettingsModal, { GlobalSettingsModalRef } from "./GlobalSettingsModal"; +import ExampleFlowsModal, { ExampleFlowsModalRef } from "./ExampleFlowsModal"; import AreYouSureModal from "./AreYouSureModal"; import LLMEvaluatorNode from "./LLMEvalNode"; import SimpleEvalNode from "./SimpleEvalNode"; @@ -69,7 +69,7 @@ import { isEdgeChromium, isChromium, } from "react-device-detect"; -import { Dict, LLMSpec } from "./backend/typing"; +import { Dict, JSONCompatible, LLMSpec } from "./backend/typing"; const IS_ACCEPTED_BROWSER = (isChrome || isChromium || @@ -196,7 +196,7 @@ const MenuTooltip = ({ }; // const connectionLineStyle = { stroke: '#ddd' }; -const snapGrid = [16, 16]; +const snapGrid: [number, number] = [16, 16]; const App = () => { // Get nodes, edges, etc. state from the Zustand store: @@ -206,7 +206,7 @@ const App = () => { onNodesChange, onEdgesChange, onConnect, - addNode, + addNode: addNodeToStore, setNodes, setEdges, resetLLMColors, @@ -216,17 +216,17 @@ const App = () => { // For saving / loading const [rfInstance, setRfInstance] = useState(null); - const [autosavingInterval, setAutosavingInterval] = useState(null); + const [autosavingInterval, setAutosavingInterval] = useState(undefined); // For 'share' button const clipboard = useClipboard({ timeout: 1500 }); const [waitingForShare, setWaitingForShare] = useState(false); // For modal popup to set global settings like API keys - const settingsModal = useRef(null); + const settingsModal = useRef(null); // For modal popup of example flows - const examplesModal = useRef(null); + const examplesModal = useRef(null); // For an info pop-up that welcomes new users // const [welcomeModalOpened, { open: openWelcomeModal, close: closeWelcomeModal }] = useDisclosure(false); @@ -265,150 +265,44 @@ const App = () => { return { x: -(x / zoom) + centerX / zoom, y: -(y / zoom) + centerY / zoom }; }; - const addTextFieldsNode = () => { + const addNode = (id: string, type?: string, data?: Dict, offsetX?: number, offsetY?: number) => { const { x, y } = getViewportCenter(); - addNode({ - id: "textFieldsNode-" + Date.now(), - type: "textfields", - data: {}, - position: { x: x - 200, y: y - 100 }, - }); - }; - const addPromptNode = () => { - const { x, y } = getViewportCenter(); - addNode({ - id: "promptNode-" + Date.now(), - type: "prompt", - data: { prompt: "" }, - position: { x: x - 200, y: y - 100 }, - }); - }; - const addChatTurnNode = () => { - const { x, y } = getViewportCenter(); - addNode({ - id: "chatTurn-" + Date.now(), - type: "chat", - data: { prompt: "" }, - position: { x: x - 200, y: y - 100 }, - }); - }; - const addSimpleEvalNode = () => { - const { x, y } = getViewportCenter(); - addNode({ - id: "simpleEval-" + Date.now(), - type: "simpleval", - data: {}, - position: { x: x - 200, y: y - 100 }, + addNodeToStore({ + id: `${id}-` + Date.now(), + type: type ?? id, + data: data ?? {}, + position: { x: x - 200 + (offsetX ? offsetX : 0), y: y - 100 + (offsetY ? offsetY : 0)}, }); }; + + const addTextFieldsNode = () => addNode("textFieldsNode", "textfields"); + const addPromptNode = () => addNode("promptNode", "prompt", { prompt: "" }); + const addChatTurnNode = () => addNode("chatTurn", "chat", { prompt: "" }); + const addSimpleEvalNode = () => addNode("simpleEval", "simpleval"); const addEvalNode = (progLang: string) => { - const { x, y } = getViewportCenter(); let code = ""; if (progLang === "python") code = "def evaluate(response):\n return len(response.text)"; else if (progLang === "javascript") code = "function evaluate(response) {\n return response.text.length;\n}"; - addNode({ - id: "evalNode-" + Date.now(), - type: "evaluator", - data: { language: progLang, code }, - position: { x: x - 200, y: y - 100 }, - }); - }; - const addVisNode = () => { - const { x, y } = getViewportCenter(); - addNode({ - id: "visNode-" + Date.now(), - type: "vis", - data: {}, - position: { x: x - 200, y: y - 100 }, - }); - }; - const addInspectNode = () => { - const { x, y } = getViewportCenter(); - addNode({ - id: "inspectNode-" + Date.now(), - type: "inspect", - data: {}, - position: { x: x - 200, y: y - 100 }, - }); - }; - const addScriptNode = () => { - const { x, y } = getViewportCenter(); - addNode({ - id: "scriptNode-" + Date.now(), - type: "script", - data: {}, - position: { x: x - 200, y: y - 100 }, - }); - }; - const addItemsNode = () => { - const { x, y } = getViewportCenter(); - addNode({ - id: "csvNode-" + Date.now(), - type: "csv", - data: {}, - position: { x: x - 200, y: y - 100 }, - }); - }; - const addTabularDataNode = () => { - const { x, y } = getViewportCenter(); - addNode({ - id: "table-" + Date.now(), - type: "table", - data: {}, - position: { x: x - 200, y: y - 100 }, - }); - }; - const addCommentNode = () => { - const { x, y } = getViewportCenter(); - addNode({ - id: "comment-" + Date.now(), - type: "comment", - data: {}, - position: { x: x - 200, y: y - 100 }, - }); - }; - const addLLMEvalNode = () => { - const { x, y } = getViewportCenter(); - addNode({ - id: "llmeval-" + Date.now(), - type: "llmeval", - data: {}, - position: { x: x - 200, y: y - 100 }, - }); - }; - const addJoinNode = () => { - const { x, y } = getViewportCenter(); - addNode({ - id: "join-" + Date.now(), - type: "join", - data: {}, - position: { x: x - 200, y: y - 100 }, - }); - }; - const addSplitNode = () => { - const { x, y } = getViewportCenter(); - addNode({ - id: "split-" + Date.now(), - type: "split", - data: {}, - position: { x: x - 200, y: y - 100 }, - }); + addNode("evalNode", "evaluator", { language: progLang, code }); }; + const addVisNode = () => addNode("visNode", "vis", {}); + const addInspectNode = () => addNode("inspectNode", "inspect"); + const addScriptNode = () => addNode("scriptNode", "script"); + const addItemsNode = () => addNode("csvNode", "csv"); + const addTabularDataNode = () => addNode("table"); + const addCommentNode = () => addNode("comment"); + const addLLMEvalNode = () => addNode("llmeval"); + const addJoinNode = () => addNode("join"); + const addSplitNode = () => addNode("split"); const addProcessorNode = (progLang: string) => { - const { x, y } = getViewportCenter(); let code = ""; if (progLang === "python") code = "def process(response):\n return response.text;"; else if (progLang === "javascript") code = "function process(response) {\n return response.text;\n}"; - addNode({ - id: "process-" + Date.now(), - type: "processor", - data: { language: progLang, code }, - position: { x: x - 200, y: y - 100 }, - }); + addNode("process", "processor", { language: progLang, code }); }; const onClickExamples = () => { @@ -429,7 +323,7 @@ const App = () => { /** * SAVING / LOADING, IMPORT / EXPORT (from JSON) */ - const downloadJSON = (jsonData, filename) => { + const downloadJSON = (jsonData: JSONCompatible, filename: string) => { // Convert JSON object to JSON string const jsonString = JSON.stringify(jsonData, null, 2); @@ -547,7 +441,7 @@ const App = () => { const saved_flow = StorageCache.loadFromLocalStorage( "chainforge-flow", false, - ); + ) as Dict; if (saved_flow) { StorageCache.loadFromLocalStorage("chainforge-state"); importGlobalStateFromCache(); @@ -718,7 +612,7 @@ const App = () => { }; // Load flow from examples modal - const onSelectExampleFlow = (name: string, example_category: string) => { + const onSelectExampleFlow = (name: string, example_category?: string) => { // Trigger the 'loading' modal setIsLoading(true); @@ -872,9 +766,9 @@ const App = () => { ]); // Initialize auto-saving - const initAutosaving = (rf_inst) => { - if (autosavingInterval !== null) return; // autosaving interval already set - console.log("Init autosaving!"); + const initAutosaving = (rf_inst: ReactFlowInstance) => { + if (autosavingInterval !== undefined) return; // autosaving interval already set + console.log("Init autosaving"); // Autosave the flow to localStorage every minute: const interv = setInterval(() => { @@ -898,7 +792,7 @@ const App = () => { "Autosaving disabled. The time required to save to localStorage exceeds 1 second. This can happen when there's a lot of data in your flow. Make sure to export frequently to save your work.", ); clearInterval(interv); - setAutosavingInterval(null); + setAutosavingInterval(undefined); } }, 60000); // 60000 milliseconds = 1 minute setAutosavingInterval(interv); @@ -938,23 +832,16 @@ const App = () => { if (!response || response.startsWith("Error")) { // Error encountered during the query; alert the user // with the error message: - handleError(new Error(response || "Unknown error")); - return; + throw new Error(response || "Unknown error"); } // Attempt to parse the response as a compressed flow + import it: - try { - const cforge_json = JSON.parse( - LZString.decompressFromUTF16(response), - ); - importFlowFromJSON(cforge_json, rf_inst); - } catch (err) { - handleError(err); - } + const cforge_json = JSON.parse( + LZString.decompressFromUTF16(response), + ); + importFlowFromJSON(cforge_json, rf_inst); }) - .catch((err) => { - handleError(err); - }); + .catch(handleError); } catch (err) { // Soft fail setIsLoading(false); @@ -1064,7 +951,9 @@ const App = () => { onConnect={onConnect} nodes={nodes} edges={edges} + // @ts-expect-error Node types won't perfectly fit unless we explicitly extend from RF's types; ignoring this for now. nodeTypes={nodeTypes} + // @ts-expect-error Edge types won't perfectly fit unless we explicitly extend from RF's types; ignoring this for now. edgeTypes={edgeTypes} zoomOnPinch={false} zoomOnScroll={false} diff --git a/chainforge/react-server/src/CodeEvaluatorNode.tsx b/chainforge/react-server/src/CodeEvaluatorNode.tsx index ea831be..34b8d0b 100644 --- a/chainforge/react-server/src/CodeEvaluatorNode.tsx +++ b/chainforge/react-server/src/CodeEvaluatorNode.tsx @@ -7,7 +7,7 @@ import React, { forwardRef, useImperativeHandle, } from "react"; -import { Handle, Position } from "reactflow"; +import { Handle, NodeProps, Position } from "reactflow"; import { Code, Modal, diff --git a/chainforge/react-server/src/ExampleFlowsModal.tsx b/chainforge/react-server/src/ExampleFlowsModal.tsx index 3dca73b..bc40b96 100644 --- a/chainforge/react-server/src/ExampleFlowsModal.tsx +++ b/chainforge/react-server/src/ExampleFlowsModal.tsx @@ -357,16 +357,16 @@ const ExampleFlowCard: React.FC = ({ ); }; -interface ExampleFlowsModalHandles { +export interface ExampleFlowsModalRef { trigger: () => void; } -interface ExampleFlowsModalProps { +export interface ExampleFlowsModalProps { handleOnSelect: (filename: string, category?: string) => void; } const ExampleFlowsModal = forwardRef< - ExampleFlowsModalHandles, + ExampleFlowsModalRef, ExampleFlowsModalProps >(function ExampleFlowsModal({ handleOnSelect }, ref) { // Mantine modal popover for alerts diff --git a/chainforge/react-server/src/GlobalSettingsModal.js b/chainforge/react-server/src/GlobalSettingsModal.tsx similarity index 88% rename from chainforge/react-server/src/GlobalSettingsModal.js rename to chainforge/react-server/src/GlobalSettingsModal.tsx index 301e95a..88c5de3 100644 --- a/chainforge/react-server/src/GlobalSettingsModal.js +++ b/chainforge/react-server/src/GlobalSettingsModal.tsx @@ -31,12 +31,15 @@ import { IconX, IconSparkles, } from "@tabler/icons-react"; -import { Dropzone } from "@mantine/dropzone"; +import { Dropzone, FileWithPath } from "@mantine/dropzone"; import useStore from "./store"; import { APP_IS_RUNNING_LOCALLY } from "./backend/utils"; import fetch_from_backend from "./fetch_from_backend"; import { setCustomProviders } from "./ModelSettingSchemas"; import { getAIFeaturesModelProviders } from "./backend/ai"; +import { CustomLLMProviderSpec, Dict } from "./backend/typing"; +import { initCustomProvider, loadCachedCustomProviders, removeCustomProvider } from "./backend/backend"; +import { AlertModalRef } from "./AlertModal"; const _LINK_STYLE = { color: "#1E90FF", textDecoration: "none" }; @@ -44,11 +47,11 @@ const _LINK_STYLE = { color: "#1E90FF", textDecoration: "none" }; let LOADED_CUSTOM_PROVIDERS = false; // Read a file as text and pass the text to a cb (callback) function -const read_file = (file, cb) => { +const read_file = (file: FileWithPath, cb: (contents: string | ArrayBuffer | null) => void) => { const reader = new window.FileReader(); reader.onload = function (event) { - const fileContent = event.target.result; - cb(fileContent); + const fileContent = event.target?.result; + cb(fileContent ?? null); }; reader.onerror = function (event) { console.error("Error reading file:", event); @@ -56,10 +59,15 @@ const read_file = (file, cb) => { reader.readAsText(file); }; +interface CustomProviderScriptDropzoneProps { + onError: (err: string | Error) => void; + onSetProviders: (providers: CustomLLMProviderSpec[]) => void; +} + /** A Dropzone to load a Python `.py` script that registers a `CustomModelProvider` in the Flask backend. * If successful, the list of custom model providers in the ChainForge UI dropdown is updated. * */ -const CustomProviderScriptDropzone = ({ onError, onSetProviders }) => { +const CustomProviderScriptDropzone: React.FC = ({ onError, onSetProviders }) => { const theme = useMantineTheme(); const [isLoading, setIsLoading] = useState(false); @@ -69,23 +77,20 @@ const CustomProviderScriptDropzone = ({ onError, onSetProviders }) => { onDrop={(files) => { if (files.length === 1) { setIsLoading(true); - read_file(files[0], (content) => { + read_file(files[0], (content: string | ArrayBuffer | null) => { + if (typeof content !== "string") { + console.error("File unreadable: Contents are not text."); + return; + } // Read the file into text and then send it to backend - fetch_from_backend("initCustomProvider", { - code: content, - }) - .then((response) => { + initCustomProvider(content) + .then((providers) => { setIsLoading(false); - - if (response.error || !response.providers) { - onError(response.error); - return; - } // Successfully loaded custom providers in backend, // now load them into the ChainForge UI: - console.log(response.providers); - setCustomProviders(response.providers); - onSetProviders(response.providers); + console.log(providers); + setCustomProviders(providers); + onSetProviders(providers); }) .catch((err) => { setIsLoading(false); @@ -102,8 +107,6 @@ const CustomProviderScriptDropzone = ({ onError, onSetProviders }) => { maxSize={3 * 1024 ** 2} >
@@ -144,7 +147,15 @@ const CustomProviderScriptDropzone = ({ onError, onSetProviders }) => { ); }; -const GlobalSettingsModal = forwardRef( +export interface GlobalSettingsModalRef { + trigger: () => void; +} + +export interface GlobalSettingsModalProps { + alertModal?: React.RefObject; +} + +const GlobalSettingsModal = forwardRef( function GlobalSettingsModal(props, ref) { const [opened, { open, close }] = useDisclosure(false); const setAPIKeys = useStore((state) => state.setAPIKeys); @@ -160,11 +171,15 @@ const GlobalSettingsModal = forwardRef( ); const aiFeaturesProvider = useStore((state) => state.aiFeaturesProvider); - const [aiSupportActive, setAISupportActive] = useState( - getFlag("aiSupport"), + const [aiSupportActive, setAISupportActive] = useState( + getFlag("aiSupport") as boolean, ); + const [aiAutocompleteActive, setAIAutocompleteActive] = useState( + getFlag("aiAutocomplete") as boolean, + ); + const handleAISupportChecked = useCallback( - (e) => { + (e: React.ChangeEvent) => { const checked = e.currentTarget.checked; setAISupportActive(checked); setFlag("aiSupport", checked); @@ -177,11 +192,8 @@ const GlobalSettingsModal = forwardRef( [setFlag, setAISupportActive], ); - const [aiAutocompleteActive, setAIAutocompleteActive] = useState( - getFlag("aiAutocomplete"), - ); const handleAIAutocompleteChecked = useCallback( - (e) => { + (e: React.ChangeEvent) => { const checked = e.currentTarget.checked; setAIAutocompleteActive(checked); setFlag("aiAutocomplete", checked); @@ -190,13 +202,14 @@ const GlobalSettingsModal = forwardRef( ); const handleError = useCallback( - (msg) => { + (err: string | Error) => { + const msg = typeof err === "string" ? err : err.message; if (alertModal && alertModal.current) alertModal.current.trigger(msg); }, [alertModal], ); - const [customProviders, setLocalCustomProviders] = useState([]); + const [customProviders, setLocalCustomProviders] = useState([]); const refreshLLMProviderLists = useCallback(() => { // We unfortunately have to force all prompt/chat nodes to refresh their LLM lists, bc // apparently the update to the AvailableLLMs list is not immediately propagated to them. @@ -208,16 +221,9 @@ const GlobalSettingsModal = forwardRef( ); }, [nodes, setDataPropsForNode]); - const removeCustomProvider = useCallback( - (name) => { - fetch_from_backend("removeCustomProvider", { - name, - }) - .then((response) => { - if (response.error || !response.success) { - handleError(response.error); - return; - } + const handleRemoveCustomProvider = useCallback( + (name: string) => { + removeCustomProvider(name).then(() => { // Successfully deleted the custom provider from backend; // now updated the front-end UI to reflect this: setAvailableLLMs(AvailableLLMs.filter((p) => p.name !== name)); @@ -226,7 +232,7 @@ const GlobalSettingsModal = forwardRef( ); refreshLLMProviderLists(); }) - .catch((err) => handleError(err.message)); + .catch(handleError); }, [customProviders, handleError, AvailableLLMs, refreshLLMProviderLists], ); @@ -237,20 +243,13 @@ const GlobalSettingsModal = forwardRef( LOADED_CUSTOM_PROVIDERS = true; // Is running locally; try to load any custom providers. // Soft fails if it encounters error: - fetch_from_backend("loadCachedCustomProviders", {}, console.error).then( - (json) => { - if (json?.error || json?.providers === undefined) { - console.error( - json?.error || - "Could not load custom provider scripts: Error contacting backend.", - ); - return; - } + loadCachedCustomProviders().then( + (providers) => { // Success; pass custom providers list to store: - setCustomProviders(json.providers); - setLocalCustomProviders(json.providers); + setCustomProviders(providers); + setLocalCustomProviders(providers); }, - ); + ).catch(console.error); } }, []); @@ -276,7 +275,7 @@ const GlobalSettingsModal = forwardRef( }); // When the API settings form is submitted - const onSubmit = (values) => { + const onSubmit = (values: Dict) => { setAPIKeys(values); close(); }; @@ -524,7 +523,7 @@ const GlobalSettingsModal = forwardRef( )}