From 667518942405a0643fdd7111fef7a1f1a4dbb0b9 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Mon, 16 Oct 2023 14:29:05 -0400 Subject: [PATCH] edge cases --- chainforge/react-server/src/JoinNode.js | 134 +++++++++++++++++------- 1 file changed, 96 insertions(+), 38 deletions(-) diff --git a/chainforge/react-server/src/JoinNode.js b/chainforge/react-server/src/JoinNode.js index f90c205..5cabdac 100644 --- a/chainforge/react-server/src/JoinNode.js +++ b/chainforge/react-server/src/JoinNode.js @@ -43,8 +43,8 @@ const getVarsAndMetavars = (input_data) => { }; } -const containsMultipleLLMs = (resp_objs) => { - return (new Set(resp_objs.map(r => r.llm.key || r.llm))).length > 1; +const countNumLLMs = (resp_objs) => { + return (new Set(resp_objs.filter(r => typeof r !== "string").map(r => r.llm.key || r.llm))).size; } const truncStr = (s, maxLen) => { @@ -69,26 +69,38 @@ const groupResponsesBy = (responses, keyFunc) => { const DEFAULT_GROUPBY_VAR_ALL = { label: "all text", value: "A" }; -const displayJoinedTexts = (textInfos) => - textInfos.map((info, idx) => { +const displayJoinedTexts = (textInfos, getColorForLLM) => { + const color_for_llm = (llm) => (getColorForLLM(llm) + '99'); + return textInfos.map((info, idx) => { + const vars = info.fill_history; - const var_tags = vars === undefined ? [] : Object.keys(vars).map((varname) => { + let var_tags = vars === undefined ? [] : Object.keys(vars).map((varname) => { const v = truncStr(vars[varname].trim(), 72); return (
{varname} = {v}
); }); + + const ps = (
{info.text || info}
); + return ( -
- {var_tags} -
-          {info.text || info}
-        
+
+
+ {var_tags} +
+ {info.llm === undefined ? + ps + : (
+

{info.llm?.name}

+ {ps} +
) + }
); }); +}; -const JoinedTextsPopover = ({ textInfos, onHover, onClick }) => { +const JoinedTextsPopover = ({ textInfos, onHover, onClick, getColorForLLM }) => { const [opened, { close, open }] = useDisclosure(false); const _onHover = useCallback(() => { @@ -107,7 +119,7 @@ const JoinedTextsPopover = ({ textInfos, onHover, onClick }) => {
Preview of joined inputs ({textInfos?.length} total)
- {displayJoinedTexts(textInfos)} + {displayJoinedTexts(textInfos, getColorForLLM)}
); @@ -128,6 +140,9 @@ const JoinNode = ({ data, id }) => { const inputEdgesForNode = useStore((state) => state.inputEdgesForNode); const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); + // Global lookup for what color to use per LLM + const getColorForLLMAndSetIfNotFound = useStore((state) => state.getColorForLLMAndSetIfNotFound); + const [inputHasLLMs, setInputHasLLMs] = useState(false); const [inputHasMultiRespPerLLM, setInputHasMultiRespPerLLM] = useState(false); @@ -150,6 +165,10 @@ const JoinNode = ({ data, id }) => { // Find all vars and metavars in the input data (if any): const {vars, metavars} = getVarsAndMetavars(input_data.__input); + // Check whether more than one LLM is present in the inputs: + const numLLMs = countNumLLMs(input_data.__input); + setInputHasLLMs(numLLMs > 1); + // Refresh the dropdown list with available vars/metavars: setGroupByVars([DEFAULT_GROUPBY_VAR_ALL].concat( vars.map(varname => ({label: `within ${varname}`, value: `V${varname}`}))) @@ -157,13 +176,12 @@ const JoinNode = ({ data, id }) => { metavars.filter(varname => !varname.startsWith('LLM_')).map(varname => ({label: `within ${varname} (meta)`, value: `M${varname}`}))) ); - // If groupByVar is set to non-ALL (not "A"), then we need to group responses by that variable first: - if (groupByVar !== 'A') { + const joinByVars = (input) => { const varname = groupByVar.substring(1); - const [groupedResps, unspecGroup] = groupResponsesBy(input_data.__input, + const [groupedResps, unspecGroup] = groupResponsesBy(input, (groupByVar[0] === 'V') ? - (r) => r.fill_history[varname] : - (r) => r.metavars[varname] + (r) => (r.fill_history ? r.fill_history[varname] : undefined) : + (r) => (r.metavars ? r.metavars[varname] : undefined) ); console.log(groupedResps); @@ -171,36 +189,76 @@ const JoinNode = ({ data, id }) => { // (NOTE: We can do this directly here as response texts can't be templates themselves) let joined_texts = Object.entries(groupedResps).map(([var_val, resp_objs]) => { if (resp_objs.length === 0) return ""; - let llm = containsMultipleLLMs(resp_objs) ? undefined : resp_objs[0].llm; + const llm = (countNumLLMs(resp_objs) > 1) ? undefined : resp_objs[0].llm; let vars = {}; - vars[varname] = var_val; + if (groupByVar !== 'A') + vars[varname] = var_val; return { - text: joinTexts(resp_objs.map(r => r.text), formatting), + text: joinTexts(resp_objs.map(r => r.text !== undefined ? r.text : r), formatting), fill_history: vars, llm: llm, // NOTE: We lose all other metadata here, because we could've joined across other vars or metavars values. }; }); + + // Add any data from unspecified group + if (unspecGroup.length > 0) { + const llm = (countNumLLMs(unspecGroup) > 1) ? undefined : unspecGroup[0].llm; + joined_texts.push({ + text: joinTexts(unspecGroup.map(u => u.text !== undefined ? u.text : u), formatting), + fill_history: {}, + llm: llm, + }); + } + + return joined_texts; + }; + + // If there's multiple LLMs and groupByLLM is 'within', we need to + // first group by the LLMs (and a possible 'undefined' group): + if (numLLMs > 1 && groupByLLM === 'within') { + let joined_texts = []; + const [groupedRespsByLLM, nonLLMRespGroup] = groupResponsesBy(input_data.__input, r => r.llm?.key || r.llm); + Object.entries(groupedRespsByLLM).map(([llm_key, resp_objs]) => { + // Group only within the LLM + joined_texts = joined_texts.concat(joinByVars(resp_objs)); + }); + + if (nonLLMRespGroup.length > 0) + joined_texts.push(joinTexts(nonLLMRespGroup, formatting)); + setJoinedTexts(joined_texts); console.log(joined_texts); - } - else { - // Since templates could be chained, we need to run this - // through the prompt generator: - fetch_from_backend('generatePrompts', { - prompt: "{__input}", - vars: input_data, - }).then(promptTemplates => { - const texts = promptTemplates.map(p => p.toString()); - console.log(texts); + } else { + // Join across LLMs (join irrespective of LLM): + if (groupByVar !== 'A') { + // If groupByVar is set to non-ALL (not "A"), then we need to group responses by that variable first: + const joined_texts = joinByVars(input_data.__input); + setJoinedTexts(joined_texts); + } else { + // Since templates could be chained, we need to run this + // through the prompt generator: + fetch_from_backend('generatePrompts', { + prompt: "{__input}", + vars: input_data, + }).then(promptTemplates => { + const texts = promptTemplates.map(p => p.toString()); + console.log(texts); - const joined_texts = joinTexts(texts, formatting); - setJoinedTexts([joined_texts]); - console.log(joined_texts); - }); + let joined_texts = joinTexts(texts, formatting); + + // If there is exactly 1 LLM and it's present across all inputs, keep track of it: + if (numLLMs === 1 && input_data.__input.every((r) => r.llm !== undefined)) + joined_texts = {text: joined_texts, fill_history: {}, llm: input_data.__input[0].llm}; + + setJoinedTexts([joined_texts]); + console.log(joined_texts); + }); + } } - }, [formatting, pullInputData, groupByVar]); + + }, [formatting, pullInputData, groupByVar, groupByLLM]); if (data.input) { // If there's a change in inputs... @@ -224,11 +282,11 @@ const JoinNode = ({ data, id }) => { nodeId={id} icon={} customButtons={[ - + ]} /> - {displayJoinedTexts(joinedTexts)} + {displayJoinedTexts(joinedTexts, getColorForLLMAndSetIfNotFound)}
@@ -251,7 +309,7 @@ const JoinNode = ({ data, id }) => { maw='80px' mr='xs' ml='40px' /> - LLM(s) + LLMs
: <>} {inputHasMultiRespPerLLM ?