diff --git a/chain-forge/src/InspectorNode.js b/chain-forge/src/InspectorNode.js index 910d100..0828f40 100644 --- a/chain-forge/src/InspectorNode.js +++ b/chain-forge/src/InspectorNode.js @@ -1,4 +1,4 @@ -import React, { useState, useEffect } from 'react'; +import React, { useState, useEffect, useRef } from 'react'; import { Handle } from 'react-flow-renderer'; import { Badge, MultiSelect } from '@mantine/core'; import useStore from './store'; @@ -11,7 +11,14 @@ const truncStr = (s, maxLen) => { return s.substring(0, maxLen) + '...' else return s; -} +}; +const filterDict = (dict, keyFilterFunc) => { + return Object.keys(dict).reduce((acc, key) => { + if (keyFilterFunc(key) === true) + acc[key] = dict[key]; + return acc; + }, {}); +}; const vars_to_str = (vars) => { const pairs = Object.keys(vars).map(varname => { const s = truncStr(vars[varname].trim(), 12); @@ -36,6 +43,7 @@ const groupResponsesBy = (responses, keyFunc) => { const InspectorNode = ({ data, id }) => { const [responses, setResponses] = useState([]); + const [jsonResponses, setJSONResponses] = useState(null); const [pastInputs, setPastInputs] = useState([]); const inputEdgesForNode = useStore((state) => state.inputEdgesForNode); const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); @@ -44,6 +52,123 @@ const InspectorNode = ({ data, id }) => { const [multiSelectVars, setMultiSelectVars] = useState(data.vars || []); const [multiSelectValue, setMultiSelectValue] = useState(data.selected_vars || []); + // Update the visualization when the MultiSelect values change: + useEffect(() => { + if (!jsonResponses || (Array.isArray(jsonResponses) && jsonResponses.length === 0)) + return; + + const responses = jsonResponses; + const selected_vars = multiSelectValue; + + // Find all LLMs in responses and store as array + let found_llms = new Set(); + responses.forEach(res_obj => + found_llms.add(res_obj.llm)); + found_llms = Array.from(found_llms); + + // Assign a color to each LLM in responses + const llm_colors = ['#ace1aeb1', '#f1b963b1', '#e46161b1', '#f8f398b1', '#defcf9b1', '#cadefcb1', '#c3bef0b1', '#cca8e9b1']; + const color_for_llm = (llm) => llm_colors[found_llms.indexOf(llm) % llm_colors.length]; + const response_box_colors = ['#ddd', '#eee', '#ddd', '#eee']; + const rgroup_color = (depth) => response_box_colors[depth % response_box_colors.length]; + + const getHeaderBadge = (key, val) => { + if (val) { + const s = truncStr(val.trim(), 12); + const txt = `${key} = '${s}'`; + return ({txt}); + } else { + return ({`(unspecified ${key})`}); + } + }; + + // Now we need to perform groupings by each var in the selected vars list, + // nesting the groupings (preferrably with custom divs) and sorting within + // each group by value of that group's var (so all same values are clumped together). + // :: For instance, for varnames = ['LLM', '$var1', '$var2'] we should get back + // :: nested divs first grouped by LLM (first level), then by var1, then var2 (deepest level). + const groupByVars = (resps, varnames, eatenvars, header) => { + if (resps.length === 0) return []; + if (varnames.length === 0) { + // Base case. Display n response(s) to each single prompt, back-to-back: + const resp_boxes = resps.map((res_obj, res_idx) => { + // Spans for actual individual response texts + const ps = res_obj.responses.map((r, idx) => + (
{r}
) + ); + + // At the deepest level, there may still be some vars left over. We want to display these + // as tags, too, so we need to display only the ones that weren't 'eaten' during the recursive call: + // (e.g., the vars that weren't part of the initial 'varnames' list that form the groupings) + const unused_vars = filterDict(res_obj.vars, v => !eatenvars.includes(v)); + const vars = vars_to_str(unused_vars); + const var_tags = vars.map((v) => + ({v}) + ); + return ( +
+ {var_tags} + {eatenvars.includes('LLM') ? + ps + : (
+ {ps} +

{res_obj.llm}

+
) + } +
+ ); + }); + const className = eatenvars.length > 0 ? "response-group" : ""; + const boxesClassName = eatenvars.length > 0 ? "response-boxes-wrapper" : ""; + return ( +
+ {header} +
+ {resp_boxes} +
+
+ ); + } + + // Bucket responses by the first var in the list, where + // we also bucket any 'leftover' responses that didn't have the requested variable (a kind of 'soft fail') + const group_name = varnames[0]; + const [grouped_resps, leftover_resps] = (group_name === 'LLM') + ? groupResponsesBy(resps, (r => r.llm)) + : groupResponsesBy(resps, (r => ((group_name in r.vars) ? r.vars[group_name] : null))); + const get_header = (group_name === 'LLM') + ? ((key, val) => ({val})) + : ((key, val) => getHeaderBadge(key, val)); + + // Now produce nested divs corresponding to the groups + const remaining_vars = varnames.slice(1); + const updated_eatenvars = eatenvars.concat([group_name]); + const grouped_resps_divs = Object.keys(grouped_resps).map(g => groupByVars(grouped_resps[g], remaining_vars, updated_eatenvars, get_header(group_name, g))); + const leftover_resps_divs = leftover_resps.length > 0 ? groupByVars(leftover_resps, remaining_vars, updated_eatenvars, get_header(group_name, undefined)) : []; + + return (<> + {header ? + (
+ {header} +
+ {grouped_resps_divs} +
+
) + :
{grouped_resps_divs}
} + {leftover_resps_divs.length === 0 ? (<>) : ( +
+ {leftover_resps_divs} +
+ )} + ); + }; + + // Produce DIV elements grouped by selected vars + const divs = groupByVars(responses, selected_vars, [], null); + setResponses(divs); + + }, [multiSelectValue, multiSelectVars]); + const handleOnConnect = () => { // Get the ids from the connected input nodes: const input_node_ids = inputEdgesForNode(id).map(e => e.source); @@ -62,11 +187,10 @@ const InspectorNode = ({ data, id }) => { }).then(function(json) { console.log(json); if (json.responses && json.responses.length > 0) { - const responses = json.responses; - // Find all vars in response + // Find all vars in responses let found_vars = new Set(); - responses.forEach(res_obj => { + json.responses.forEach(res_obj => { Object.keys(res_obj.vars).forEach(v => { found_vars.add(v); }); @@ -86,108 +210,7 @@ const InspectorNode = ({ data, id }) => { selected_vars = ['LLM']; } - // Now we need to perform groupings by each var in the selected vars list, - // nesting the groupings (preferrably with custom divs) and sorting within - // each group by value of that group's var (so all same values are clumped together). - // :: For instance, for varnames = ['LLM', '$var1', '$var2'] we should get back - // :: nested divs first grouped by LLM (first level), then by var1, then var2 (deepest level). - /** - const groupByVars = (resps, varnames, eatenvars) => { - if (resps.length === 0) return []; - if (varnames.length === 0) { - // Base case. Display n response(s) to each single prompt, back-to-back: - return resps.map((res_obj, res_idx) => { - // Spans for actual individual response texts - const ps = res_obj.responses.map((r, idx) => - (
{r}
) - ); - - // At the deepest level, there may still be some vars left over. We want to display these - // as tags, too, so we need to display only the ones that weren't 'eaten' during the recursive call: - // (e.g., the vars that weren't part of the initial 'varnames' list that form the groupings) - const vars = vars_to_str(res_obj.vars.filter(v => !eatenvars.includes(v))); - const var_tags = vars.map((v) => - ({v}) - ); - return ( -
- {var_tags} - {ps} -
- ); - }); - } - - // Bucket responses by the first var in the list, where - // we also bucket any 'leftover' responses that didn't have the requested variable (a kind of 'soft fail') - const group_name = varnames[0]; - const [grouped_resps, leftover_resps] = (group_name === 'LLM') - ? groupResponsesBy(resps, (r => r.llm)) - : groupResponsesBy(resps, (r => ((group_name in r.vars) ? r.vars[group_name] : null))); - // Now produce nested divs corresponding to the groups - const remaining_vars = varnames.slice(1); - const updated_eatenvars = eatenvars.concat([group_name]); - const grouped_resps_divs = grouped_resps.map(g => groupByVars(g, remaining_vars, updated_eatenvars)); - const leftover_resps_divs = leftover_resps.length > 0 ? groupByVars(leftover_resps, remaining_vars, updated_eatenvars) : []; - - return (<> -
-

{group_name}

- {grouped_resps_divs} -
- {leftover_resps_divs.length === 0 ? (<>) : ( -
- {leftover_resps_divs} -
- )} - ); - }; - - // Produce DIV elements grouped by selected vars - groupByVars(responses, selected_vars, []); - **/ - - // Bucket responses by LLM: - const responses_by_llm = groupResponsesBy(responses, (r => r.llm)); - - const colors = ['#ace1aeb1', '#f1b963b1', '#e46161b1', '#f8f398b1', '#defcf9b1', '#cadefcb1', '#c3bef0b1', '#cca8e9b1']; - setResponses(Object.keys(responses_by_llm).map((llm, llm_idx) => { - const res_divs = responses_by_llm[llm].map((res_obj, res_idx) => { - const ps = res_obj.responses.map((r, idx) => - (
{r}
) - ); - const vars = vars_to_str(res_obj.vars); - const var_tags = vars.map((v) => ( - {v} - )); - return ( -
- {var_tags} - {ps} -
- ); - }); - return ( -
-

{llm}

- {res_divs} -
- ); - })); - - // setVarSelects(Object.keys(tempvars).map(v => { - // const options = Array.from(tempvars[v]).map((val, idx) => ( - // - // )); - // return ( - //
- // - // - //
- // ); - // })); + setJSONResponses(json.responses); } }); } @@ -206,20 +229,33 @@ const InspectorNode = ({ data, id }) => { setDataPropsForNode(id, { refresh: false }); handleOnConnect(); } -}, [data, id, handleOnConnect, setDataPropsForNode]); + }, [data, id, handleOnConnect, setDataPropsForNode]); + + // When the user clicks an item in the drop-down, + // we want to autoclose the multiselect drop-down: + const multiSelectRef = useRef(null); + const handleMultiSelectValueChange = (new_val) => { + if (multiSelectRef) { + multiSelectRef.current.blur(); + } + setMultiSelectValue(new_val); + }; return (
- Group responses by (order matters):} data={multiSelectVars} placeholder="Pick vars to group responses, in order of importance" size="xs" value={multiSelectValue} - searchable /> + clearSearchOnChange={true} + clearSearchOnBlur={true} />
{responses}
diff --git a/chain-forge/src/text-fields-node.css b/chain-forge/src/text-fields-node.css index 2a3f71f..96f1efe 100644 --- a/chain-forge/src/text-fields-node.css +++ b/chain-forge/src/text-fields-node.css @@ -165,9 +165,9 @@ border-radius: 5px; } .inspect-response-container { - overflow-y: auto; + overflow-y: scroll; min-width: 150px; - max-width: 450px; + max-width: 650px; max-height: 650px; resize: both; } @@ -194,6 +194,42 @@ padding-bottom: 0px; color: #222; } + + .llm-group-header { + font-weight: 400; + font-size: 10pt; + margin: 6px 8px 4px 8px; + padding-top: 2px; + padding-bottom: 0px; + color: #222; + } + + .response-group { + margin: 2px 0px 8px 0px; + padding: 2px 2px 2px 2px; + /* border-radius: 7px; */ + } + .response-boxes-wrapper { + margin-top: 4px; + padding-left: 10px; + border-left-width: 2px; + border-left-style: solid; + border-left-color: #bbb; + } + .response-item-llm-name-wrapper { + padding-bottom: 0px; + } + .response-item-llm-name-wrapper h1 { + font-size: 8pt; + font-weight: bold; + font-style: italic; + color: #000; + opacity: 0.7; + text-align: right; + padding-right: 4px; + margin: 0px; + } + .response-preview-container { margin: 10px -9px -9px -9px; max-height: 100px;