From d8e734d778e790c4dfe9564739f6876241255cf6 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Fri, 13 Oct 2023 13:13:51 -0400 Subject: [PATCH] wip join node UI --- chainforge/react-server/src/App.js | 25 +++- chainforge/react-server/src/JoinNode.js | 136 ++++++++++++++++++ .../react-server/src/text-fields-node.css | 4 + 3 files changed, 158 insertions(+), 7 deletions(-) create mode 100644 chainforge/react-server/src/JoinNode.js diff --git a/chainforge/react-server/src/App.js b/chainforge/react-server/src/App.js index b799109..f791ba6 100644 --- a/chainforge/react-server/src/App.js +++ b/chainforge/react-server/src/App.js @@ -8,7 +8,7 @@ import ReactFlow, { } from 'reactflow'; import { Button, Menu, LoadingOverlay, Text, Box, List, Loader, Tooltip } from '@mantine/core'; import { useClipboard } from '@mantine/hooks'; -import { IconSettings, IconTextPlus, IconTerminal, IconCsv, IconSettingsAutomation, IconFileSymlink, IconRobot, IconRuler2 } from '@tabler/icons-react'; +import { IconSettings, IconTextPlus, IconTerminal, IconCsv, IconSettingsAutomation, IconFileSymlink, IconRobot, IconRuler2, IconArrowMerge } from '@tabler/icons-react'; import RemoveEdge from './RemoveEdge'; import TextFieldsNode from './TextFieldsNode'; // Import a custom node import PromptNode from './PromptNode'; @@ -19,6 +19,7 @@ import ScriptNode from './ScriptNode'; import AlertModal from './AlertModal'; import CsvNode from './CsvNode'; import TabularDataNode from './TabularDataNode'; +import JoinNode from './JoinNode'; import CommentNode from './CommentNode'; import GlobalSettingsModal from './GlobalSettingsModal'; import ExampleFlowsModal from './ExampleFlowsModal'; @@ -87,6 +88,7 @@ const nodeTypes = { csv: CsvNode, table: TabularDataNode, comment: CommentNode, + join: JoinNode, }; const edgeTypes = { @@ -197,27 +199,27 @@ const App = () => { code = "function evaluate(response) {\n return response.text.length;\n}"; addNode({ id: 'evalNode-'+Date.now(), type: 'evaluator', data: { language: progLang, code: code }, position: {x: x-200, y:y-100} }); }; - const addVisNode = (event) => { + const addVisNode = () => { const { x, y } = getViewportCenter(); addNode({ id: 'visNode-'+Date.now(), type: 'vis', data: {}, position: {x: x-200, y:y-100} }); }; - const addInspectNode = (event) => { + const addInspectNode = () => { const { x, y } = getViewportCenter(); addNode({ id: 'inspectNode-'+Date.now(), type: 'inspect', data: {}, position: {x: x-200, y:y-100} }); }; - const addScriptNode = (event) => { + const addScriptNode = () => { const { x, y } = getViewportCenter(); addNode({ id: 'scriptNode-'+Date.now(), type: 'script', data: {}, position: {x: x-200, y:y-100} }); }; - const addCsvNode = (event) => { + const addCsvNode = () => { const { x, y } = getViewportCenter(); addNode({ id: 'csvNode-'+Date.now(), type: 'csv', data: {}, position: {x: x-200, y:y-100} }); }; - const addTabularDataNode = (event) => { + const addTabularDataNode = () => { const { x, y } = getViewportCenter(); addNode({ id: 'table-'+Date.now(), type: 'table', data: {}, position: {x: x-200, y:y-100} }); }; - const addCommentNode = (event) => { + const addCommentNode = () => { const { x, y } = getViewportCenter(); addNode({ id: 'comment-'+Date.now(), type: 'comment', data: {}, position: {x: x-200, y:y-100} }); }; @@ -225,6 +227,10 @@ const App = () => { const { x, y } = getViewportCenter(); addNode({ id: 'llmeval-'+Date.now(), type: 'llmeval', data: {}, position: {x: x-200, y:y-100} }); }; + const addJoinNode = () => { + const { x, y } = getViewportCenter(); + addNode({ id: 'join-'+Date.now(), type: 'join', data: {}, position: {x: x-200, y:y-100} }); + }; const onClickExamples = () => { if (examplesModal && examplesModal.current) @@ -768,6 +774,11 @@ const App = () => { Inspect Node + Processors + + }> Join Node + + Misc Comment Node diff --git a/chainforge/react-server/src/JoinNode.js b/chainforge/react-server/src/JoinNode.js new file mode 100644 index 0000000..7a95d42 --- /dev/null +++ b/chainforge/react-server/src/JoinNode.js @@ -0,0 +1,136 @@ +import React, { useState, useEffect } from 'react'; +import { Handle } from 'reactflow'; +import useStore from './store'; +import NodeLabel from './NodeLabelComponent'; +import fetch_from_backend from './fetch_from_backend'; +import { IconArrowMerge } from '@tabler/icons-react'; +import { Divider, NativeSelect, Text } from '@mantine/core'; + +const JoinNode = ({ data, id }) => { + + let is_fetching = false; + + const [jsonResponses, setJSONResponses] = useState(null); + + const [pastInputs, setPastInputs] = useState([]); + const inputEdgesForNode = useStore((state) => state.inputEdgesForNode); + const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); + + const handleOnConnect = () => { + // For some reason, 'on connect' is called twice upon connection. + // We detect when an inspector node is already fetching, and disable the second call: + if (is_fetching) return; + + // Get the ids from the connected input nodes: + const input_node_ids = inputEdgesForNode(id).map(e => e.source); + + is_fetching = true; + + // Grab responses associated with those ids: + fetch_from_backend('grabResponses', { + 'responses': input_node_ids + }).then(function(json) { + if (json.responses && json.responses.length > 0) { + setJSONResponses(json.responses); + } + is_fetching = false; + }).catch(() => { + is_fetching = false; + }); + } + + const [groupByVar, setGroupByVar] = useState("all text"); + const handleChangeGroupByVar = (new_val) => { + setGroupByVar(new_val.target.value); + }; + + const [groupByLLM, setGroupByLLM] = useState("within"); + const handleChangeGroupByLLM = (new_val) => { + setGroupByLLM(new_val.target.value); + }; + + const [responsesPerPrompt, setResponsesPerPrompt] = useState("all"); + const handleChangeResponsesPerPrompt = (new_val) => { + setResponsesPerPrompt(new_val.target.value); + }; + + if (data.input) { + // If there's a change in inputs... + if (data.input != pastInputs) { + setPastInputs(data.input); + handleOnConnect(); + } + } + + useEffect(() => { + if (data.refresh && data.refresh === true) { + // Recreate the visualization: + setDataPropsForNode(id, { refresh: false }); + handleOnConnect(); + } + }, [data, id, handleOnConnect, setDataPropsForNode]); + + return ( +
+ } + /> +
+ Join + +
+
+ + LLM(s) +
+
+ take + + resp / prompt +
+ + + + +
); +}; + +export default JoinNode; \ No newline at end of file diff --git a/chainforge/react-server/src/text-fields-node.css b/chainforge/react-server/src/text-fields-node.css index c5fd444..cf15e75 100644 --- a/chainforge/react-server/src/text-fields-node.css +++ b/chainforge/react-server/src/text-fields-node.css @@ -531,6 +531,10 @@ border-color: #222; } + .join-node { + min-width: 200px; + } + .tabular-data-node { min-width: 280px; }