Prompt template chaining

This commit is contained in:
Ian Arawjo 2023-05-02 09:32:35 -04:00
parent 775b61e89c
commit 31ebc91942
4 changed files with 75 additions and 41 deletions
chain-forge/src
python-backend/promptengine

@ -45,11 +45,15 @@ const InspectorNode = ({ data, id }) => {
// Bucket responses by LLM: // Bucket responses by LLM:
const responses_by_llm = bucketResponsesByLLM(json.responses); const responses_by_llm = bucketResponsesByLLM(json.responses);
// Get the var names across responses // // Get the var names across all responses, as a set
// NOTE: This assumes only a single prompt node output as input // let tempvarnames = new Set();
// (all response vars have the exact same keys). // json.responses.forEach(r => {
let tempvars = {}; // if (!r.vars) return;
Object.keys(json.responses[0].vars).forEach(v => {tempvars[v] = new Set();}); // Object.keys(r.vars).forEach(tempvarnames.add);
// });
// // Create a dict version
// let tempvars = {};
const vars_to_str = (vars) => { const vars_to_str = (vars) => {
const pairs = Object.keys(vars).map(varname => { const pairs = Object.keys(vars).map(varname => {
@ -67,7 +71,7 @@ const InspectorNode = ({ data, id }) => {
const ps = res_obj.responses.map((r, idx) => const ps = res_obj.responses.map((r, idx) =>
(<pre className="small-response" key={idx}>{r}</pre>) (<pre className="small-response" key={idx}>{r}</pre>)
); );
Object.keys(res_obj.vars).forEach(v => {tempvars[v].add(res_obj.vars[v])}); // Object.keys(res_obj.vars).forEach(v => {tempvars[v].add(res_obj.vars[v])});
const vars = vars_to_str(res_obj.vars); const vars = vars_to_str(res_obj.vars);
return ( return (
<div key={"r"+res_idx} className="response-box" style={{ backgroundColor: colors[llm_idx % colors.length] }}> <div key={"r"+res_idx} className="response-box" style={{ backgroundColor: colors[llm_idx % colors.length] }}>
@ -84,19 +88,19 @@ const InspectorNode = ({ data, id }) => {
); );
})); }));
setVarSelects(Object.keys(tempvars).map(v => { // setVarSelects(Object.keys(tempvars).map(v => {
const options = Array.from(tempvars[v]).map((val, idx) => ( // const options = Array.from(tempvars[v]).map((val, idx) => (
<option value={val} key={idx}>{val}</option> // <option value={val} key={idx}>{val}</option>
)); // ));
return ( // return (
<div key={v}> // <div key={v}>
<label htmlFor={v}>{v}: </label> // <label htmlFor={v}>{v}: </label>
<select name={v} id={v} onChange={handleVarValueSelect}> // <select name={v} id={v} onChange={handleVarValueSelect}>
{options} // {options}
</select> // </select>
</div> // </div>
); // );
})); // }));
} }
}); });
} }

@ -36,6 +36,7 @@ const PromptNode = ({ data, id }) => {
const edges = useStore((state) => state.edges); const edges = useStore((state) => state.edges);
const output = useStore((state) => state.output); const output = useStore((state) => state.output);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const getNode = useStore((state) => state.getNode);
const [hovered, setHovered] = useState(false); const [hovered, setHovered] = useState(false);
const [templateVars, setTemplateVars] = useState(data.vars || []); const [templateVars, setTemplateVars] = useState(data.vars || []);
@ -105,24 +106,39 @@ const PromptNode = ({ data, id }) => {
// Pull data from each source: // Pull data from each source:
const pulled_data = {}; const pulled_data = {};
templateVars.forEach(varname => { const get_outputs = (varnames, nodeId) => {
// Find the relevant edge (breaking once we've found it): console.log(varnames);
for (let i = 0; i < edges.length; i++) { varnames.forEach(varname => {
const e = edges[i]; // Find the relevant edge(s):
if (e.target == id && e.targetHandle == varname) { edges.forEach(e => {
// Get the data output for that handle on the source node: if (e.target == nodeId && e.targetHandle == varname) {
let out = output(e.source, e.sourceHandle); // Get the immediate output:
if (!Array.isArray(out)) out = [out]; let out = output(e.source, e.sourceHandle);
if (varname in pulled_data)
pulled_data[varname] = pulled_data[varname].concat(out); // Save the var data from the pulled output
else if (varname in pulled_data)
pulled_data[varname] = out; pulled_data[varname] = pulled_data[varname].concat(out);
} else
} pulled_data[varname] = out;
});
// Get any vars that the output depends on, and recursively collect those outputs as well:
const n_vars = getNode(e.source).data.vars;
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
get_outputs(n_vars, e.source);
}
});
});
};
get_outputs(templateVars, id);
// Get Pythonic version of the prompt, by adding a $ before any template variables in braces: // Get Pythonic version of the prompt, by adding a $ before any template variables in braces:
const py_prompt_template = promptText.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}") const to_py_template_format = (str) => str.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
const py_prompt_template = to_py_template_format(promptText);
// Do the same for the vars, since vars can themselves be prompt templates:
Object.keys(pulled_data).forEach(varname => {
pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val));
});
// Run all prompt permutations through the LLM to generate + cache responses: // Run all prompt permutations through the LLM to generate + cache responses:
fetch('http://localhost:5000/queryllm', { fetch('http://localhost:5000/queryllm', {

@ -11,6 +11,17 @@ const union = (setA, setB) => {
} }
return _union; return _union;
} }
const setsAreEqual = (setA, setB) => {
if (setA.size !== setB.size) return false;
let equal = true;
for (const item of setA) {
if (!setB.has(item)) {
equal = false;
break;
}
}
return equal;
}
const TextFieldsNode = ({ data, id }) => { const TextFieldsNode = ({ data, id }) => {
@ -22,7 +33,6 @@ const TextFieldsNode = ({ data, id }) => {
// Update the data for this text fields' id. // Update the data for this text fields' id.
let new_data = { 'fields': {...data.fields} }; let new_data = { 'fields': {...data.fields} };
new_data.fields[event.target.id] = event.target.value; new_data.fields[event.target.id] = event.target.value;
setDataPropsForNode(id, new_data);
// TODO: Optimize this check. // TODO: Optimize this check.
let all_found_vars = new Set(); let all_found_vars = new Set();
@ -37,9 +47,14 @@ const TextFieldsNode = ({ data, id }) => {
// Update template var fields + handles, if there's a change in sets // Update template var fields + handles, if there's a change in sets
const past_vars = new Set(templateVars); const past_vars = new Set(templateVars);
if (all_found_vars !== past_vars) { if (!setsAreEqual(all_found_vars, past_vars)) {
setTemplateVars(Array.from(all_found_vars)); console.log('set vars');
const new_vars_arr = Array.from(all_found_vars);
new_data.vars = new_vars_arr;
setTemplateVars(new_vars_arr);
} }
setDataPropsForNode(id, new_data);
}, [data, id, setDataPropsForNode, templateVars]); }, [data, id, setDataPropsForNode, templateVars]);
// Initialize fields (run once at init) // Initialize fields (run once at init)

@ -109,8 +109,7 @@ class PromptPermutationGenerator:
break break
if param is None: if param is None:
print("Did not find any more params left to fill in current template. Returning empty list...") return [template]
return []
# Generate new prompts by filling in its value(s) into the PromptTemplate # Generate new prompts by filling in its value(s) into the PromptTemplate
val = paramDict[param] val = paramDict[param]
@ -136,9 +135,9 @@ class PromptPermutationGenerator:
return return
for p in self._gen_perm(self.template, list(paramDict.keys()), paramDict): for p in self._gen_perm(self.template, list(paramDict.keys()), paramDict):
print(p)
yield p yield p
# Test cases # Test cases
if __name__ == '__main__': if __name__ == '__main__':
# Single template # Single template