Reorganize join node logic

This commit is contained in:
Ian Arawjo 2023-10-19 12:37:50 -04:00
parent 9c7b36fa7e
commit 30a099bf20

View File

@ -30,10 +30,13 @@ const getVarsAndMetavars = (input_data) => {
// Find all vars and metavars in the input data (if any):
let varnames = new Set();
let metavars = new Set();
input_data.forEach(resp_obj => {
if (typeof resp_obj === "string") return;
Object.keys(resp_obj.fill_history).forEach(v => varnames.add(v));
if (resp_obj.metavars) Object.keys(resp_obj.metavars).forEach(v => metavars.add(v));
Object.entries(input_data).forEach(([key, obj]) => {
if (key !== '__input') varnames.add(key); // A "var" can also be other properties on input_data
obj.forEach(resp_obj => {
if (typeof resp_obj === "string") return;
Object.keys(resp_obj.fill_history).forEach(v => varnames.add(v));
if (resp_obj.metavars) Object.keys(resp_obj.metavars).forEach(v => metavars.add(v));
});
});
varnames = Array.from(varnames);
metavars = Array.from(metavars);
@ -43,10 +46,50 @@ const getVarsAndMetavars = (input_data) => {
};
}
const countNumLLMs = (resp_objs) => {
return (new Set(resp_objs.filter(r => typeof r !== "string").map(r => r.llm.key || r.llm))).size;
const countNumLLMs = (resp_objs_or_dict) => {
const resp_objs = Array.isArray(resp_objs_or_dict) ? resp_objs_or_dict : Object.values(resp_objs_or_dict).flat();
return (new Set(resp_objs.filter(r => typeof r !== "string" && r.llm !== undefined).map(r => r.llm?.key || r.llm))).size;
}
const tagMetadataWithLLM = (input_data) => {
let new_data = {};
Object.entries(input_data).forEach(([varname, resp_objs]) => {
new_data[varname] = resp_objs.map(r => {
if (!r || typeof r === 'string' || !r?.llm?.key) return r;
let r_copy = JSON.parse(JSON.stringify(r));
r_copy.metavars["__LLM_key"] = r.llm.key;
return r_copy;
});
});
return new_data;
};
const extractLLMLookup = (input_data) => {
let llm_lookup = {};
Object.entries(input_data).forEach(([varname, resp_objs]) => {
resp_objs.forEach(r => {
if (typeof r === 'string' || !r?.llm?.key || r.llm.key in llm_lookup) return;
llm_lookup[r.llm.key] = r.llm;
});
});
return llm_lookup;
};
const removeLLMTagFromMetadata = (metavars) => {
if (!('__LLM_key' in metavars))
return metavars;
let mcopy = JSON.parse(JSON.stringify(metavars));
delete metavars['__LLM_key'];
return mcopy;
};
const removeLLMTagsFromMetadata = (input_data) => {
Object.values(input_data).forEach((resp_objs) => {
resp_objs.forEach(r => {
if ("__LLM_key" in r?.metavars)
delete r.metavars["__LLM_key"];
});
});
return input_data;
};
const truncStr = (s, maxLen) => {
if (s.length > maxLen) // Cut the name short if it's long
return s.substring(0, maxLen) + '...'
@ -154,7 +197,7 @@ const JoinNode = ({ data, id }) => {
const [formatting, setFormatting] = useState(formattingOptions[0].value);
const handleOnConnect = useCallback(() => {
const input_data = pullInputData(["__input"], id);
let input_data = pullInputData(["__input"], id);
if (!input_data?.__input) {
console.warn('Join Node: No input data detected.');
return;
@ -163,11 +206,10 @@ const JoinNode = ({ data, id }) => {
console.log(input_data);
// Find all vars and metavars in the input data (if any):
const {vars, metavars} = getVarsAndMetavars(input_data.__input);
let {vars, metavars} = getVarsAndMetavars(input_data);
// Check whether more than one LLM is present in the inputs:
const numLLMs = countNumLLMs(input_data.__input);
setInputHasLLMs(numLLMs > 1);
// Create lookup table for LLMs in input, indexed by llm key
const llm_lookup = extractLLMLookup(input_data);
// Refresh the dropdown list with available vars/metavars:
setGroupByVars([DEFAULT_GROUPBY_VAR_ALL].concat(
@ -176,7 +218,16 @@ const JoinNode = ({ data, id }) => {
metavars.filter(varname => !varname.startsWith('LLM_')).map(varname => ({label: `by ${varname} (meta)`, value: `M${varname}`})))
);
const joinByVars = (input) => {
// Check whether more than one LLM is present in the inputs:
const numLLMs = countNumLLMs(input_data);
setInputHasLLMs(numLLMs > 1);
// Tag all response objects in the input data with a metavar for their LLM (using the llm key as a uid)
input_data = tagMetadataWithLLM(input_data);
// A function to group the input (an array of texts/resp_objs) by the selected var
// and then join the texts within the groups
const joinByVar = (input) => {
const varname = groupByVar.substring(1);
const isMetavar = groupByVar[0] === 'M';
const [groupedResps, unspecGroup] = groupResponsesBy(input,
@ -217,51 +268,58 @@ const JoinNode = ({ data, id }) => {
return joined_texts;
};
// If there's multiple LLMs and groupByLLM is 'within', we need to
// first group by the LLMs (and a possible 'undefined' group):
if (numLLMs > 1 && groupByLLM === 'within') {
let joined_texts = [];
const [groupedRespsByLLM, nonLLMRespGroup] = groupResponsesBy(input_data.__input, r => r.llm?.key || r.llm);
Object.entries(groupedRespsByLLM).map(([llm_key, resp_objs]) => {
// Group only within the LLM
joined_texts = joined_texts.concat(joinByVars(resp_objs));
});
// Generate (flatten) the inputs, which could be recursively chained templates
// and a mix of LLM resp objects, templates, and strings.
// (We tagged each object with its LLM key so that we can use built-in features to keep track of the LLM associated with each response object)
fetch_from_backend('generatePrompts', {
prompt: "{__input}",
vars: input_data,
}).then(promptTemplates => {
if (nonLLMRespGroup.length > 0)
joined_texts.push(joinTexts(nonLLMRespGroup, formatting));
// Convert the templates into response objects
let resp_objs = promptTemplates.map(p => ({
text: p.toString(),
fill_history: p.fill_history,
llm: "__LLM_key" in p.metavars ? llm_lookup[p.metavars['__LLM_key']] : undefined,
metavars: removeLLMTagFromMetadata(p.metavars),
}));
// If there's multiple LLMs and groupByLLM is 'within', we need to
// first group by the LLMs (and a possible 'undefined' group):
if (numLLMs > 1 && groupByLLM === 'within') {
let joined_texts = [];
const [groupedRespsByLLM, nonLLMRespGroup] = groupResponsesBy(resp_objs, r => r.llm?.key || r.llm);
Object.entries(groupedRespsByLLM).map(([llm_key, resp_objs]) => {
// Group only within the LLM
joined_texts = joined_texts.concat(joinByVar(resp_objs));
});
if (nonLLMRespGroup.length > 0)
joined_texts.push(joinTexts(nonLLMRespGroup, formatting));
setJoinedTexts(joined_texts);
setDataPropsForNode(id, { fields: joined_texts });
console.log(joined_texts);
} else {
// Join across LLMs (join irrespective of LLM):
if (groupByVar !== 'A') {
// If groupByVar is set to non-ALL (not "A"), then we need to group responses by that variable first:
const joined_texts = joinByVars(input_data.__input);
setJoinedTexts(joined_texts);
setDataPropsForNode(id, { fields: joined_texts });
console.log(joined_texts);
} else {
// Since templates could be chained, we need to run this
// through the prompt generator:
fetch_from_backend('generatePrompts', {
prompt: "{__input}",
vars: input_data,
}).then(promptTemplates => {
const texts = promptTemplates.map(p => p.toString());
console.log(texts);
let joined_texts = joinTexts(texts, formatting);
// Join across LLMs (join irrespective of LLM):
if (groupByVar !== 'A') {
// If groupByVar is set to non-ALL (not "A"), then we need to group responses by that variable first:
const joined_texts = joinByVar(resp_objs);
setJoinedTexts(joined_texts);
setDataPropsForNode(id, { fields: joined_texts });
} else {
let joined_texts = joinTexts(resp_objs.map(r => ((typeof r === 'string') ? r : r.text)), formatting);
// If there is exactly 1 LLM and it's present across all inputs, keep track of it:
if (numLLMs === 1 && input_data.__input.every((r) => r.llm !== undefined))
joined_texts = {text: joined_texts, fill_history: {}, llm: input_data.__input[0].llm};
if (numLLMs === 1 && resp_objs.every((r) => r.llm !== undefined))
joined_texts = {text: joined_texts, fill_history: {}, llm: resp_objs[0].llm};
setJoinedTexts([joined_texts]);
setDataPropsForNode(id, { fields: [joined_texts] });
console.log(joined_texts);
});
}
}
}
});
}, [formatting, pullInputData, groupByVar, groupByLLM]);