diff --git a/chain-forge/src/EvaluatorNode.js b/chain-forge/src/EvaluatorNode.js index 662e14c..d4d0ba7 100644 --- a/chain-forge/src/EvaluatorNode.js +++ b/chain-forge/src/EvaluatorNode.js @@ -1,9 +1,7 @@ -import React, { useState, useEffect, useRef } from 'react'; +import React, { useState, useRef } from 'react'; import { Handle } from 'react-flow-renderer'; import useStore from './store'; -import StatusIndicator from './StatusIndicatorComponent' import NodeLabel from './NodeLabelComponent' -import AlertModal from './AlertModal' import { IconTerminal } from '@tabler/icons-react' import {BASE_URL} from './store'; @@ -27,9 +25,7 @@ const EvaluatorNode = ({ data, id }) => { const [codeText, setCodeText] = useState(data.code); const [codeTextOnLastRun, setCodeTextOnLastRun] = useState(false); - const [reduceMethod, setReduceMethod] = useState('none'); const [mapScope, setMapScope] = useState('response'); - const [reduceVars, setReduceVars] = useState([]); const handleCodeChange = (code) => { if (codeTextOnLastRun !== false) { @@ -82,7 +78,7 @@ const EvaluatorNode = ({ data, id }) => { code: codeTextOnRun, scope: mapScope, responses: input_node_ids, - reduce_vars: reduceMethod === 'avg' ? reduceVars : [], + reduce_vars: [], // reduceMethod === 'avg' ? reduceVars : [], script_paths: script_paths, // write an extra part here that takes in reduce func }), @@ -112,37 +108,9 @@ const EvaluatorNode = ({ data, id }) => { }, rejected); }; - const handleOnReduceMethodSelect = (event) => { - const method = event.target.value; - if (method === 'none') { - setReduceVars([]); - } - setReduceMethod(method); - }; - const handleOnMapScopeSelect = (event) => { setMapScope(event.target.value); }; - - const handleReduceVarsChange = (event) => { - // Split on commas, ignoring commas wrapped in double-quotes - const regex_csv = /,(?!(?<=(?:^|,)\s*\x22(?:[^\x22]|\x22\x22|\\\x22)*,)(?:[^\x22]|\x22\x22|\\\x22)*\x22\s*(?:,|$))/g; - setReduceVars(event.target.value.split(regex_csv).map(s => s.trim())); - }; - - // To get CM editor state every render, use this and add ref={cmRef} to CodeMirror component - // const cmRef = React.useRef({}); - // useEffect(() => { - // if (cmRef.current?.view) console.log('EditorView:', cmRef.current?.view); - // if (cmRef.current?.state) console.log('EditorState:', cmRef.current?.state); - // if (cmRef.current?.editor) { - // console.log('HTMLDivElement:', cmRef.current?.editor); - // } - // }, [cmRef.current]); - - // const initEditor = (view, state) => { - // console.log(view, state); - // } const hideStatusIndicator = () => { if (status !== 'none') { setStatus('none'); } @@ -200,34 +168,17 @@ const EvaluatorNode = ({ data, id }) => { }} /> - {/* */} -
+ {/*
Method to reduce across responses:
- {/* - */} -
+ */} ); }; diff --git a/chain-forge/src/InspectorNode.js b/chain-forge/src/InspectorNode.js index f107dc0..32ff9a5 100644 --- a/chain-forge/src/InspectorNode.js +++ b/chain-forge/src/InspectorNode.js @@ -1,6 +1,6 @@ import React, { useState, useEffect } from 'react'; import { Handle } from 'react-flow-renderer'; -import { Badge } from '@mantine/core'; +import { Badge, MultiSelect } from '@mantine/core'; import useStore from './store'; import NodeLabel from './NodeLabelComponent' import {BASE_URL} from './store'; @@ -33,13 +33,13 @@ const bucketResponsesByLLM = (responses) => { const InspectorNode = ({ data, id }) => { const [responses, setResponses] = useState([]); - const [varSelects, setVarSelects] = useState([]); const [pastInputs, setPastInputs] = useState([]); const inputEdgesForNode = useStore((state) => state.inputEdgesForNode); const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); - const handleVarValueSelect = () => { - } + // The MultiSelect so people can dynamically set what vars they care about + const [multiSelectVars, setMultiSelectVars] = useState(data.vars || []); + const [multiSelectValue, setMultiSelectValue] = useState(data.selected_vars || []); const handleOnConnect = () => { // Get the ids from the connected input nodes: @@ -59,19 +59,61 @@ const InspectorNode = ({ data, id }) => { }).then(function(json) { console.log(json); if (json.responses && json.responses.length > 0) { + const responses = json.responses; + // Find all vars in response + let found_vars = new Set(); + responses.forEach(res_obj => { + Object.keys(res_obj.vars).forEach(v => { + found_vars.add(v); + }); + }); + + // Set the variables accessible in the MultiSelect for 'group by' + setMultiSelectVars(Array.from(found_vars).map(name => ( + // We add a $ prefix to mark this as a prompt parameter, and so + // in the future we can add special types of variables without name collisions + {value: `${name}`, label: name} + )).concat({value: 'LLM', label: 'LLM'})); + + // If this is an initial run or the multi select value is empty, set to group by 'LLM' by default: + let selected_vars = multiSelectValue; + if (multiSelectValue.length === 0) { + setMultiSelectValue(['LLM']); + selected_vars = ['LLM']; + } + + // Now we need to perform groupings by each var in the selected vars list, + // nesting the groupings (preferrably with custom divs) and sorting within + // each group by value of that group's var (so all same values are clumped together). + /** + const groupBy = (resps, varnames) => { + if (varnames.length === 0) return []; + + const groupName = varnames[0]; + const groupedResponses = groupResponsesByVar(resps, groupName); + const groupedResponseDivs = groupedResponses.map(g => groupBy(g, varnames.slice(1))); + + return ( +
+ {groupName} + {groupedResponseDivs} +
+ ); + }; + + // Group by LLM + if (selected_vars.includes('LLM')) { + // ... + + // Group without LLM + } else { + // .. + } + */ + // Bucket responses by LLM: const responses_by_llm = bucketResponsesByLLM(json.responses); - - // // Get the var names across all responses, as a set - // let tempvarnames = new Set(); - // json.responses.forEach(r => { - // if (!r.vars) return; - // Object.keys(r.vars).forEach(tempvarnames.add); - // }); - - // // Create a dict version - // let tempvars = {}; const colors = ['#ace1aeb1', '#f1b963b1', '#e46161b1', '#f8f398b1', '#defcf9b1', '#cadefcb1', '#c3bef0b1', '#cca8e9b1']; setResponses(Object.keys(responses_by_llm).map((llm, llm_idx) => { @@ -136,11 +178,15 @@ const InspectorNode = ({ data, id }) => { - {/*
- {varSelects} -
*/} +
- {responses} + {responses}
{ const [templateVars, setTemplateVars] = useState(data.vars || []); const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); const delButtonId = 'del-'; - const [idCounter, setIDCounter] = useState(0); - // const [resizeObserver, setResizeObserver] = useState(null); - const get_id = () => { - setIDCounter(idCounter + 1); - return 'f' + idCounter.toString(); - } + const getUID = useCallback(() => { + if (data.fields) { + return 'f' + (1 + Object.keys(data.fields).reduce((acc, key) => ( + Math.max(acc, parseInt(key.slice(1))) + ), 0).toString()); + } else { + return 'f0'; + } + }, [data.fields]); // Handle a change in a text fields' input. const handleInputChange = useCallback((event) => { @@ -74,7 +77,7 @@ const TextFieldsNode = ({ data, id }) => { delete new_data.fields[item_id]; // if the new_data is empty, initialize it with one empty field if (Object.keys(new_data.fields).length === 0) { - new_data.fields[get_id()] = ''; + new_data.fields[getUID()] = ''; } setDataPropsForNode(id, new_data); }, [data, id, setDataPropsForNode]); @@ -83,14 +86,14 @@ const TextFieldsNode = ({ data, id }) => { const [fields, setFields] = useState([]); useEffect(() => { if (!data.fields) - setDataPropsForNode(id, { fields: {[get_id()]: ''}} ); + setDataPropsForNode(id, { fields: {[getUID()]: ''}} ); }, []); // Whenever 'data' changes, update the input fields to reflect the current state. useEffect(() => { const f = data.fields ? Object.keys(data.fields) : []; const num_fields = f.length; - setFields(f.map((i, idx) => { + setFields(f.map((i) => { const val = data.fields ? data.fields[i] : ''; return (
@@ -104,7 +107,7 @@ const TextFieldsNode = ({ data, id }) => { const handleAddField = useCallback(() => { // Update the data for this text fields' id. let new_data = { 'fields': {...data.fields} }; - new_data.fields[get_id()] = ""; + new_data.fields[getUID()] = ""; setDataPropsForNode(id, new_data); }, [data, id, setDataPropsForNode]); @@ -129,7 +132,6 @@ const TextFieldsNode = ({ data, id }) => { }); observer.observe(elem); - // setResizeObserver(observer); } ref.current = elem; }, [ref, hooksY]); diff --git a/chain-forge/src/VisNode.js b/chain-forge/src/VisNode.js index 3087e29..2d320cd 100644 --- a/chain-forge/src/VisNode.js +++ b/chain-forge/src/VisNode.js @@ -49,7 +49,7 @@ const VisNode = ({ data, id }) => { // Re-plot responses when anything changes useEffect(() => { - if (!responses || responses.length === 0 || !multiSelectValue || multiSelectValue.length === 0) return; + if (!responses || responses.length === 0 || !multiSelectValue) return; // Bucket responses by LLM: let responses_by_llm = {}; @@ -69,7 +69,7 @@ const VisNode = ({ data, id }) => { width: 420, height: 300, title: '', margin: { l: 105, r: 0, b: 20, t: 20, pad: 0 } - } + }; const plot_grouped_boxplot = (resp_to_x) => { llm_names.forEach((llm, idx) => { @@ -83,7 +83,7 @@ const VisNode = ({ data, id }) => { let text_items = []; for (const name of names) { rs.forEach(r => { - if (r.vars[varnames[0]].trim() !== name) return; + if (resp_to_x(r) !== name) return; x_items = x_items.concat(r.eval_res.items).flat(); text_items = text_items.concat(createHoverTexts(r.responses)).flat(); y_items = y_items.concat(Array(r.eval_res.items.length).fill(truncStr(name, 12))).flat();