mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 08:16:37 +00:00
Reorganize join node logic
This commit is contained in:
parent
9c7b36fa7e
commit
30a099bf20
148
chainforge/react-server/src/JoinNode.js
vendored
148
chainforge/react-server/src/JoinNode.js
vendored
@ -30,10 +30,13 @@ const getVarsAndMetavars = (input_data) => {
|
||||
// Find all vars and metavars in the input data (if any):
|
||||
let varnames = new Set();
|
||||
let metavars = new Set();
|
||||
input_data.forEach(resp_obj => {
|
||||
if (typeof resp_obj === "string") return;
|
||||
Object.keys(resp_obj.fill_history).forEach(v => varnames.add(v));
|
||||
if (resp_obj.metavars) Object.keys(resp_obj.metavars).forEach(v => metavars.add(v));
|
||||
Object.entries(input_data).forEach(([key, obj]) => {
|
||||
if (key !== '__input') varnames.add(key); // A "var" can also be other properties on input_data
|
||||
obj.forEach(resp_obj => {
|
||||
if (typeof resp_obj === "string") return;
|
||||
Object.keys(resp_obj.fill_history).forEach(v => varnames.add(v));
|
||||
if (resp_obj.metavars) Object.keys(resp_obj.metavars).forEach(v => metavars.add(v));
|
||||
});
|
||||
});
|
||||
varnames = Array.from(varnames);
|
||||
metavars = Array.from(metavars);
|
||||
@ -43,10 +46,50 @@ const getVarsAndMetavars = (input_data) => {
|
||||
};
|
||||
}
|
||||
|
||||
const countNumLLMs = (resp_objs) => {
|
||||
return (new Set(resp_objs.filter(r => typeof r !== "string").map(r => r.llm.key || r.llm))).size;
|
||||
const countNumLLMs = (resp_objs_or_dict) => {
|
||||
const resp_objs = Array.isArray(resp_objs_or_dict) ? resp_objs_or_dict : Object.values(resp_objs_or_dict).flat();
|
||||
return (new Set(resp_objs.filter(r => typeof r !== "string" && r.llm !== undefined).map(r => r.llm?.key || r.llm))).size;
|
||||
}
|
||||
|
||||
const tagMetadataWithLLM = (input_data) => {
|
||||
let new_data = {};
|
||||
Object.entries(input_data).forEach(([varname, resp_objs]) => {
|
||||
new_data[varname] = resp_objs.map(r => {
|
||||
if (!r || typeof r === 'string' || !r?.llm?.key) return r;
|
||||
let r_copy = JSON.parse(JSON.stringify(r));
|
||||
r_copy.metavars["__LLM_key"] = r.llm.key;
|
||||
return r_copy;
|
||||
});
|
||||
});
|
||||
return new_data;
|
||||
};
|
||||
const extractLLMLookup = (input_data) => {
|
||||
let llm_lookup = {};
|
||||
Object.entries(input_data).forEach(([varname, resp_objs]) => {
|
||||
resp_objs.forEach(r => {
|
||||
if (typeof r === 'string' || !r?.llm?.key || r.llm.key in llm_lookup) return;
|
||||
llm_lookup[r.llm.key] = r.llm;
|
||||
});
|
||||
});
|
||||
return llm_lookup;
|
||||
};
|
||||
const removeLLMTagFromMetadata = (metavars) => {
|
||||
if (!('__LLM_key' in metavars))
|
||||
return metavars;
|
||||
let mcopy = JSON.parse(JSON.stringify(metavars));
|
||||
delete metavars['__LLM_key'];
|
||||
return mcopy;
|
||||
};
|
||||
const removeLLMTagsFromMetadata = (input_data) => {
|
||||
Object.values(input_data).forEach((resp_objs) => {
|
||||
resp_objs.forEach(r => {
|
||||
if ("__LLM_key" in r?.metavars)
|
||||
delete r.metavars["__LLM_key"];
|
||||
});
|
||||
});
|
||||
return input_data;
|
||||
};
|
||||
|
||||
const truncStr = (s, maxLen) => {
|
||||
if (s.length > maxLen) // Cut the name short if it's long
|
||||
return s.substring(0, maxLen) + '...'
|
||||
@ -154,7 +197,7 @@ const JoinNode = ({ data, id }) => {
|
||||
const [formatting, setFormatting] = useState(formattingOptions[0].value);
|
||||
|
||||
const handleOnConnect = useCallback(() => {
|
||||
const input_data = pullInputData(["__input"], id);
|
||||
let input_data = pullInputData(["__input"], id);
|
||||
if (!input_data?.__input) {
|
||||
console.warn('Join Node: No input data detected.');
|
||||
return;
|
||||
@ -163,11 +206,10 @@ const JoinNode = ({ data, id }) => {
|
||||
console.log(input_data);
|
||||
|
||||
// Find all vars and metavars in the input data (if any):
|
||||
const {vars, metavars} = getVarsAndMetavars(input_data.__input);
|
||||
let {vars, metavars} = getVarsAndMetavars(input_data);
|
||||
|
||||
// Check whether more than one LLM is present in the inputs:
|
||||
const numLLMs = countNumLLMs(input_data.__input);
|
||||
setInputHasLLMs(numLLMs > 1);
|
||||
// Create lookup table for LLMs in input, indexed by llm key
|
||||
const llm_lookup = extractLLMLookup(input_data);
|
||||
|
||||
// Refresh the dropdown list with available vars/metavars:
|
||||
setGroupByVars([DEFAULT_GROUPBY_VAR_ALL].concat(
|
||||
@ -176,7 +218,16 @@ const JoinNode = ({ data, id }) => {
|
||||
metavars.filter(varname => !varname.startsWith('LLM_')).map(varname => ({label: `by ${varname} (meta)`, value: `M${varname}`})))
|
||||
);
|
||||
|
||||
const joinByVars = (input) => {
|
||||
// Check whether more than one LLM is present in the inputs:
|
||||
const numLLMs = countNumLLMs(input_data);
|
||||
setInputHasLLMs(numLLMs > 1);
|
||||
|
||||
// Tag all response objects in the input data with a metavar for their LLM (using the llm key as a uid)
|
||||
input_data = tagMetadataWithLLM(input_data);
|
||||
|
||||
// A function to group the input (an array of texts/resp_objs) by the selected var
|
||||
// and then join the texts within the groups
|
||||
const joinByVar = (input) => {
|
||||
const varname = groupByVar.substring(1);
|
||||
const isMetavar = groupByVar[0] === 'M';
|
||||
const [groupedResps, unspecGroup] = groupResponsesBy(input,
|
||||
@ -217,51 +268,58 @@ const JoinNode = ({ data, id }) => {
|
||||
return joined_texts;
|
||||
};
|
||||
|
||||
// If there's multiple LLMs and groupByLLM is 'within', we need to
|
||||
// first group by the LLMs (and a possible 'undefined' group):
|
||||
if (numLLMs > 1 && groupByLLM === 'within') {
|
||||
let joined_texts = [];
|
||||
const [groupedRespsByLLM, nonLLMRespGroup] = groupResponsesBy(input_data.__input, r => r.llm?.key || r.llm);
|
||||
Object.entries(groupedRespsByLLM).map(([llm_key, resp_objs]) => {
|
||||
// Group only within the LLM
|
||||
joined_texts = joined_texts.concat(joinByVars(resp_objs));
|
||||
});
|
||||
// Generate (flatten) the inputs, which could be recursively chained templates
|
||||
// and a mix of LLM resp objects, templates, and strings.
|
||||
// (We tagged each object with its LLM key so that we can use built-in features to keep track of the LLM associated with each response object)
|
||||
fetch_from_backend('generatePrompts', {
|
||||
prompt: "{__input}",
|
||||
vars: input_data,
|
||||
}).then(promptTemplates => {
|
||||
|
||||
if (nonLLMRespGroup.length > 0)
|
||||
joined_texts.push(joinTexts(nonLLMRespGroup, formatting));
|
||||
// Convert the templates into response objects
|
||||
let resp_objs = promptTemplates.map(p => ({
|
||||
text: p.toString(),
|
||||
fill_history: p.fill_history,
|
||||
llm: "__LLM_key" in p.metavars ? llm_lookup[p.metavars['__LLM_key']] : undefined,
|
||||
metavars: removeLLMTagFromMetadata(p.metavars),
|
||||
}));
|
||||
|
||||
// If there's multiple LLMs and groupByLLM is 'within', we need to
|
||||
// first group by the LLMs (and a possible 'undefined' group):
|
||||
if (numLLMs > 1 && groupByLLM === 'within') {
|
||||
let joined_texts = [];
|
||||
const [groupedRespsByLLM, nonLLMRespGroup] = groupResponsesBy(resp_objs, r => r.llm?.key || r.llm);
|
||||
Object.entries(groupedRespsByLLM).map(([llm_key, resp_objs]) => {
|
||||
// Group only within the LLM
|
||||
joined_texts = joined_texts.concat(joinByVar(resp_objs));
|
||||
});
|
||||
|
||||
if (nonLLMRespGroup.length > 0)
|
||||
joined_texts.push(joinTexts(nonLLMRespGroup, formatting));
|
||||
|
||||
setJoinedTexts(joined_texts);
|
||||
setDataPropsForNode(id, { fields: joined_texts });
|
||||
console.log(joined_texts);
|
||||
} else {
|
||||
// Join across LLMs (join irrespective of LLM):
|
||||
if (groupByVar !== 'A') {
|
||||
// If groupByVar is set to non-ALL (not "A"), then we need to group responses by that variable first:
|
||||
const joined_texts = joinByVars(input_data.__input);
|
||||
setJoinedTexts(joined_texts);
|
||||
setDataPropsForNode(id, { fields: joined_texts });
|
||||
console.log(joined_texts);
|
||||
} else {
|
||||
// Since templates could be chained, we need to run this
|
||||
// through the prompt generator:
|
||||
fetch_from_backend('generatePrompts', {
|
||||
prompt: "{__input}",
|
||||
vars: input_data,
|
||||
}).then(promptTemplates => {
|
||||
const texts = promptTemplates.map(p => p.toString());
|
||||
console.log(texts);
|
||||
|
||||
let joined_texts = joinTexts(texts, formatting);
|
||||
// Join across LLMs (join irrespective of LLM):
|
||||
if (groupByVar !== 'A') {
|
||||
// If groupByVar is set to non-ALL (not "A"), then we need to group responses by that variable first:
|
||||
const joined_texts = joinByVar(resp_objs);
|
||||
setJoinedTexts(joined_texts);
|
||||
setDataPropsForNode(id, { fields: joined_texts });
|
||||
} else {
|
||||
let joined_texts = joinTexts(resp_objs.map(r => ((typeof r === 'string') ? r : r.text)), formatting);
|
||||
|
||||
// If there is exactly 1 LLM and it's present across all inputs, keep track of it:
|
||||
if (numLLMs === 1 && input_data.__input.every((r) => r.llm !== undefined))
|
||||
joined_texts = {text: joined_texts, fill_history: {}, llm: input_data.__input[0].llm};
|
||||
if (numLLMs === 1 && resp_objs.every((r) => r.llm !== undefined))
|
||||
joined_texts = {text: joined_texts, fill_history: {}, llm: resp_objs[0].llm};
|
||||
|
||||
setJoinedTexts([joined_texts]);
|
||||
setDataPropsForNode(id, { fields: [joined_texts] });
|
||||
console.log(joined_texts);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
}, [formatting, pullInputData, groupByVar, groupByLLM]);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user