diff --git a/chainforge/react-server/src/JoinNode.js b/chainforge/react-server/src/JoinNode.js
index f90c205..5cabdac 100644
--- a/chainforge/react-server/src/JoinNode.js
+++ b/chainforge/react-server/src/JoinNode.js
@@ -43,8 +43,8 @@ const getVarsAndMetavars = (input_data) => {
};
}
-const containsMultipleLLMs = (resp_objs) => {
- return (new Set(resp_objs.map(r => r.llm.key || r.llm))).length > 1;
+const countNumLLMs = (resp_objs) => {
+ return (new Set(resp_objs.filter(r => typeof r !== "string").map(r => r.llm.key || r.llm))).size;
}
const truncStr = (s, maxLen) => {
@@ -69,26 +69,38 @@ const groupResponsesBy = (responses, keyFunc) => {
const DEFAULT_GROUPBY_VAR_ALL = { label: "all text", value: "A" };
-const displayJoinedTexts = (textInfos) =>
- textInfos.map((info, idx) => {
+const displayJoinedTexts = (textInfos, getColorForLLM) => {
+ const color_for_llm = (llm) => (getColorForLLM(llm) + '99');
+ return textInfos.map((info, idx) => {
+
const vars = info.fill_history;
- const var_tags = vars === undefined ? [] : Object.keys(vars).map((varname) => {
+ let var_tags = vars === undefined ? [] : Object.keys(vars).map((varname) => {
const v = truncStr(vars[varname].trim(), 72);
return (
- {var_tags}
-
- {info.text || info}
-
+
+
+ {var_tags}
+
+ {info.llm === undefined ?
+ ps
+ : (
+
{info.llm?.name}
+ {ps}
+ )
+ }
);
});
+};
-const JoinedTextsPopover = ({ textInfos, onHover, onClick }) => {
+const JoinedTextsPopover = ({ textInfos, onHover, onClick, getColorForLLM }) => {
const [opened, { close, open }] = useDisclosure(false);
const _onHover = useCallback(() => {
@@ -107,7 +119,7 @@ const JoinedTextsPopover = ({ textInfos, onHover, onClick }) => {
Preview of joined inputs ({textInfos?.length} total)
- {displayJoinedTexts(textInfos)}
+ {displayJoinedTexts(textInfos, getColorForLLM)}
);
@@ -128,6 +140,9 @@ const JoinNode = ({ data, id }) => {
const inputEdgesForNode = useStore((state) => state.inputEdgesForNode);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
+ // Global lookup for what color to use per LLM
+ const getColorForLLMAndSetIfNotFound = useStore((state) => state.getColorForLLMAndSetIfNotFound);
+
const [inputHasLLMs, setInputHasLLMs] = useState(false);
const [inputHasMultiRespPerLLM, setInputHasMultiRespPerLLM] = useState(false);
@@ -150,6 +165,10 @@ const JoinNode = ({ data, id }) => {
// Find all vars and metavars in the input data (if any):
const {vars, metavars} = getVarsAndMetavars(input_data.__input);
+ // Check whether more than one LLM is present in the inputs:
+ const numLLMs = countNumLLMs(input_data.__input);
+ setInputHasLLMs(numLLMs > 1);
+
// Refresh the dropdown list with available vars/metavars:
setGroupByVars([DEFAULT_GROUPBY_VAR_ALL].concat(
vars.map(varname => ({label: `within ${varname}`, value: `V${varname}`})))
@@ -157,13 +176,12 @@ const JoinNode = ({ data, id }) => {
metavars.filter(varname => !varname.startsWith('LLM_')).map(varname => ({label: `within ${varname} (meta)`, value: `M${varname}`})))
);
- // If groupByVar is set to non-ALL (not "A"), then we need to group responses by that variable first:
- if (groupByVar !== 'A') {
+ const joinByVars = (input) => {
const varname = groupByVar.substring(1);
- const [groupedResps, unspecGroup] = groupResponsesBy(input_data.__input,
+ const [groupedResps, unspecGroup] = groupResponsesBy(input,
(groupByVar[0] === 'V') ?
- (r) => r.fill_history[varname] :
- (r) => r.metavars[varname]
+ (r) => (r.fill_history ? r.fill_history[varname] : undefined) :
+ (r) => (r.metavars ? r.metavars[varname] : undefined)
);
console.log(groupedResps);
@@ -171,36 +189,76 @@ const JoinNode = ({ data, id }) => {
// (NOTE: We can do this directly here as response texts can't be templates themselves)
let joined_texts = Object.entries(groupedResps).map(([var_val, resp_objs]) => {
if (resp_objs.length === 0) return "";
- let llm = containsMultipleLLMs(resp_objs) ? undefined : resp_objs[0].llm;
+ const llm = (countNumLLMs(resp_objs) > 1) ? undefined : resp_objs[0].llm;
let vars = {};
- vars[varname] = var_val;
+ if (groupByVar !== 'A')
+ vars[varname] = var_val;
return {
- text: joinTexts(resp_objs.map(r => r.text), formatting),
+ text: joinTexts(resp_objs.map(r => r.text !== undefined ? r.text : r), formatting),
fill_history: vars,
llm: llm,
// NOTE: We lose all other metadata here, because we could've joined across other vars or metavars values.
};
});
+
+ // Add any data from unspecified group
+ if (unspecGroup.length > 0) {
+ const llm = (countNumLLMs(unspecGroup) > 1) ? undefined : unspecGroup[0].llm;
+ joined_texts.push({
+ text: joinTexts(unspecGroup.map(u => u.text !== undefined ? u.text : u), formatting),
+ fill_history: {},
+ llm: llm,
+ });
+ }
+
+ return joined_texts;
+ };
+
+ // If there's multiple LLMs and groupByLLM is 'within', we need to
+ // first group by the LLMs (and a possible 'undefined' group):
+ if (numLLMs > 1 && groupByLLM === 'within') {
+ let joined_texts = [];
+ const [groupedRespsByLLM, nonLLMRespGroup] = groupResponsesBy(input_data.__input, r => r.llm?.key || r.llm);
+ Object.entries(groupedRespsByLLM).map(([llm_key, resp_objs]) => {
+ // Group only within the LLM
+ joined_texts = joined_texts.concat(joinByVars(resp_objs));
+ });
+
+ if (nonLLMRespGroup.length > 0)
+ joined_texts.push(joinTexts(nonLLMRespGroup, formatting));
+
setJoinedTexts(joined_texts);
console.log(joined_texts);
- }
- else {
- // Since templates could be chained, we need to run this
- // through the prompt generator:
- fetch_from_backend('generatePrompts', {
- prompt: "{__input}",
- vars: input_data,
- }).then(promptTemplates => {
- const texts = promptTemplates.map(p => p.toString());
- console.log(texts);
+ } else {
+ // Join across LLMs (join irrespective of LLM):
+ if (groupByVar !== 'A') {
+ // If groupByVar is set to non-ALL (not "A"), then we need to group responses by that variable first:
+ const joined_texts = joinByVars(input_data.__input);
+ setJoinedTexts(joined_texts);
+ } else {
+ // Since templates could be chained, we need to run this
+ // through the prompt generator:
+ fetch_from_backend('generatePrompts', {
+ prompt: "{__input}",
+ vars: input_data,
+ }).then(promptTemplates => {
+ const texts = promptTemplates.map(p => p.toString());
+ console.log(texts);
- const joined_texts = joinTexts(texts, formatting);
- setJoinedTexts([joined_texts]);
- console.log(joined_texts);
- });
+ let joined_texts = joinTexts(texts, formatting);
+
+ // If there is exactly 1 LLM and it's present across all inputs, keep track of it:
+ if (numLLMs === 1 && input_data.__input.every((r) => r.llm !== undefined))
+ joined_texts = {text: joined_texts, fill_history: {}, llm: input_data.__input[0].llm};
+
+ setJoinedTexts([joined_texts]);
+ console.log(joined_texts);
+ });
+ }
}
- }, [formatting, pullInputData, groupByVar]);
+
+ }, [formatting, pullInputData, groupByVar, groupByLLM]);
if (data.input) {
// If there's a change in inputs...
@@ -224,11 +282,11 @@ const JoinNode = ({ data, id }) => {
nodeId={id}
icon={
}
customButtons={[
-
+
]} />
- {displayJoinedTexts(joinedTexts)}
+ {displayJoinedTexts(joinedTexts, getColorForLLMAndSetIfNotFound)}
@@ -251,7 +309,7 @@ const JoinNode = ({ data, id }) => {
maw='80px'
mr='xs'
ml='40px' />
- LLM(s)
+ LLMs
: <>>}
{inputHasMultiRespPerLLM ?