diff --git a/chainforge/react-server/src/App.js b/chainforge/react-server/src/App.js
index b799109..f791ba6 100644
--- a/chainforge/react-server/src/App.js
+++ b/chainforge/react-server/src/App.js
@@ -8,7 +8,7 @@ import ReactFlow, {
} from 'reactflow';
import { Button, Menu, LoadingOverlay, Text, Box, List, Loader, Tooltip } from '@mantine/core';
import { useClipboard } from '@mantine/hooks';
-import { IconSettings, IconTextPlus, IconTerminal, IconCsv, IconSettingsAutomation, IconFileSymlink, IconRobot, IconRuler2 } from '@tabler/icons-react';
+import { IconSettings, IconTextPlus, IconTerminal, IconCsv, IconSettingsAutomation, IconFileSymlink, IconRobot, IconRuler2, IconArrowMerge } from '@tabler/icons-react';
import RemoveEdge from './RemoveEdge';
import TextFieldsNode from './TextFieldsNode'; // Import a custom node
import PromptNode from './PromptNode';
@@ -19,6 +19,7 @@ import ScriptNode from './ScriptNode';
import AlertModal from './AlertModal';
import CsvNode from './CsvNode';
import TabularDataNode from './TabularDataNode';
+import JoinNode from './JoinNode';
import CommentNode from './CommentNode';
import GlobalSettingsModal from './GlobalSettingsModal';
import ExampleFlowsModal from './ExampleFlowsModal';
@@ -87,6 +88,7 @@ const nodeTypes = {
csv: CsvNode,
table: TabularDataNode,
comment: CommentNode,
+ join: JoinNode,
};
const edgeTypes = {
@@ -197,27 +199,27 @@ const App = () => {
code = "function evaluate(response) {\n return response.text.length;\n}";
addNode({ id: 'evalNode-'+Date.now(), type: 'evaluator', data: { language: progLang, code: code }, position: {x: x-200, y:y-100} });
};
- const addVisNode = (event) => {
+ const addVisNode = () => {
const { x, y } = getViewportCenter();
addNode({ id: 'visNode-'+Date.now(), type: 'vis', data: {}, position: {x: x-200, y:y-100} });
};
- const addInspectNode = (event) => {
+ const addInspectNode = () => {
const { x, y } = getViewportCenter();
addNode({ id: 'inspectNode-'+Date.now(), type: 'inspect', data: {}, position: {x: x-200, y:y-100} });
};
- const addScriptNode = (event) => {
+ const addScriptNode = () => {
const { x, y } = getViewportCenter();
addNode({ id: 'scriptNode-'+Date.now(), type: 'script', data: {}, position: {x: x-200, y:y-100} });
};
- const addCsvNode = (event) => {
+ const addCsvNode = () => {
const { x, y } = getViewportCenter();
addNode({ id: 'csvNode-'+Date.now(), type: 'csv', data: {}, position: {x: x-200, y:y-100} });
};
- const addTabularDataNode = (event) => {
+ const addTabularDataNode = () => {
const { x, y } = getViewportCenter();
addNode({ id: 'table-'+Date.now(), type: 'table', data: {}, position: {x: x-200, y:y-100} });
};
- const addCommentNode = (event) => {
+ const addCommentNode = () => {
const { x, y } = getViewportCenter();
addNode({ id: 'comment-'+Date.now(), type: 'comment', data: {}, position: {x: x-200, y:y-100} });
};
@@ -225,6 +227,10 @@ const App = () => {
const { x, y } = getViewportCenter();
addNode({ id: 'llmeval-'+Date.now(), type: 'llmeval', data: {}, position: {x: x-200, y:y-100} });
};
+ const addJoinNode = () => {
+ const { x, y } = getViewportCenter();
+ addNode({ id: 'join-'+Date.now(), type: 'join', data: {}, position: {x: x-200, y:y-100} });
+ };
const onClickExamples = () => {
if (examplesModal && examplesModal.current)
@@ -768,6 +774,11 @@ const App = () => {
Inspect Node
+ Processors
+
+ }> Join Node
+
+
Misc
Comment Node
diff --git a/chainforge/react-server/src/JoinNode.js b/chainforge/react-server/src/JoinNode.js
new file mode 100644
index 0000000..676f827
--- /dev/null
+++ b/chainforge/react-server/src/JoinNode.js
@@ -0,0 +1,392 @@
+import React, { useState, useEffect, useCallback } from 'react';
+import { Handle } from 'reactflow';
+import useStore from './store';
+import NodeLabel from './NodeLabelComponent';
+import fetch_from_backend from './fetch_from_backend';
+import { IconArrowMerge, IconList } from '@tabler/icons-react';
+import { Divider, NativeSelect, Text, Popover, Tooltip, Center, Modal, Box } from '@mantine/core';
+import { useDisclosure } from '@mantine/hooks';
+
+const formattingOptions = [
+ {value: "\n\n", label:"double newline \\n\\n"},
+ {value: "\n", label:"newline \\n"},
+ {value: "-", label:"- dashed list"},
+ {value: "1.", label:"1. numbered list"},
+ {value: "[]", label:'["list", "of", "strings"]'}
+];
+
+const joinTexts = (texts, formatting) => {
+ if (formatting === "\n\n" || formatting === "\n")
+ return texts.join(formatting);
+ else if (formatting === "-")
+ return texts.map((t) => ('- ' + t)).join("\n");
+ else if (formatting === "1.")
+ return texts.map((t, i) => (`${i+1}. ${t}`)).join("\n");
+ else if (formatting === '[]')
+ return JSON.stringify(texts);
+
+ console.error(`Could not join: Unknown formatting option: ${formatting}`);
+ return texts;
+};
+
+const getVarsAndMetavars = (input_data) => {
+ // Find all vars and metavars in the input data (if any):
+ let varnames = new Set();
+ let metavars = new Set();
+ Object.entries(input_data).forEach(([key, obj]) => {
+ if (key !== '__input') varnames.add(key); // A "var" can also be other properties on input_data
+ obj.forEach(resp_obj => {
+ if (typeof resp_obj === "string") return;
+ Object.keys(resp_obj.fill_history).forEach(v => varnames.add(v));
+ if (resp_obj.metavars) Object.keys(resp_obj.metavars).forEach(v => metavars.add(v));
+ });
+ });
+ varnames = Array.from(varnames);
+ metavars = Array.from(metavars);
+ return {
+ vars: varnames,
+ metavars: metavars,
+ };
+};
+
+const countNumLLMs = (resp_objs_or_dict) => {
+ const resp_objs = Array.isArray(resp_objs_or_dict) ? resp_objs_or_dict : Object.values(resp_objs_or_dict).flat();
+ return (new Set(resp_objs.filter(r => typeof r !== "string" && r.llm !== undefined).map(r => r.llm?.key || r.llm))).size;
+};
+
+const tagMetadataWithLLM = (input_data) => {
+ let new_data = {};
+ Object.entries(input_data).forEach(([varname, resp_objs]) => {
+ new_data[varname] = resp_objs.map(r => {
+ if (!r || typeof r === 'string' || !r?.llm?.key) return r;
+ let r_copy = JSON.parse(JSON.stringify(r));
+ r_copy.metavars["__LLM_key"] = r.llm.key;
+ return r_copy;
+ });
+ });
+ return new_data;
+};
+const extractLLMLookup = (input_data) => {
+ let llm_lookup = {};
+ Object.entries(input_data).forEach(([varname, resp_objs]) => {
+ resp_objs.forEach(r => {
+ if (typeof r === 'string' || !r?.llm?.key || r.llm.key in llm_lookup) return;
+ llm_lookup[r.llm.key] = r.llm;
+ });
+ });
+ return llm_lookup;
+};
+const removeLLMTagFromMetadata = (metavars) => {
+ if (!('__LLM_key' in metavars))
+ return metavars;
+ let mcopy = JSON.parse(JSON.stringify(metavars));
+ delete metavars['__LLM_key'];
+ return mcopy;
+};
+
+const truncStr = (s, maxLen) => {
+ if (s.length > maxLen) // Cut the name short if it's long
+ return s.substring(0, maxLen) + '...'
+ else
+ return s;
+};
+const groupResponsesBy = (responses, keyFunc) => {
+ let responses_by_key = {};
+ let unspecified_group = [];
+ responses.forEach(item => {
+ const key = keyFunc(item);
+ const d = key !== null ? responses_by_key : unspecified_group;
+ if (key in d)
+ d[key].push(item);
+ else
+ d[key] = [item];
+ });
+ return [responses_by_key, unspecified_group];
+};
+
+const DEFAULT_GROUPBY_VAR_ALL = { label: "all text", value: "A" };
+
+const displayJoinedTexts = (textInfos, getColorForLLM) => {
+ const color_for_llm = (llm) => (getColorForLLM(llm) + '99');
+ return textInfos.map((info, idx) => {
+
+ const vars = info.fill_history;
+ let var_tags = vars === undefined ? [] : Object.keys(vars).map((varname) => {
+ const v = truncStr(vars[varname].trim(), 72);
+ return (
+ {varname} = {v}
+
);
+ });
+
+ const ps = ({info.text || info}
);
+
+ return (
+
+
+ {var_tags}
+
+ {info.llm === undefined ?
+ ps
+ : (
+
{info.llm?.name}
+ {ps}
+ )
+ }
+
+ );
+ });
+};
+
+const JoinedTextsPopover = ({ textInfos, onHover, onClick, getColorForLLM }) => {
+ const [opened, { close, open }] = useDisclosure(false);
+
+ const _onHover = useCallback(() => {
+ onHover();
+ open();
+ }, [onHover, open]);
+
+ return (
+
+
+
+
+
+
+
+ Preview of joined inputs ({textInfos?.length} total)
+ {displayJoinedTexts(textInfos, getColorForLLM)}
+
+
+ );
+};
+
+
+const JoinNode = ({ data, id }) => {
+
+ const [joinedTexts, setJoinedTexts] = useState([]);
+
+ // For an info pop-up that previews all the joined inputs
+ const [infoModalOpened, { open: openInfoModal, close: closeInfoModal }] = useDisclosure(false);
+
+ const [pastInputs, setPastInputs] = useState([]);
+ const pullInputData = useStore((state) => state.pullInputData);
+ const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
+
+ // Global lookup for what color to use per LLM
+ const getColorForLLMAndSetIfNotFound = useStore((state) => state.getColorForLLMAndSetIfNotFound);
+
+ const [inputHasLLMs, setInputHasLLMs] = useState(false);
+
+ const [groupByVars, setGroupByVars] = useState([DEFAULT_GROUPBY_VAR_ALL]);
+ const [groupByVar, setGroupByVar] = useState("A");
+
+ const [groupByLLM, setGroupByLLM] = useState("within");
+ const [formatting, setFormatting] = useState(formattingOptions[0].value);
+
+ const handleOnConnect = useCallback(() => {
+ let input_data = pullInputData(["__input"], id);
+ if (!input_data?.__input) {
+ console.warn('Join Node: No input data detected.');
+ return;
+ }
+
+ // Find all vars and metavars in the input data (if any):
+ let {vars, metavars} = getVarsAndMetavars(input_data);
+
+ // Create lookup table for LLMs in input, indexed by llm key
+ const llm_lookup = extractLLMLookup(input_data);
+
+ // Refresh the dropdown list with available vars/metavars:
+ setGroupByVars([DEFAULT_GROUPBY_VAR_ALL].concat(
+ vars.map(varname => ({label: `by ${varname}`, value: `V${varname}`})))
+ .concat(
+ metavars.filter(varname => !varname.startsWith('LLM_')).map(varname => ({label: `by ${varname} (meta)`, value: `M${varname}`})))
+ );
+
+ // Check whether more than one LLM is present in the inputs:
+ const numLLMs = countNumLLMs(input_data);
+ setInputHasLLMs(numLLMs > 1);
+
+ // Tag all response objects in the input data with a metavar for their LLM (using the llm key as a uid)
+ input_data = tagMetadataWithLLM(input_data);
+
+ // A function to group the input (an array of texts/resp_objs) by the selected var
+ // and then join the texts within the groups
+ const joinByVar = (input) => {
+ const varname = groupByVar.substring(1);
+ const isMetavar = groupByVar[0] === 'M';
+ const [groupedResps, unspecGroup] = groupResponsesBy(input,
+ isMetavar ?
+ (r) => (r.metavars ? r.metavars[varname] : undefined) :
+ (r) => (r.fill_history ? r.fill_history[varname] : undefined)
+ );
+
+ // Now join texts within each group:
+ // (NOTE: We can do this directly here as response texts can't be templates themselves)
+ let joined_texts = Object.entries(groupedResps).map(([var_val, resp_objs]) => {
+ if (resp_objs.length === 0) return "";
+ const llm = (countNumLLMs(resp_objs) > 1) ? undefined : resp_objs[0].llm;
+ let vars = {};
+ if (groupByVar !== 'A')
+ vars[varname] = var_val;
+ return {
+ text: joinTexts(resp_objs.map(r => r.text !== undefined ? r.text : r), formatting),
+ fill_history: isMetavar ? {} : vars,
+ metavars: isMetavar ? vars : {},
+ llm: llm,
+ // NOTE: We lose all other metadata here, because we could've joined across other vars or metavars values.
+ };
+ });
+
+ // Add any data from unspecified group
+ if (unspecGroup.length > 0) {
+ const llm = (countNumLLMs(unspecGroup) > 1) ? undefined : unspecGroup[0].llm;
+ joined_texts.push({
+ text: joinTexts(unspecGroup.map(u => u.text !== undefined ? u.text : u), formatting),
+ fill_history: {},
+ metavars: {},
+ llm: llm,
+ });
+ }
+
+ return joined_texts;
+ };
+
+ // Generate (flatten) the inputs, which could be recursively chained templates
+ // and a mix of LLM resp objects, templates, and strings.
+ // (We tagged each object with its LLM key so that we can use built-in features to keep track of the LLM associated with each response object)
+ fetch_from_backend('generatePrompts', {
+ prompt: "{__input}",
+ vars: input_data,
+ }).then(promptTemplates => {
+
+ // Convert the templates into response objects
+ let resp_objs = promptTemplates.map(p => ({
+ text: p.toString(),
+ fill_history: p.fill_history,
+ llm: "__LLM_key" in p.metavars ? llm_lookup[p.metavars['__LLM_key']] : undefined,
+ metavars: removeLLMTagFromMetadata(p.metavars),
+ }));
+
+ // If there's multiple LLMs and groupByLLM is 'within', we need to
+ // first group by the LLMs (and a possible 'undefined' group):
+ if (numLLMs > 1 && groupByLLM === 'within') {
+ let joined_texts = [];
+ const [groupedRespsByLLM, nonLLMRespGroup] = groupResponsesBy(resp_objs, r => r.llm?.key || r.llm);
+ Object.entries(groupedRespsByLLM).map(([llm_key, resp_objs]) => {
+ // Group only within the LLM
+ joined_texts = joined_texts.concat(joinByVar(resp_objs));
+ });
+
+ if (nonLLMRespGroup.length > 0)
+ joined_texts.push(joinTexts(nonLLMRespGroup, formatting));
+
+ setJoinedTexts(joined_texts);
+ setDataPropsForNode(id, { fields: joined_texts });
+ } else {
+ // Join across LLMs (join irrespective of LLM):
+ if (groupByVar !== 'A') {
+ // If groupByVar is set to non-ALL (not "A"), then we need to group responses by that variable first:
+ const joined_texts = joinByVar(resp_objs);
+ setJoinedTexts(joined_texts);
+ setDataPropsForNode(id, { fields: joined_texts });
+ } else {
+ let joined_texts = joinTexts(resp_objs.map(r => ((typeof r === 'string') ? r : r.text)), formatting);
+
+ // If there is exactly 1 LLM and it's present across all inputs, keep track of it:
+ if (numLLMs === 1 && resp_objs.every((r) => r.llm !== undefined))
+ joined_texts = {text: joined_texts, fill_history: {}, llm: resp_objs[0].llm};
+
+ setJoinedTexts([joined_texts]);
+ setDataPropsForNode(id, { fields: [joined_texts] });
+ }
+ }
+ });
+
+ }, [formatting, pullInputData, groupByVar, groupByLLM]);
+
+ if (data.input) {
+ // If there's a change in inputs...
+ if (data.input != pastInputs) {
+ setPastInputs(data.input);
+ handleOnConnect();
+ }
+ }
+
+ // Refresh join output anytime the dropdowns change
+ useEffect(() => {
+ handleOnConnect();
+ }, [groupByVar, groupByLLM, formatting])
+
+ useEffect(() => {
+ if (data.refresh && data.refresh === true) {
+ // Recreate the visualization:
+ setDataPropsForNode(id, { refresh: false });
+ handleOnConnect();
+ }
+ }, [data, id, handleOnConnect, setDataPropsForNode]);
+
+ return (
+
+
}
+ customButtons={[
+
+ ]} />
+
+
+ {displayJoinedTexts(joinedTexts, getColorForLLMAndSetIfNotFound)}
+
+
+
+ Join
+ setGroupByVar(e.target.value)}
+ className='nodrag nowheel'
+ data={groupByVars}
+ size="xs"
+ value={groupByVar}
+ miw='80px'
+ mr='xs' />
+
+ {inputHasLLMs ?
+
+ setGroupByLLM(e.target.value)}
+ className='nodrag nowheel'
+ data={["within", "across"]}
+ size="xs"
+ value={groupByLLM}
+ maw='80px'
+ mr='xs'
+ ml='40px' />
+ LLMs
+
+ : <>>}
+
+
setFormatting(e.target.value)}
+ className='nodrag nowheel'
+ data={formattingOptions}
+ size="xs"
+ value={formatting}
+ miw='80px' />
+
+
+ );
+};
+
+export default JoinNode;
\ No newline at end of file
diff --git a/chainforge/react-server/src/LLMResponseInspector.js b/chainforge/react-server/src/LLMResponseInspector.js
index 97570d6..797d44f 100644
--- a/chainforge/react-server/src/LLMResponseInspector.js
+++ b/chainforge/react-server/src/LLMResponseInspector.js
@@ -475,7 +475,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
value={multiSelectValue}
clearSearchOnChange={true}
clearSearchOnBlur={true}
- w='80%' />
+ w={wideFormat ? '80%' : '100%'} />
setOnlyShowScores(e.currentTarget.checked)}
diff --git a/chainforge/react-server/src/PromptNode.js b/chainforge/react-server/src/PromptNode.js
index e787e3a..0ad13ab 100644
--- a/chainforge/react-server/src/PromptNode.js
+++ b/chainforge/react-server/src/PromptNode.js
@@ -78,10 +78,9 @@ const PromptNode = ({ data, id, type: node_type }) => {
// Get state from the Zustand store:
const edges = useStore((state) => state.edges);
- const output = useStore((state) => state.output);
+ const pullInputData = useStore((state) => state.pullInputData);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const pingOutputNodes = useStore((state) => state.pingOutputNodes);
- const getNode = useStore((state) => state.getNode);
// API Keys (set by user in popup GlobalSettingsModal)
const apiKeys = useStore((state) => state.apiKeys);
@@ -208,52 +207,10 @@ const PromptNode = ({ data, id, type: node_type }) => {
}
}, [data]);
- // Pull all inputs needed to request responses.
- // Returns [prompt, vars dict]
- const pullInputData = (_targetHandles) => {
- // Pull data from each source recursively:
- const pulled_data = {};
- const store_data = (_texts, _varname, _data) => {
- if (_varname in _data)
- _data[_varname] = _data[_varname].concat(_texts);
- else
- _data[_varname] = _texts;
- };
- const get_outputs = (varnames, nodeId) => {
- varnames.forEach(varname => {
- // Find the relevant edge(s):
- edges.forEach(e => {
- if (e.target == nodeId && e.targetHandle == varname) {
- // Get the immediate output:
- let out = output(e.source, e.sourceHandle);
- if (!out || !Array.isArray(out) || out.length === 0) return;
-
- // Check the format of the output. Can be str or dict with 'text' and more attrs:
- if (typeof out[0] === 'object') {
- out.forEach(obj => store_data([obj], varname, pulled_data));
- }
- else {
- // Save the list of strings from the pulled output under the var 'varname'
- store_data(out, varname, pulled_data);
- }
-
- // Get any vars that the output depends on, and recursively collect those outputs as well:
- const n_vars = getNode(e.source).data.vars;
- if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
- get_outputs(n_vars, e.source);
- }
- });
- });
- };
- get_outputs(_targetHandles, id);
-
- return pulled_data;
- };
-
// Chat nodes only. Pulls input data attached to the 'past conversations' handle.
// Returns a tuple (past_chat_llms, __past_chats), where both are undefined if nothing is connected.
const pullInputChats = () => {
- const pulled_data = pullInputData(['__past_chats']);
+ const pulled_data = pullInputData(['__past_chats'], id);
if (!('__past_chats' in pulled_data)) return [undefined, undefined];
// For storing the unique LLMs in past_chats:
@@ -313,12 +270,12 @@ const PromptNode = ({ data, id, type: node_type }) => {
const [promptPreviews, setPromptPreviews] = useState([]);
const handlePreviewHover = () => {
// Pull input data and prompt
- const pulled_vars = pullInputData(templateVars);
+ const pulled_vars = pullInputData(templateVars, id);
fetch_from_backend('generatePrompts', {
prompt: promptText,
vars: pulled_vars,
}).then(prompts => {
- setPromptPreviews(prompts.map(p => (new PromptInfo(p))));
+ setPromptPreviews(prompts.map(p => (new PromptInfo(p.toString()))));
});
pullInputChats();
@@ -352,7 +309,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
}
// Pull the input data
- const pulled_vars = pullInputData(templateVars);
+ const pulled_vars = pullInputData(templateVars, id);
const llms = _llmItemsCurrState.map(item => item.model);
const num_llms = llms.length;
@@ -442,6 +399,20 @@ const PromptNode = ({ data, id, type: node_type }) => {
return;
}
+ // Check if pulled chats includes undefined content.
+ // This could happen with Join nodes, where there is no longer a single "prompt" (user prompt)
+ // of the chat provenance. Instead of blocking this behavior, we replace undefined with a blank string,
+ // and output a warning to the console.
+ if (!pulled_chats.every(c => c.messages.every(m => m.content !== undefined))) {
+ console.warn("Chat history contains undefined content. This can happen if a Join Node was used, \
+ as there is no longer a single prompt as the provenance of the conversation. \
+ Soft failing by replacing undefined with empty strings.");
+ pulled_chats.forEach(c => {c.messages = c.messages.map(m => {
+ if (m.content !== undefined) return m;
+ else return {...m, content: " "}; // the string contains a single space since PaLM2 refuses to answer with empty strings
+ })});
+ }
+
// Override LLM list with the past llm info (unique LLMs in prior responses)
_llmItemsCurrState = past_chat_llms;
@@ -462,13 +433,13 @@ const PromptNode = ({ data, id, type: node_type }) => {
setProgressAnimated(true);
// Pull the data to fill in template input variables, if any
- const pulled_data = pullInputData(templateVars);
+ const pulled_data = pullInputData(templateVars, id);
const prompt_template = promptText;
const rejected = (err) => {
setStatus('error');
setContChatToggleDisabled(false);
- triggerAlert(err.message);
+ triggerAlert(err.message || err);
};
// Fetch info about the number of queries we'll need to make
diff --git a/chainforge/react-server/src/RemoveEdge.js b/chainforge/react-server/src/RemoveEdge.js
index 885b87a..1245d46 100644
--- a/chainforge/react-server/src/RemoveEdge.js
+++ b/chainforge/react-server/src/RemoveEdge.js
@@ -45,7 +45,7 @@ export default function CustomEdge({
// Thanks in part to oshanley https://github.com/wbkd/react-flow/issues/1211#issuecomment-1585032930
return (
- setHovering(true)} onPointerLeave={()=>setHovering(false)} onClick={()=>console.log('click')}>
+ setHovering(true)} onPointerLeave={()=>setHovering(false)}>
any, responses: A
* @param vars a dict of the template variables to fill the prompt template with, by name. (See countQueries docstring for more info).
* @returns An array of strings representing the prompts that will be sent out. Note that this could include unfilled template vars.
*/
-export async function generatePrompts(root_prompt: string, vars: Dict): Promise
{
+export async function generatePrompts(root_prompt: string, vars: Dict): Promise {
const gen_prompts = new PromptPermutationGenerator(root_prompt);
- const all_prompt_permutations = Array.from(gen_prompts.generate(vars)).map(p => p.toString());
+ const all_prompt_permutations = Array.from(gen_prompts.generate(vars));
return all_prompt_permutations;
}
diff --git a/chainforge/react-server/src/store.js b/chainforge/react-server/src/store.js
index a4c2f85..fad3e0e 100644
--- a/chainforge/react-server/src/store.js
+++ b/chainforge/react-server/src/store.js
@@ -26,7 +26,7 @@ export const colorPalettes = {
var: varColorPalette,
}
-const refreshableOutputNodeTypes = new Set(['evaluator', 'prompt', 'inspect', 'vis', 'llmeval', 'textfields', 'chat', 'simpleval']);
+const refreshableOutputNodeTypes = new Set(['evaluator', 'prompt', 'inspect', 'vis', 'llmeval', 'textfields', 'chat', 'simpleval', 'join']);
export let initLLMProviders = [
{ name: "GPT3.5", emoji: "🤖", model: "gpt-3.5-turbo", base_model: "gpt-3.5-turbo", temp: 1.0 }, // The base_model designates what settings form will be used, and must be unique.
@@ -204,6 +204,56 @@ const useStore = create((set, get) => ({
return null;
}
},
+
+ // Pull all inputs needed to request responses.
+ // Returns [prompt, vars dict]
+ pullInputData: (_targetHandles, node_id) => {
+ // Functions/data from the store:
+ const getNode = get().getNode;
+ const output = get().output;
+ const edges = get().edges;
+
+ // Helper function to store collected data in dict:
+ const store_data = (_texts, _varname, _data) => {
+ if (_varname in _data)
+ _data[_varname] = _data[_varname].concat(_texts);
+ else
+ _data[_varname] = _texts;
+ };
+
+ // Pull data from each source recursively:
+ const pulled_data = {};
+ const get_outputs = (varnames, nodeId) => {
+ varnames.forEach(varname => {
+ // Find the relevant edge(s):
+ edges.forEach(e => {
+ if (e.target == nodeId && e.targetHandle == varname) {
+ // Get the immediate output:
+ let out = output(e.source, e.sourceHandle);
+ if (!out || !Array.isArray(out) || out.length === 0) return;
+
+ // Check the format of the output. Can be str or dict with 'text' and more attrs:
+ if (typeof out[0] === 'object') {
+ out.forEach(obj => store_data([obj], varname, pulled_data));
+ }
+ else {
+ // Save the list of strings from the pulled output under the var 'varname'
+ store_data(out, varname, pulled_data);
+ }
+
+ // Get any vars that the output depends on, and recursively collect those outputs as well:
+ const n_vars = getNode(e.source).data.vars;
+ if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
+ get_outputs(n_vars, e.source);
+ }
+ });
+ });
+ };
+ get_outputs(_targetHandles, node_id);
+
+ return pulled_data;
+ },
+
setDataPropsForNode: (id, data_props) => {
set({
nodes: (nds =>
diff --git a/chainforge/react-server/src/text-fields-node.css b/chainforge/react-server/src/text-fields-node.css
index c5fd444..c9a0079 100644
--- a/chainforge/react-server/src/text-fields-node.css
+++ b/chainforge/react-server/src/text-fields-node.css
@@ -409,6 +409,9 @@
color: #444;
white-space: pre-wrap;
}
+ .join-text-preview {
+ margin: 0px 0px 10px 0px;
+ }
.small-response {
font-size: 8pt;
@@ -531,6 +534,10 @@
border-color: #222;
}
+ .join-node {
+ min-width: 200px;
+ }
+
.tabular-data-node {
min-width: 280px;
}
@@ -652,7 +659,7 @@
.text-field-fixed .mantine-Textarea-wrapper textarea {
resize: vertical;
overflow-y: auto;
- width: 280px;
+ width: 260px;
padding: calc(0.5rem / 3);
font-size: 10pt;
font-family: monospace;