From b448e300c5a86714b4f6171aae2a5b11a6e13b72 Mon Sep 17 00:00:00 2001 From: ianarawjo Date: Mon, 23 Oct 2023 15:12:13 -0400 Subject: [PATCH] Add Join node (#144) * Add Join node * Bug fix chat histories with undefined content in messages * Slightly decrease TF width --- chainforge/react-server/src/App.js | 25 +- chainforge/react-server/src/JoinNode.js | 392 ++++++++++++++++++ .../react-server/src/LLMResponseInspector.js | 2 +- chainforge/react-server/src/PromptNode.js | 71 +--- chainforge/react-server/src/RemoveEdge.js | 2 +- .../react-server/src/backend/backend.ts | 4 +- chainforge/react-server/src/store.js | 52 ++- .../react-server/src/text-fields-node.css | 9 +- 8 files changed, 494 insertions(+), 63 deletions(-) create mode 100644 chainforge/react-server/src/JoinNode.js diff --git a/chainforge/react-server/src/App.js b/chainforge/react-server/src/App.js index b799109..f791ba6 100644 --- a/chainforge/react-server/src/App.js +++ b/chainforge/react-server/src/App.js @@ -8,7 +8,7 @@ import ReactFlow, { } from 'reactflow'; import { Button, Menu, LoadingOverlay, Text, Box, List, Loader, Tooltip } from '@mantine/core'; import { useClipboard } from '@mantine/hooks'; -import { IconSettings, IconTextPlus, IconTerminal, IconCsv, IconSettingsAutomation, IconFileSymlink, IconRobot, IconRuler2 } from '@tabler/icons-react'; +import { IconSettings, IconTextPlus, IconTerminal, IconCsv, IconSettingsAutomation, IconFileSymlink, IconRobot, IconRuler2, IconArrowMerge } from '@tabler/icons-react'; import RemoveEdge from './RemoveEdge'; import TextFieldsNode from './TextFieldsNode'; // Import a custom node import PromptNode from './PromptNode'; @@ -19,6 +19,7 @@ import ScriptNode from './ScriptNode'; import AlertModal from './AlertModal'; import CsvNode from './CsvNode'; import TabularDataNode from './TabularDataNode'; +import JoinNode from './JoinNode'; import CommentNode from './CommentNode'; import GlobalSettingsModal from './GlobalSettingsModal'; import ExampleFlowsModal from './ExampleFlowsModal'; @@ -87,6 +88,7 @@ const nodeTypes = { csv: CsvNode, table: TabularDataNode, comment: CommentNode, + join: JoinNode, }; const edgeTypes = { @@ -197,27 +199,27 @@ const App = () => { code = "function evaluate(response) {\n return response.text.length;\n}"; addNode({ id: 'evalNode-'+Date.now(), type: 'evaluator', data: { language: progLang, code: code }, position: {x: x-200, y:y-100} }); }; - const addVisNode = (event) => { + const addVisNode = () => { const { x, y } = getViewportCenter(); addNode({ id: 'visNode-'+Date.now(), type: 'vis', data: {}, position: {x: x-200, y:y-100} }); }; - const addInspectNode = (event) => { + const addInspectNode = () => { const { x, y } = getViewportCenter(); addNode({ id: 'inspectNode-'+Date.now(), type: 'inspect', data: {}, position: {x: x-200, y:y-100} }); }; - const addScriptNode = (event) => { + const addScriptNode = () => { const { x, y } = getViewportCenter(); addNode({ id: 'scriptNode-'+Date.now(), type: 'script', data: {}, position: {x: x-200, y:y-100} }); }; - const addCsvNode = (event) => { + const addCsvNode = () => { const { x, y } = getViewportCenter(); addNode({ id: 'csvNode-'+Date.now(), type: 'csv', data: {}, position: {x: x-200, y:y-100} }); }; - const addTabularDataNode = (event) => { + const addTabularDataNode = () => { const { x, y } = getViewportCenter(); addNode({ id: 'table-'+Date.now(), type: 'table', data: {}, position: {x: x-200, y:y-100} }); }; - const addCommentNode = (event) => { + const addCommentNode = () => { const { x, y } = getViewportCenter(); addNode({ id: 'comment-'+Date.now(), type: 'comment', data: {}, position: {x: x-200, y:y-100} }); }; @@ -225,6 +227,10 @@ const App = () => { const { x, y } = getViewportCenter(); addNode({ id: 'llmeval-'+Date.now(), type: 'llmeval', data: {}, position: {x: x-200, y:y-100} }); }; + const addJoinNode = () => { + const { x, y } = getViewportCenter(); + addNode({ id: 'join-'+Date.now(), type: 'join', data: {}, position: {x: x-200, y:y-100} }); + }; const onClickExamples = () => { if (examplesModal && examplesModal.current) @@ -768,6 +774,11 @@ const App = () => { Inspect Node + Processors + + }> Join Node + + Misc Comment Node diff --git a/chainforge/react-server/src/JoinNode.js b/chainforge/react-server/src/JoinNode.js new file mode 100644 index 0000000..676f827 --- /dev/null +++ b/chainforge/react-server/src/JoinNode.js @@ -0,0 +1,392 @@ +import React, { useState, useEffect, useCallback } from 'react'; +import { Handle } from 'reactflow'; +import useStore from './store'; +import NodeLabel from './NodeLabelComponent'; +import fetch_from_backend from './fetch_from_backend'; +import { IconArrowMerge, IconList } from '@tabler/icons-react'; +import { Divider, NativeSelect, Text, Popover, Tooltip, Center, Modal, Box } from '@mantine/core'; +import { useDisclosure } from '@mantine/hooks'; + +const formattingOptions = [ + {value: "\n\n", label:"double newline \\n\\n"}, + {value: "\n", label:"newline \\n"}, + {value: "-", label:"- dashed list"}, + {value: "1.", label:"1. numbered list"}, + {value: "[]", label:'["list", "of", "strings"]'} +]; + +const joinTexts = (texts, formatting) => { + if (formatting === "\n\n" || formatting === "\n") + return texts.join(formatting); + else if (formatting === "-") + return texts.map((t) => ('- ' + t)).join("\n"); + else if (formatting === "1.") + return texts.map((t, i) => (`${i+1}. ${t}`)).join("\n"); + else if (formatting === '[]') + return JSON.stringify(texts); + + console.error(`Could not join: Unknown formatting option: ${formatting}`); + return texts; +}; + +const getVarsAndMetavars = (input_data) => { + // Find all vars and metavars in the input data (if any): + let varnames = new Set(); + let metavars = new Set(); + Object.entries(input_data).forEach(([key, obj]) => { + if (key !== '__input') varnames.add(key); // A "var" can also be other properties on input_data + obj.forEach(resp_obj => { + if (typeof resp_obj === "string") return; + Object.keys(resp_obj.fill_history).forEach(v => varnames.add(v)); + if (resp_obj.metavars) Object.keys(resp_obj.metavars).forEach(v => metavars.add(v)); + }); + }); + varnames = Array.from(varnames); + metavars = Array.from(metavars); + return { + vars: varnames, + metavars: metavars, + }; +}; + +const countNumLLMs = (resp_objs_or_dict) => { + const resp_objs = Array.isArray(resp_objs_or_dict) ? resp_objs_or_dict : Object.values(resp_objs_or_dict).flat(); + return (new Set(resp_objs.filter(r => typeof r !== "string" && r.llm !== undefined).map(r => r.llm?.key || r.llm))).size; +}; + +const tagMetadataWithLLM = (input_data) => { + let new_data = {}; + Object.entries(input_data).forEach(([varname, resp_objs]) => { + new_data[varname] = resp_objs.map(r => { + if (!r || typeof r === 'string' || !r?.llm?.key) return r; + let r_copy = JSON.parse(JSON.stringify(r)); + r_copy.metavars["__LLM_key"] = r.llm.key; + return r_copy; + }); + }); + return new_data; +}; +const extractLLMLookup = (input_data) => { + let llm_lookup = {}; + Object.entries(input_data).forEach(([varname, resp_objs]) => { + resp_objs.forEach(r => { + if (typeof r === 'string' || !r?.llm?.key || r.llm.key in llm_lookup) return; + llm_lookup[r.llm.key] = r.llm; + }); + }); + return llm_lookup; +}; +const removeLLMTagFromMetadata = (metavars) => { + if (!('__LLM_key' in metavars)) + return metavars; + let mcopy = JSON.parse(JSON.stringify(metavars)); + delete metavars['__LLM_key']; + return mcopy; +}; + +const truncStr = (s, maxLen) => { + if (s.length > maxLen) // Cut the name short if it's long + return s.substring(0, maxLen) + '...' + else + return s; +}; +const groupResponsesBy = (responses, keyFunc) => { + let responses_by_key = {}; + let unspecified_group = []; + responses.forEach(item => { + const key = keyFunc(item); + const d = key !== null ? responses_by_key : unspecified_group; + if (key in d) + d[key].push(item); + else + d[key] = [item]; + }); + return [responses_by_key, unspecified_group]; +}; + +const DEFAULT_GROUPBY_VAR_ALL = { label: "all text", value: "A" }; + +const displayJoinedTexts = (textInfos, getColorForLLM) => { + const color_for_llm = (llm) => (getColorForLLM(llm) + '99'); + return textInfos.map((info, idx) => { + + const vars = info.fill_history; + let var_tags = vars === undefined ? [] : Object.keys(vars).map((varname) => { + const v = truncStr(vars[varname].trim(), 72); + return (
+ {varname} = {v} +
); + }); + + const ps = (
{info.text || info}
); + + return ( +
+
+ {var_tags} +
+ {info.llm === undefined ? + ps + : (
+

{info.llm?.name}

+ {ps} +
) + } +
+ ); + }); +}; + +const JoinedTextsPopover = ({ textInfos, onHover, onClick, getColorForLLM }) => { + const [opened, { close, open }] = useDisclosure(false); + + const _onHover = useCallback(() => { + onHover(); + open(); + }, [onHover, open]); + + return ( + + + + + + + +
Preview of joined inputs ({textInfos?.length} total)
+ {displayJoinedTexts(textInfos, getColorForLLM)} +
+
+ ); +}; + + +const JoinNode = ({ data, id }) => { + + const [joinedTexts, setJoinedTexts] = useState([]); + + // For an info pop-up that previews all the joined inputs + const [infoModalOpened, { open: openInfoModal, close: closeInfoModal }] = useDisclosure(false); + + const [pastInputs, setPastInputs] = useState([]); + const pullInputData = useStore((state) => state.pullInputData); + const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); + + // Global lookup for what color to use per LLM + const getColorForLLMAndSetIfNotFound = useStore((state) => state.getColorForLLMAndSetIfNotFound); + + const [inputHasLLMs, setInputHasLLMs] = useState(false); + + const [groupByVars, setGroupByVars] = useState([DEFAULT_GROUPBY_VAR_ALL]); + const [groupByVar, setGroupByVar] = useState("A"); + + const [groupByLLM, setGroupByLLM] = useState("within"); + const [formatting, setFormatting] = useState(formattingOptions[0].value); + + const handleOnConnect = useCallback(() => { + let input_data = pullInputData(["__input"], id); + if (!input_data?.__input) { + console.warn('Join Node: No input data detected.'); + return; + } + + // Find all vars and metavars in the input data (if any): + let {vars, metavars} = getVarsAndMetavars(input_data); + + // Create lookup table for LLMs in input, indexed by llm key + const llm_lookup = extractLLMLookup(input_data); + + // Refresh the dropdown list with available vars/metavars: + setGroupByVars([DEFAULT_GROUPBY_VAR_ALL].concat( + vars.map(varname => ({label: `by ${varname}`, value: `V${varname}`}))) + .concat( + metavars.filter(varname => !varname.startsWith('LLM_')).map(varname => ({label: `by ${varname} (meta)`, value: `M${varname}`}))) + ); + + // Check whether more than one LLM is present in the inputs: + const numLLMs = countNumLLMs(input_data); + setInputHasLLMs(numLLMs > 1); + + // Tag all response objects in the input data with a metavar for their LLM (using the llm key as a uid) + input_data = tagMetadataWithLLM(input_data); + + // A function to group the input (an array of texts/resp_objs) by the selected var + // and then join the texts within the groups + const joinByVar = (input) => { + const varname = groupByVar.substring(1); + const isMetavar = groupByVar[0] === 'M'; + const [groupedResps, unspecGroup] = groupResponsesBy(input, + isMetavar ? + (r) => (r.metavars ? r.metavars[varname] : undefined) : + (r) => (r.fill_history ? r.fill_history[varname] : undefined) + ); + + // Now join texts within each group: + // (NOTE: We can do this directly here as response texts can't be templates themselves) + let joined_texts = Object.entries(groupedResps).map(([var_val, resp_objs]) => { + if (resp_objs.length === 0) return ""; + const llm = (countNumLLMs(resp_objs) > 1) ? undefined : resp_objs[0].llm; + let vars = {}; + if (groupByVar !== 'A') + vars[varname] = var_val; + return { + text: joinTexts(resp_objs.map(r => r.text !== undefined ? r.text : r), formatting), + fill_history: isMetavar ? {} : vars, + metavars: isMetavar ? vars : {}, + llm: llm, + // NOTE: We lose all other metadata here, because we could've joined across other vars or metavars values. + }; + }); + + // Add any data from unspecified group + if (unspecGroup.length > 0) { + const llm = (countNumLLMs(unspecGroup) > 1) ? undefined : unspecGroup[0].llm; + joined_texts.push({ + text: joinTexts(unspecGroup.map(u => u.text !== undefined ? u.text : u), formatting), + fill_history: {}, + metavars: {}, + llm: llm, + }); + } + + return joined_texts; + }; + + // Generate (flatten) the inputs, which could be recursively chained templates + // and a mix of LLM resp objects, templates, and strings. + // (We tagged each object with its LLM key so that we can use built-in features to keep track of the LLM associated with each response object) + fetch_from_backend('generatePrompts', { + prompt: "{__input}", + vars: input_data, + }).then(promptTemplates => { + + // Convert the templates into response objects + let resp_objs = promptTemplates.map(p => ({ + text: p.toString(), + fill_history: p.fill_history, + llm: "__LLM_key" in p.metavars ? llm_lookup[p.metavars['__LLM_key']] : undefined, + metavars: removeLLMTagFromMetadata(p.metavars), + })); + + // If there's multiple LLMs and groupByLLM is 'within', we need to + // first group by the LLMs (and a possible 'undefined' group): + if (numLLMs > 1 && groupByLLM === 'within') { + let joined_texts = []; + const [groupedRespsByLLM, nonLLMRespGroup] = groupResponsesBy(resp_objs, r => r.llm?.key || r.llm); + Object.entries(groupedRespsByLLM).map(([llm_key, resp_objs]) => { + // Group only within the LLM + joined_texts = joined_texts.concat(joinByVar(resp_objs)); + }); + + if (nonLLMRespGroup.length > 0) + joined_texts.push(joinTexts(nonLLMRespGroup, formatting)); + + setJoinedTexts(joined_texts); + setDataPropsForNode(id, { fields: joined_texts }); + } else { + // Join across LLMs (join irrespective of LLM): + if (groupByVar !== 'A') { + // If groupByVar is set to non-ALL (not "A"), then we need to group responses by that variable first: + const joined_texts = joinByVar(resp_objs); + setJoinedTexts(joined_texts); + setDataPropsForNode(id, { fields: joined_texts }); + } else { + let joined_texts = joinTexts(resp_objs.map(r => ((typeof r === 'string') ? r : r.text)), formatting); + + // If there is exactly 1 LLM and it's present across all inputs, keep track of it: + if (numLLMs === 1 && resp_objs.every((r) => r.llm !== undefined)) + joined_texts = {text: joined_texts, fill_history: {}, llm: resp_objs[0].llm}; + + setJoinedTexts([joined_texts]); + setDataPropsForNode(id, { fields: [joined_texts] }); + } + } + }); + + }, [formatting, pullInputData, groupByVar, groupByLLM]); + + if (data.input) { + // If there's a change in inputs... + if (data.input != pastInputs) { + setPastInputs(data.input); + handleOnConnect(); + } + } + + // Refresh join output anytime the dropdowns change + useEffect(() => { + handleOnConnect(); + }, [groupByVar, groupByLLM, formatting]) + + useEffect(() => { + if (data.refresh && data.refresh === true) { + // Recreate the visualization: + setDataPropsForNode(id, { refresh: false }); + handleOnConnect(); + } + }, [data, id, handleOnConnect, setDataPropsForNode]); + + return ( +
+ } + customButtons={[ + + ]} /> + + + {displayJoinedTexts(joinedTexts, getColorForLLMAndSetIfNotFound)} + + +
+ Join + setGroupByVar(e.target.value)} + className='nodrag nowheel' + data={groupByVars} + size="xs" + value={groupByVar} + miw='80px' + mr='xs' /> +
+ {inputHasLLMs ? +
+ setGroupByLLM(e.target.value)} + className='nodrag nowheel' + data={["within", "across"]} + size="xs" + value={groupByLLM} + maw='80px' + mr='xs' + ml='40px' /> + LLMs +
+ : <>} + + setFormatting(e.target.value)} + className='nodrag nowheel' + data={formattingOptions} + size="xs" + value={formatting} + miw='80px' /> + + +
); +}; + +export default JoinNode; \ No newline at end of file diff --git a/chainforge/react-server/src/LLMResponseInspector.js b/chainforge/react-server/src/LLMResponseInspector.js index 97570d6..797d44f 100644 --- a/chainforge/react-server/src/LLMResponseInspector.js +++ b/chainforge/react-server/src/LLMResponseInspector.js @@ -475,7 +475,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { value={multiSelectValue} clearSearchOnChange={true} clearSearchOnBlur={true} - w='80%' /> + w={wideFormat ? '80%' : '100%'} /> setOnlyShowScores(e.currentTarget.checked)} diff --git a/chainforge/react-server/src/PromptNode.js b/chainforge/react-server/src/PromptNode.js index e787e3a..0ad13ab 100644 --- a/chainforge/react-server/src/PromptNode.js +++ b/chainforge/react-server/src/PromptNode.js @@ -78,10 +78,9 @@ const PromptNode = ({ data, id, type: node_type }) => { // Get state from the Zustand store: const edges = useStore((state) => state.edges); - const output = useStore((state) => state.output); + const pullInputData = useStore((state) => state.pullInputData); const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); const pingOutputNodes = useStore((state) => state.pingOutputNodes); - const getNode = useStore((state) => state.getNode); // API Keys (set by user in popup GlobalSettingsModal) const apiKeys = useStore((state) => state.apiKeys); @@ -208,52 +207,10 @@ const PromptNode = ({ data, id, type: node_type }) => { } }, [data]); - // Pull all inputs needed to request responses. - // Returns [prompt, vars dict] - const pullInputData = (_targetHandles) => { - // Pull data from each source recursively: - const pulled_data = {}; - const store_data = (_texts, _varname, _data) => { - if (_varname in _data) - _data[_varname] = _data[_varname].concat(_texts); - else - _data[_varname] = _texts; - }; - const get_outputs = (varnames, nodeId) => { - varnames.forEach(varname => { - // Find the relevant edge(s): - edges.forEach(e => { - if (e.target == nodeId && e.targetHandle == varname) { - // Get the immediate output: - let out = output(e.source, e.sourceHandle); - if (!out || !Array.isArray(out) || out.length === 0) return; - - // Check the format of the output. Can be str or dict with 'text' and more attrs: - if (typeof out[0] === 'object') { - out.forEach(obj => store_data([obj], varname, pulled_data)); - } - else { - // Save the list of strings from the pulled output under the var 'varname' - store_data(out, varname, pulled_data); - } - - // Get any vars that the output depends on, and recursively collect those outputs as well: - const n_vars = getNode(e.source).data.vars; - if (n_vars && Array.isArray(n_vars) && n_vars.length > 0) - get_outputs(n_vars, e.source); - } - }); - }); - }; - get_outputs(_targetHandles, id); - - return pulled_data; - }; - // Chat nodes only. Pulls input data attached to the 'past conversations' handle. // Returns a tuple (past_chat_llms, __past_chats), where both are undefined if nothing is connected. const pullInputChats = () => { - const pulled_data = pullInputData(['__past_chats']); + const pulled_data = pullInputData(['__past_chats'], id); if (!('__past_chats' in pulled_data)) return [undefined, undefined]; // For storing the unique LLMs in past_chats: @@ -313,12 +270,12 @@ const PromptNode = ({ data, id, type: node_type }) => { const [promptPreviews, setPromptPreviews] = useState([]); const handlePreviewHover = () => { // Pull input data and prompt - const pulled_vars = pullInputData(templateVars); + const pulled_vars = pullInputData(templateVars, id); fetch_from_backend('generatePrompts', { prompt: promptText, vars: pulled_vars, }).then(prompts => { - setPromptPreviews(prompts.map(p => (new PromptInfo(p)))); + setPromptPreviews(prompts.map(p => (new PromptInfo(p.toString())))); }); pullInputChats(); @@ -352,7 +309,7 @@ const PromptNode = ({ data, id, type: node_type }) => { } // Pull the input data - const pulled_vars = pullInputData(templateVars); + const pulled_vars = pullInputData(templateVars, id); const llms = _llmItemsCurrState.map(item => item.model); const num_llms = llms.length; @@ -442,6 +399,20 @@ const PromptNode = ({ data, id, type: node_type }) => { return; } + // Check if pulled chats includes undefined content. + // This could happen with Join nodes, where there is no longer a single "prompt" (user prompt) + // of the chat provenance. Instead of blocking this behavior, we replace undefined with a blank string, + // and output a warning to the console. + if (!pulled_chats.every(c => c.messages.every(m => m.content !== undefined))) { + console.warn("Chat history contains undefined content. This can happen if a Join Node was used, \ + as there is no longer a single prompt as the provenance of the conversation. \ + Soft failing by replacing undefined with empty strings."); + pulled_chats.forEach(c => {c.messages = c.messages.map(m => { + if (m.content !== undefined) return m; + else return {...m, content: " "}; // the string contains a single space since PaLM2 refuses to answer with empty strings + })}); + } + // Override LLM list with the past llm info (unique LLMs in prior responses) _llmItemsCurrState = past_chat_llms; @@ -462,13 +433,13 @@ const PromptNode = ({ data, id, type: node_type }) => { setProgressAnimated(true); // Pull the data to fill in template input variables, if any - const pulled_data = pullInputData(templateVars); + const pulled_data = pullInputData(templateVars, id); const prompt_template = promptText; const rejected = (err) => { setStatus('error'); setContChatToggleDisabled(false); - triggerAlert(err.message); + triggerAlert(err.message || err); }; // Fetch info about the number of queries we'll need to make diff --git a/chainforge/react-server/src/RemoveEdge.js b/chainforge/react-server/src/RemoveEdge.js index 885b87a..1245d46 100644 --- a/chainforge/react-server/src/RemoveEdge.js +++ b/chainforge/react-server/src/RemoveEdge.js @@ -45,7 +45,7 @@ export default function CustomEdge({ // Thanks in part to oshanley https://github.com/wbkd/react-flow/issues/1211#issuecomment-1585032930 return ( - setHovering(true)} onPointerLeave={()=>setHovering(false)} onClick={()=>console.log('click')}> + setHovering(true)} onPointerLeave={()=>setHovering(false)}>
any, responses: A * @param vars a dict of the template variables to fill the prompt template with, by name. (See countQueries docstring for more info). * @returns An array of strings representing the prompts that will be sent out. Note that this could include unfilled template vars. */ -export async function generatePrompts(root_prompt: string, vars: Dict): Promise { +export async function generatePrompts(root_prompt: string, vars: Dict): Promise { const gen_prompts = new PromptPermutationGenerator(root_prompt); - const all_prompt_permutations = Array.from(gen_prompts.generate(vars)).map(p => p.toString()); + const all_prompt_permutations = Array.from(gen_prompts.generate(vars)); return all_prompt_permutations; } diff --git a/chainforge/react-server/src/store.js b/chainforge/react-server/src/store.js index a4c2f85..fad3e0e 100644 --- a/chainforge/react-server/src/store.js +++ b/chainforge/react-server/src/store.js @@ -26,7 +26,7 @@ export const colorPalettes = { var: varColorPalette, } -const refreshableOutputNodeTypes = new Set(['evaluator', 'prompt', 'inspect', 'vis', 'llmeval', 'textfields', 'chat', 'simpleval']); +const refreshableOutputNodeTypes = new Set(['evaluator', 'prompt', 'inspect', 'vis', 'llmeval', 'textfields', 'chat', 'simpleval', 'join']); export let initLLMProviders = [ { name: "GPT3.5", emoji: "🤖", model: "gpt-3.5-turbo", base_model: "gpt-3.5-turbo", temp: 1.0 }, // The base_model designates what settings form will be used, and must be unique. @@ -204,6 +204,56 @@ const useStore = create((set, get) => ({ return null; } }, + + // Pull all inputs needed to request responses. + // Returns [prompt, vars dict] + pullInputData: (_targetHandles, node_id) => { + // Functions/data from the store: + const getNode = get().getNode; + const output = get().output; + const edges = get().edges; + + // Helper function to store collected data in dict: + const store_data = (_texts, _varname, _data) => { + if (_varname in _data) + _data[_varname] = _data[_varname].concat(_texts); + else + _data[_varname] = _texts; + }; + + // Pull data from each source recursively: + const pulled_data = {}; + const get_outputs = (varnames, nodeId) => { + varnames.forEach(varname => { + // Find the relevant edge(s): + edges.forEach(e => { + if (e.target == nodeId && e.targetHandle == varname) { + // Get the immediate output: + let out = output(e.source, e.sourceHandle); + if (!out || !Array.isArray(out) || out.length === 0) return; + + // Check the format of the output. Can be str or dict with 'text' and more attrs: + if (typeof out[0] === 'object') { + out.forEach(obj => store_data([obj], varname, pulled_data)); + } + else { + // Save the list of strings from the pulled output under the var 'varname' + store_data(out, varname, pulled_data); + } + + // Get any vars that the output depends on, and recursively collect those outputs as well: + const n_vars = getNode(e.source).data.vars; + if (n_vars && Array.isArray(n_vars) && n_vars.length > 0) + get_outputs(n_vars, e.source); + } + }); + }); + }; + get_outputs(_targetHandles, node_id); + + return pulled_data; + }, + setDataPropsForNode: (id, data_props) => { set({ nodes: (nds => diff --git a/chainforge/react-server/src/text-fields-node.css b/chainforge/react-server/src/text-fields-node.css index c5fd444..c9a0079 100644 --- a/chainforge/react-server/src/text-fields-node.css +++ b/chainforge/react-server/src/text-fields-node.css @@ -409,6 +409,9 @@ color: #444; white-space: pre-wrap; } + .join-text-preview { + margin: 0px 0px 10px 0px; + } .small-response { font-size: 8pt; @@ -531,6 +534,10 @@ border-color: #222; } + .join-node { + min-width: 200px; + } + .tabular-data-node { min-width: 280px; } @@ -652,7 +659,7 @@ .text-field-fixed .mantine-Textarea-wrapper textarea { resize: vertical; overflow-y: auto; - width: 280px; + width: 260px; padding: calc(0.5rem / 3); font-size: 10pt; font-family: monospace;