diff --git a/chain-forge/src/InspectorNode.js b/chain-forge/src/InspectorNode.js
index 910d100..0828f40 100644
--- a/chain-forge/src/InspectorNode.js
+++ b/chain-forge/src/InspectorNode.js
@@ -1,4 +1,4 @@
-import React, { useState, useEffect } from 'react';
+import React, { useState, useEffect, useRef } from 'react';
import { Handle } from 'react-flow-renderer';
import { Badge, MultiSelect } from '@mantine/core';
import useStore from './store';
@@ -11,7 +11,14 @@ const truncStr = (s, maxLen) => {
return s.substring(0, maxLen) + '...'
else
return s;
-}
+};
+const filterDict = (dict, keyFilterFunc) => {
+ return Object.keys(dict).reduce((acc, key) => {
+ if (keyFilterFunc(key) === true)
+ acc[key] = dict[key];
+ return acc;
+ }, {});
+};
const vars_to_str = (vars) => {
const pairs = Object.keys(vars).map(varname => {
const s = truncStr(vars[varname].trim(), 12);
@@ -36,6 +43,7 @@ const groupResponsesBy = (responses, keyFunc) => {
const InspectorNode = ({ data, id }) => {
const [responses, setResponses] = useState([]);
+ const [jsonResponses, setJSONResponses] = useState(null);
const [pastInputs, setPastInputs] = useState([]);
const inputEdgesForNode = useStore((state) => state.inputEdgesForNode);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
@@ -44,6 +52,123 @@ const InspectorNode = ({ data, id }) => {
const [multiSelectVars, setMultiSelectVars] = useState(data.vars || []);
const [multiSelectValue, setMultiSelectValue] = useState(data.selected_vars || []);
+ // Update the visualization when the MultiSelect values change:
+ useEffect(() => {
+ if (!jsonResponses || (Array.isArray(jsonResponses) && jsonResponses.length === 0))
+ return;
+
+ const responses = jsonResponses;
+ const selected_vars = multiSelectValue;
+
+ // Find all LLMs in responses and store as array
+ let found_llms = new Set();
+ responses.forEach(res_obj =>
+ found_llms.add(res_obj.llm));
+ found_llms = Array.from(found_llms);
+
+ // Assign a color to each LLM in responses
+ const llm_colors = ['#ace1aeb1', '#f1b963b1', '#e46161b1', '#f8f398b1', '#defcf9b1', '#cadefcb1', '#c3bef0b1', '#cca8e9b1'];
+ const color_for_llm = (llm) => llm_colors[found_llms.indexOf(llm) % llm_colors.length];
+ const response_box_colors = ['#ddd', '#eee', '#ddd', '#eee'];
+ const rgroup_color = (depth) => response_box_colors[depth % response_box_colors.length];
+
+ const getHeaderBadge = (key, val) => {
+ if (val) {
+ const s = truncStr(val.trim(), 12);
+ const txt = `${key} = '${s}'`;
+ return ({txt});
+ } else {
+ return ({`(unspecified ${key})`});
+ }
+ };
+
+ // Now we need to perform groupings by each var in the selected vars list,
+ // nesting the groupings (preferrably with custom divs) and sorting within
+ // each group by value of that group's var (so all same values are clumped together).
+ // :: For instance, for varnames = ['LLM', '$var1', '$var2'] we should get back
+ // :: nested divs first grouped by LLM (first level), then by var1, then var2 (deepest level).
+ const groupByVars = (resps, varnames, eatenvars, header) => {
+ if (resps.length === 0) return [];
+ if (varnames.length === 0) {
+ // Base case. Display n response(s) to each single prompt, back-to-back:
+ const resp_boxes = resps.map((res_obj, res_idx) => {
+ // Spans for actual individual response texts
+ const ps = res_obj.responses.map((r, idx) =>
+ (
{r}
)
+ );
+
+ // At the deepest level, there may still be some vars left over. We want to display these
+ // as tags, too, so we need to display only the ones that weren't 'eaten' during the recursive call:
+ // (e.g., the vars that weren't part of the initial 'varnames' list that form the groupings)
+ const unused_vars = filterDict(res_obj.vars, v => !eatenvars.includes(v));
+ const vars = vars_to_str(unused_vars);
+ const var_tags = vars.map((v) =>
+ ({v})
+ );
+ return (
+
+ )}
+ >);
+ };
+
+ // Produce DIV elements grouped by selected vars
+ const divs = groupByVars(responses, selected_vars, [], null);
+ setResponses(divs);
+
+ }, [multiSelectValue, multiSelectVars]);
+
const handleOnConnect = () => {
// Get the ids from the connected input nodes:
const input_node_ids = inputEdgesForNode(id).map(e => e.source);
@@ -62,11 +187,10 @@ const InspectorNode = ({ data, id }) => {
}).then(function(json) {
console.log(json);
if (json.responses && json.responses.length > 0) {
- const responses = json.responses;
- // Find all vars in response
+ // Find all vars in responses
let found_vars = new Set();
- responses.forEach(res_obj => {
+ json.responses.forEach(res_obj => {
Object.keys(res_obj.vars).forEach(v => {
found_vars.add(v);
});
@@ -86,108 +210,7 @@ const InspectorNode = ({ data, id }) => {
selected_vars = ['LLM'];
}
- // Now we need to perform groupings by each var in the selected vars list,
- // nesting the groupings (preferrably with custom divs) and sorting within
- // each group by value of that group's var (so all same values are clumped together).
- // :: For instance, for varnames = ['LLM', '$var1', '$var2'] we should get back
- // :: nested divs first grouped by LLM (first level), then by var1, then var2 (deepest level).
- /**
- const groupByVars = (resps, varnames, eatenvars) => {
- if (resps.length === 0) return [];
- if (varnames.length === 0) {
- // Base case. Display n response(s) to each single prompt, back-to-back:
- return resps.map((res_obj, res_idx) => {
- // Spans for actual individual response texts
- const ps = res_obj.responses.map((r, idx) =>
- (
{r}
)
- );
-
- // At the deepest level, there may still be some vars left over. We want to display these
- // as tags, too, so we need to display only the ones that weren't 'eaten' during the recursive call:
- // (e.g., the vars that weren't part of the initial 'varnames' list that form the groupings)
- const vars = vars_to_str(res_obj.vars.filter(v => !eatenvars.includes(v)));
- const var_tags = vars.map((v) =>
- ({v})
- );
- return (
-
- {var_tags}
- {ps}
-
- );
- });
- }
-
- // Bucket responses by the first var in the list, where
- // we also bucket any 'leftover' responses that didn't have the requested variable (a kind of 'soft fail')
- const group_name = varnames[0];
- const [grouped_resps, leftover_resps] = (group_name === 'LLM')
- ? groupResponsesBy(resps, (r => r.llm))
- : groupResponsesBy(resps, (r => ((group_name in r.vars) ? r.vars[group_name] : null)));
- // Now produce nested divs corresponding to the groups
- const remaining_vars = varnames.slice(1);
- const updated_eatenvars = eatenvars.concat([group_name]);
- const grouped_resps_divs = grouped_resps.map(g => groupByVars(g, remaining_vars, updated_eatenvars));
- const leftover_resps_divs = leftover_resps.length > 0 ? groupByVars(leftover_resps, remaining_vars, updated_eatenvars) : [];
-
- return (<>
-
- // );
- // }));
+ setJSONResponses(json.responses);
}
});
}
@@ -206,20 +229,33 @@ const InspectorNode = ({ data, id }) => {
setDataPropsForNode(id, { refresh: false });
handleOnConnect();
}
-}, [data, id, handleOnConnect, setDataPropsForNode]);
+ }, [data, id, handleOnConnect, setDataPropsForNode]);
+
+ // When the user clicks an item in the drop-down,
+ // we want to autoclose the multiselect drop-down:
+ const multiSelectRef = useRef(null);
+ const handleMultiSelectValueChange = (new_val) => {
+ if (multiSelectRef) {
+ multiSelectRef.current.blur();
+ }
+ setMultiSelectValue(new_val);
+ };
return (
- Group responses by (order matters):}
data={multiSelectVars}
placeholder="Pick vars to group responses, in order of importance"
size="xs"
value={multiSelectValue}
- searchable />
+ clearSearchOnChange={true}
+ clearSearchOnBlur={true} />