From 30a099bf2090bcada66a69909d3f7b52a342f8d9 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Thu, 19 Oct 2023 12:37:50 -0400 Subject: [PATCH] Reorganize join node logic --- chainforge/react-server/src/JoinNode.js | 148 +++++++++++++++++------- 1 file changed, 103 insertions(+), 45 deletions(-) diff --git a/chainforge/react-server/src/JoinNode.js b/chainforge/react-server/src/JoinNode.js index ebcd86f..7d4a028 100644 --- a/chainforge/react-server/src/JoinNode.js +++ b/chainforge/react-server/src/JoinNode.js @@ -30,10 +30,13 @@ const getVarsAndMetavars = (input_data) => { // Find all vars and metavars in the input data (if any): let varnames = new Set(); let metavars = new Set(); - input_data.forEach(resp_obj => { - if (typeof resp_obj === "string") return; - Object.keys(resp_obj.fill_history).forEach(v => varnames.add(v)); - if (resp_obj.metavars) Object.keys(resp_obj.metavars).forEach(v => metavars.add(v)); + Object.entries(input_data).forEach(([key, obj]) => { + if (key !== '__input') varnames.add(key); // A "var" can also be other properties on input_data + obj.forEach(resp_obj => { + if (typeof resp_obj === "string") return; + Object.keys(resp_obj.fill_history).forEach(v => varnames.add(v)); + if (resp_obj.metavars) Object.keys(resp_obj.metavars).forEach(v => metavars.add(v)); + }); }); varnames = Array.from(varnames); metavars = Array.from(metavars); @@ -43,10 +46,50 @@ const getVarsAndMetavars = (input_data) => { }; } -const countNumLLMs = (resp_objs) => { - return (new Set(resp_objs.filter(r => typeof r !== "string").map(r => r.llm.key || r.llm))).size; +const countNumLLMs = (resp_objs_or_dict) => { + const resp_objs = Array.isArray(resp_objs_or_dict) ? resp_objs_or_dict : Object.values(resp_objs_or_dict).flat(); + return (new Set(resp_objs.filter(r => typeof r !== "string" && r.llm !== undefined).map(r => r.llm?.key || r.llm))).size; } +const tagMetadataWithLLM = (input_data) => { + let new_data = {}; + Object.entries(input_data).forEach(([varname, resp_objs]) => { + new_data[varname] = resp_objs.map(r => { + if (!r || typeof r === 'string' || !r?.llm?.key) return r; + let r_copy = JSON.parse(JSON.stringify(r)); + r_copy.metavars["__LLM_key"] = r.llm.key; + return r_copy; + }); + }); + return new_data; +}; +const extractLLMLookup = (input_data) => { + let llm_lookup = {}; + Object.entries(input_data).forEach(([varname, resp_objs]) => { + resp_objs.forEach(r => { + if (typeof r === 'string' || !r?.llm?.key || r.llm.key in llm_lookup) return; + llm_lookup[r.llm.key] = r.llm; + }); + }); + return llm_lookup; +}; +const removeLLMTagFromMetadata = (metavars) => { + if (!('__LLM_key' in metavars)) + return metavars; + let mcopy = JSON.parse(JSON.stringify(metavars)); + delete metavars['__LLM_key']; + return mcopy; +}; +const removeLLMTagsFromMetadata = (input_data) => { + Object.values(input_data).forEach((resp_objs) => { + resp_objs.forEach(r => { + if ("__LLM_key" in r?.metavars) + delete r.metavars["__LLM_key"]; + }); + }); + return input_data; +}; + const truncStr = (s, maxLen) => { if (s.length > maxLen) // Cut the name short if it's long return s.substring(0, maxLen) + '...' @@ -154,7 +197,7 @@ const JoinNode = ({ data, id }) => { const [formatting, setFormatting] = useState(formattingOptions[0].value); const handleOnConnect = useCallback(() => { - const input_data = pullInputData(["__input"], id); + let input_data = pullInputData(["__input"], id); if (!input_data?.__input) { console.warn('Join Node: No input data detected.'); return; @@ -163,11 +206,10 @@ const JoinNode = ({ data, id }) => { console.log(input_data); // Find all vars and metavars in the input data (if any): - const {vars, metavars} = getVarsAndMetavars(input_data.__input); + let {vars, metavars} = getVarsAndMetavars(input_data); - // Check whether more than one LLM is present in the inputs: - const numLLMs = countNumLLMs(input_data.__input); - setInputHasLLMs(numLLMs > 1); + // Create lookup table for LLMs in input, indexed by llm key + const llm_lookup = extractLLMLookup(input_data); // Refresh the dropdown list with available vars/metavars: setGroupByVars([DEFAULT_GROUPBY_VAR_ALL].concat( @@ -176,7 +218,16 @@ const JoinNode = ({ data, id }) => { metavars.filter(varname => !varname.startsWith('LLM_')).map(varname => ({label: `by ${varname} (meta)`, value: `M${varname}`}))) ); - const joinByVars = (input) => { + // Check whether more than one LLM is present in the inputs: + const numLLMs = countNumLLMs(input_data); + setInputHasLLMs(numLLMs > 1); + + // Tag all response objects in the input data with a metavar for their LLM (using the llm key as a uid) + input_data = tagMetadataWithLLM(input_data); + + // A function to group the input (an array of texts/resp_objs) by the selected var + // and then join the texts within the groups + const joinByVar = (input) => { const varname = groupByVar.substring(1); const isMetavar = groupByVar[0] === 'M'; const [groupedResps, unspecGroup] = groupResponsesBy(input, @@ -217,51 +268,58 @@ const JoinNode = ({ data, id }) => { return joined_texts; }; - // If there's multiple LLMs and groupByLLM is 'within', we need to - // first group by the LLMs (and a possible 'undefined' group): - if (numLLMs > 1 && groupByLLM === 'within') { - let joined_texts = []; - const [groupedRespsByLLM, nonLLMRespGroup] = groupResponsesBy(input_data.__input, r => r.llm?.key || r.llm); - Object.entries(groupedRespsByLLM).map(([llm_key, resp_objs]) => { - // Group only within the LLM - joined_texts = joined_texts.concat(joinByVars(resp_objs)); - }); + // Generate (flatten) the inputs, which could be recursively chained templates + // and a mix of LLM resp objects, templates, and strings. + // (We tagged each object with its LLM key so that we can use built-in features to keep track of the LLM associated with each response object) + fetch_from_backend('generatePrompts', { + prompt: "{__input}", + vars: input_data, + }).then(promptTemplates => { - if (nonLLMRespGroup.length > 0) - joined_texts.push(joinTexts(nonLLMRespGroup, formatting)); + // Convert the templates into response objects + let resp_objs = promptTemplates.map(p => ({ + text: p.toString(), + fill_history: p.fill_history, + llm: "__LLM_key" in p.metavars ? llm_lookup[p.metavars['__LLM_key']] : undefined, + metavars: removeLLMTagFromMetadata(p.metavars), + })); + + // If there's multiple LLMs and groupByLLM is 'within', we need to + // first group by the LLMs (and a possible 'undefined' group): + if (numLLMs > 1 && groupByLLM === 'within') { + let joined_texts = []; + const [groupedRespsByLLM, nonLLMRespGroup] = groupResponsesBy(resp_objs, r => r.llm?.key || r.llm); + Object.entries(groupedRespsByLLM).map(([llm_key, resp_objs]) => { + // Group only within the LLM + joined_texts = joined_texts.concat(joinByVar(resp_objs)); + }); + + if (nonLLMRespGroup.length > 0) + joined_texts.push(joinTexts(nonLLMRespGroup, formatting)); - setJoinedTexts(joined_texts); - setDataPropsForNode(id, { fields: joined_texts }); - console.log(joined_texts); - } else { - // Join across LLMs (join irrespective of LLM): - if (groupByVar !== 'A') { - // If groupByVar is set to non-ALL (not "A"), then we need to group responses by that variable first: - const joined_texts = joinByVars(input_data.__input); setJoinedTexts(joined_texts); setDataPropsForNode(id, { fields: joined_texts }); + console.log(joined_texts); } else { - // Since templates could be chained, we need to run this - // through the prompt generator: - fetch_from_backend('generatePrompts', { - prompt: "{__input}", - vars: input_data, - }).then(promptTemplates => { - const texts = promptTemplates.map(p => p.toString()); - console.log(texts); - - let joined_texts = joinTexts(texts, formatting); + // Join across LLMs (join irrespective of LLM): + if (groupByVar !== 'A') { + // If groupByVar is set to non-ALL (not "A"), then we need to group responses by that variable first: + const joined_texts = joinByVar(resp_objs); + setJoinedTexts(joined_texts); + setDataPropsForNode(id, { fields: joined_texts }); + } else { + let joined_texts = joinTexts(resp_objs.map(r => ((typeof r === 'string') ? r : r.text)), formatting); // If there is exactly 1 LLM and it's present across all inputs, keep track of it: - if (numLLMs === 1 && input_data.__input.every((r) => r.llm !== undefined)) - joined_texts = {text: joined_texts, fill_history: {}, llm: input_data.__input[0].llm}; + if (numLLMs === 1 && resp_objs.every((r) => r.llm !== undefined)) + joined_texts = {text: joined_texts, fill_history: {}, llm: resp_objs[0].llm}; setJoinedTexts([joined_texts]); setDataPropsForNode(id, { fields: [joined_texts] }); console.log(joined_texts); - }); + } } - } + }); }, [formatting, pullInputData, groupByVar, groupByLLM]);