mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-15 08:40:48 +00:00
Prompt template chaining
This commit is contained in:
parent
775b61e89c
commit
31ebc91942
@ -45,11 +45,15 @@ const InspectorNode = ({ data, id }) => {
|
|||||||
// Bucket responses by LLM:
|
// Bucket responses by LLM:
|
||||||
const responses_by_llm = bucketResponsesByLLM(json.responses);
|
const responses_by_llm = bucketResponsesByLLM(json.responses);
|
||||||
|
|
||||||
// Get the var names across responses
|
// // Get the var names across all responses, as a set
|
||||||
// NOTE: This assumes only a single prompt node output as input
|
// let tempvarnames = new Set();
|
||||||
// (all response vars have the exact same keys).
|
// json.responses.forEach(r => {
|
||||||
let tempvars = {};
|
// if (!r.vars) return;
|
||||||
Object.keys(json.responses[0].vars).forEach(v => {tempvars[v] = new Set();});
|
// Object.keys(r.vars).forEach(tempvarnames.add);
|
||||||
|
// });
|
||||||
|
|
||||||
|
// // Create a dict version
|
||||||
|
// let tempvars = {};
|
||||||
|
|
||||||
const vars_to_str = (vars) => {
|
const vars_to_str = (vars) => {
|
||||||
const pairs = Object.keys(vars).map(varname => {
|
const pairs = Object.keys(vars).map(varname => {
|
||||||
@ -67,7 +71,7 @@ const InspectorNode = ({ data, id }) => {
|
|||||||
const ps = res_obj.responses.map((r, idx) =>
|
const ps = res_obj.responses.map((r, idx) =>
|
||||||
(<pre className="small-response" key={idx}>{r}</pre>)
|
(<pre className="small-response" key={idx}>{r}</pre>)
|
||||||
);
|
);
|
||||||
Object.keys(res_obj.vars).forEach(v => {tempvars[v].add(res_obj.vars[v])});
|
// Object.keys(res_obj.vars).forEach(v => {tempvars[v].add(res_obj.vars[v])});
|
||||||
const vars = vars_to_str(res_obj.vars);
|
const vars = vars_to_str(res_obj.vars);
|
||||||
return (
|
return (
|
||||||
<div key={"r"+res_idx} className="response-box" style={{ backgroundColor: colors[llm_idx % colors.length] }}>
|
<div key={"r"+res_idx} className="response-box" style={{ backgroundColor: colors[llm_idx % colors.length] }}>
|
||||||
@ -84,19 +88,19 @@ const InspectorNode = ({ data, id }) => {
|
|||||||
);
|
);
|
||||||
}));
|
}));
|
||||||
|
|
||||||
setVarSelects(Object.keys(tempvars).map(v => {
|
// setVarSelects(Object.keys(tempvars).map(v => {
|
||||||
const options = Array.from(tempvars[v]).map((val, idx) => (
|
// const options = Array.from(tempvars[v]).map((val, idx) => (
|
||||||
<option value={val} key={idx}>{val}</option>
|
// <option value={val} key={idx}>{val}</option>
|
||||||
));
|
// ));
|
||||||
return (
|
// return (
|
||||||
<div key={v}>
|
// <div key={v}>
|
||||||
<label htmlFor={v}>{v}: </label>
|
// <label htmlFor={v}>{v}: </label>
|
||||||
<select name={v} id={v} onChange={handleVarValueSelect}>
|
// <select name={v} id={v} onChange={handleVarValueSelect}>
|
||||||
{options}
|
// {options}
|
||||||
</select>
|
// </select>
|
||||||
</div>
|
// </div>
|
||||||
);
|
// );
|
||||||
}));
|
// }));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -36,6 +36,7 @@ const PromptNode = ({ data, id }) => {
|
|||||||
const edges = useStore((state) => state.edges);
|
const edges = useStore((state) => state.edges);
|
||||||
const output = useStore((state) => state.output);
|
const output = useStore((state) => state.output);
|
||||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||||
|
const getNode = useStore((state) => state.getNode);
|
||||||
|
|
||||||
const [hovered, setHovered] = useState(false);
|
const [hovered, setHovered] = useState(false);
|
||||||
const [templateVars, setTemplateVars] = useState(data.vars || []);
|
const [templateVars, setTemplateVars] = useState(data.vars || []);
|
||||||
@ -105,24 +106,39 @@ const PromptNode = ({ data, id }) => {
|
|||||||
|
|
||||||
// Pull data from each source:
|
// Pull data from each source:
|
||||||
const pulled_data = {};
|
const pulled_data = {};
|
||||||
templateVars.forEach(varname => {
|
const get_outputs = (varnames, nodeId) => {
|
||||||
// Find the relevant edge (breaking once we've found it):
|
console.log(varnames);
|
||||||
for (let i = 0; i < edges.length; i++) {
|
varnames.forEach(varname => {
|
||||||
const e = edges[i];
|
// Find the relevant edge(s):
|
||||||
if (e.target == id && e.targetHandle == varname) {
|
edges.forEach(e => {
|
||||||
// Get the data output for that handle on the source node:
|
if (e.target == nodeId && e.targetHandle == varname) {
|
||||||
let out = output(e.source, e.sourceHandle);
|
// Get the immediate output:
|
||||||
if (!Array.isArray(out)) out = [out];
|
let out = output(e.source, e.sourceHandle);
|
||||||
if (varname in pulled_data)
|
|
||||||
pulled_data[varname] = pulled_data[varname].concat(out);
|
// Save the var data from the pulled output
|
||||||
else
|
if (varname in pulled_data)
|
||||||
pulled_data[varname] = out;
|
pulled_data[varname] = pulled_data[varname].concat(out);
|
||||||
}
|
else
|
||||||
}
|
pulled_data[varname] = out;
|
||||||
});
|
|
||||||
|
// Get any vars that the output depends on, and recursively collect those outputs as well:
|
||||||
|
const n_vars = getNode(e.source).data.vars;
|
||||||
|
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
|
||||||
|
get_outputs(n_vars, e.source);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
};
|
||||||
|
get_outputs(templateVars, id);
|
||||||
|
|
||||||
// Get Pythonic version of the prompt, by adding a $ before any template variables in braces:
|
// Get Pythonic version of the prompt, by adding a $ before any template variables in braces:
|
||||||
const py_prompt_template = promptText.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
|
const to_py_template_format = (str) => str.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
|
||||||
|
const py_prompt_template = to_py_template_format(promptText);
|
||||||
|
|
||||||
|
// Do the same for the vars, since vars can themselves be prompt templates:
|
||||||
|
Object.keys(pulled_data).forEach(varname => {
|
||||||
|
pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val));
|
||||||
|
});
|
||||||
|
|
||||||
// Run all prompt permutations through the LLM to generate + cache responses:
|
// Run all prompt permutations through the LLM to generate + cache responses:
|
||||||
fetch('http://localhost:5000/queryllm', {
|
fetch('http://localhost:5000/queryllm', {
|
||||||
|
@ -11,6 +11,17 @@ const union = (setA, setB) => {
|
|||||||
}
|
}
|
||||||
return _union;
|
return _union;
|
||||||
}
|
}
|
||||||
|
const setsAreEqual = (setA, setB) => {
|
||||||
|
if (setA.size !== setB.size) return false;
|
||||||
|
let equal = true;
|
||||||
|
for (const item of setA) {
|
||||||
|
if (!setB.has(item)) {
|
||||||
|
equal = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return equal;
|
||||||
|
}
|
||||||
|
|
||||||
const TextFieldsNode = ({ data, id }) => {
|
const TextFieldsNode = ({ data, id }) => {
|
||||||
|
|
||||||
@ -22,7 +33,6 @@ const TextFieldsNode = ({ data, id }) => {
|
|||||||
// Update the data for this text fields' id.
|
// Update the data for this text fields' id.
|
||||||
let new_data = { 'fields': {...data.fields} };
|
let new_data = { 'fields': {...data.fields} };
|
||||||
new_data.fields[event.target.id] = event.target.value;
|
new_data.fields[event.target.id] = event.target.value;
|
||||||
setDataPropsForNode(id, new_data);
|
|
||||||
|
|
||||||
// TODO: Optimize this check.
|
// TODO: Optimize this check.
|
||||||
let all_found_vars = new Set();
|
let all_found_vars = new Set();
|
||||||
@ -37,9 +47,14 @@ const TextFieldsNode = ({ data, id }) => {
|
|||||||
|
|
||||||
// Update template var fields + handles, if there's a change in sets
|
// Update template var fields + handles, if there's a change in sets
|
||||||
const past_vars = new Set(templateVars);
|
const past_vars = new Set(templateVars);
|
||||||
if (all_found_vars !== past_vars) {
|
if (!setsAreEqual(all_found_vars, past_vars)) {
|
||||||
setTemplateVars(Array.from(all_found_vars));
|
console.log('set vars');
|
||||||
|
const new_vars_arr = Array.from(all_found_vars);
|
||||||
|
new_data.vars = new_vars_arr;
|
||||||
|
setTemplateVars(new_vars_arr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setDataPropsForNode(id, new_data);
|
||||||
}, [data, id, setDataPropsForNode, templateVars]);
|
}, [data, id, setDataPropsForNode, templateVars]);
|
||||||
|
|
||||||
// Initialize fields (run once at init)
|
// Initialize fields (run once at init)
|
||||||
|
@ -109,8 +109,7 @@ class PromptPermutationGenerator:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if param is None:
|
if param is None:
|
||||||
print("Did not find any more params left to fill in current template. Returning empty list...")
|
return [template]
|
||||||
return []
|
|
||||||
|
|
||||||
# Generate new prompts by filling in its value(s) into the PromptTemplate
|
# Generate new prompts by filling in its value(s) into the PromptTemplate
|
||||||
val = paramDict[param]
|
val = paramDict[param]
|
||||||
@ -136,9 +135,9 @@ class PromptPermutationGenerator:
|
|||||||
return
|
return
|
||||||
|
|
||||||
for p in self._gen_perm(self.template, list(paramDict.keys()), paramDict):
|
for p in self._gen_perm(self.template, list(paramDict.keys()), paramDict):
|
||||||
|
print(p)
|
||||||
yield p
|
yield p
|
||||||
|
|
||||||
|
|
||||||
# Test cases
|
# Test cases
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Single template
|
# Single template
|
||||||
|
Loading…
x
Reference in New Issue
Block a user