diff --git a/chain-forge/src/InspectorNode.js b/chain-forge/src/InspectorNode.js
index d1d38da..e00ebc3 100644
--- a/chain-forge/src/InspectorNode.js
+++ b/chain-forge/src/InspectorNode.js
@@ -45,11 +45,15 @@ const InspectorNode = ({ data, id }) => {
// Bucket responses by LLM:
const responses_by_llm = bucketResponsesByLLM(json.responses);
- // Get the var names across responses
- // NOTE: This assumes only a single prompt node output as input
- // (all response vars have the exact same keys).
- let tempvars = {};
- Object.keys(json.responses[0].vars).forEach(v => {tempvars[v] = new Set();});
+ // // Get the var names across all responses, as a set
+ // let tempvarnames = new Set();
+ // json.responses.forEach(r => {
+ // if (!r.vars) return;
+ // Object.keys(r.vars).forEach(tempvarnames.add);
+ // });
+
+ // // Create a dict version
+ // let tempvars = {};
const vars_to_str = (vars) => {
const pairs = Object.keys(vars).map(varname => {
@@ -67,7 +71,7 @@ const InspectorNode = ({ data, id }) => {
const ps = res_obj.responses.map((r, idx) =>
(
@@ -84,19 +88,19 @@ const InspectorNode = ({ data, id }) => {
);
}));
- setVarSelects(Object.keys(tempvars).map(v => {
- const options = Array.from(tempvars[v]).map((val, idx) => (
-
{val}
- ));
- return (
-
- {v}:
-
- {options}
-
-
- );
- }));
+ // setVarSelects(Object.keys(tempvars).map(v => {
+ // const options = Array.from(tempvars[v]).map((val, idx) => (
+ //
{val}
+ // ));
+ // return (
+ //
+ // {v}:
+ //
+ // {options}
+ //
+ //
+ // );
+ // }));
}
});
}
diff --git a/chain-forge/src/PromptNode.js b/chain-forge/src/PromptNode.js
index 98263ad..4e31961 100644
--- a/chain-forge/src/PromptNode.js
+++ b/chain-forge/src/PromptNode.js
@@ -36,6 +36,7 @@ const PromptNode = ({ data, id }) => {
const edges = useStore((state) => state.edges);
const output = useStore((state) => state.output);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
+ const getNode = useStore((state) => state.getNode);
const [hovered, setHovered] = useState(false);
const [templateVars, setTemplateVars] = useState(data.vars || []);
@@ -105,24 +106,39 @@ const PromptNode = ({ data, id }) => {
// Pull data from each source:
const pulled_data = {};
- templateVars.forEach(varname => {
- // Find the relevant edge (breaking once we've found it):
- for (let i = 0; i < edges.length; i++) {
- const e = edges[i];
- if (e.target == id && e.targetHandle == varname) {
- // Get the data output for that handle on the source node:
- let out = output(e.source, e.sourceHandle);
- if (!Array.isArray(out)) out = [out];
- if (varname in pulled_data)
- pulled_data[varname] = pulled_data[varname].concat(out);
- else
- pulled_data[varname] = out;
- }
- }
- });
+ const get_outputs = (varnames, nodeId) => {
+ console.log(varnames);
+ varnames.forEach(varname => {
+ // Find the relevant edge(s):
+ edges.forEach(e => {
+ if (e.target == nodeId && e.targetHandle == varname) {
+ // Get the immediate output:
+ let out = output(e.source, e.sourceHandle);
+
+ // Save the var data from the pulled output
+ if (varname in pulled_data)
+ pulled_data[varname] = pulled_data[varname].concat(out);
+ else
+ pulled_data[varname] = out;
+
+ // Get any vars that the output depends on, and recursively collect those outputs as well:
+ const n_vars = getNode(e.source).data.vars;
+ if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
+ get_outputs(n_vars, e.source);
+ }
+ });
+ });
+ };
+ get_outputs(templateVars, id);
// Get Pythonic version of the prompt, by adding a $ before any template variables in braces:
- const py_prompt_template = promptText.replace(/(? str.replace(/(? {
+ pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val));
+ });
// Run all prompt permutations through the LLM to generate + cache responses:
fetch('http://localhost:5000/queryllm', {
diff --git a/chain-forge/src/TextFieldsNode.js b/chain-forge/src/TextFieldsNode.js
index 474ad98..5a4119f 100644
--- a/chain-forge/src/TextFieldsNode.js
+++ b/chain-forge/src/TextFieldsNode.js
@@ -11,6 +11,17 @@ const union = (setA, setB) => {
}
return _union;
}
+const setsAreEqual = (setA, setB) => {
+ if (setA.size !== setB.size) return false;
+ let equal = true;
+ for (const item of setA) {
+ if (!setB.has(item)) {
+ equal = false;
+ break;
+ }
+ }
+ return equal;
+}
const TextFieldsNode = ({ data, id }) => {
@@ -22,7 +33,6 @@ const TextFieldsNode = ({ data, id }) => {
// Update the data for this text fields' id.
let new_data = { 'fields': {...data.fields} };
new_data.fields[event.target.id] = event.target.value;
- setDataPropsForNode(id, new_data);
// TODO: Optimize this check.
let all_found_vars = new Set();
@@ -37,9 +47,14 @@ const TextFieldsNode = ({ data, id }) => {
// Update template var fields + handles, if there's a change in sets
const past_vars = new Set(templateVars);
- if (all_found_vars !== past_vars) {
- setTemplateVars(Array.from(all_found_vars));
+ if (!setsAreEqual(all_found_vars, past_vars)) {
+ console.log('set vars');
+ const new_vars_arr = Array.from(all_found_vars);
+ new_data.vars = new_vars_arr;
+ setTemplateVars(new_vars_arr);
}
+
+ setDataPropsForNode(id, new_data);
}, [data, id, setDataPropsForNode, templateVars]);
// Initialize fields (run once at init)
diff --git a/python-backend/promptengine/template.py b/python-backend/promptengine/template.py
index 41218d0..100d896 100644
--- a/python-backend/promptengine/template.py
+++ b/python-backend/promptengine/template.py
@@ -109,8 +109,7 @@ class PromptPermutationGenerator:
break
if param is None:
- print("Did not find any more params left to fill in current template. Returning empty list...")
- return []
+ return [template]
# Generate new prompts by filling in its value(s) into the PromptTemplate
val = paramDict[param]
@@ -136,9 +135,9 @@ class PromptPermutationGenerator:
return
for p in self._gen_perm(self.template, list(paramDict.keys()), paramDict):
+ print(p)
yield p
-
# Test cases
if __name__ == '__main__':
# Single template