Fixed bug in TFNode. WIP better InspectorNode

This commit is contained in:
Ian Arawjo 2023-05-13 22:35:01 -04:00
parent 8c3b2c58aa
commit f7fb238a6d
4 changed files with 84 additions and 85 deletions

View File

@ -1,9 +1,7 @@
import React, { useState, useEffect, useRef } from 'react';
import React, { useState, useRef } from 'react';
import { Handle } from 'react-flow-renderer';
import useStore from './store';
import StatusIndicator from './StatusIndicatorComponent'
import NodeLabel from './NodeLabelComponent'
import AlertModal from './AlertModal'
import { IconTerminal } from '@tabler/icons-react'
import {BASE_URL} from './store';
@ -27,9 +25,7 @@ const EvaluatorNode = ({ data, id }) => {
const [codeText, setCodeText] = useState(data.code);
const [codeTextOnLastRun, setCodeTextOnLastRun] = useState(false);
const [reduceMethod, setReduceMethod] = useState('none');
const [mapScope, setMapScope] = useState('response');
const [reduceVars, setReduceVars] = useState([]);
const handleCodeChange = (code) => {
if (codeTextOnLastRun !== false) {
@ -82,7 +78,7 @@ const EvaluatorNode = ({ data, id }) => {
code: codeTextOnRun,
scope: mapScope,
responses: input_node_ids,
reduce_vars: reduceMethod === 'avg' ? reduceVars : [],
reduce_vars: [], // reduceMethod === 'avg' ? reduceVars : [],
script_paths: script_paths,
// write an extra part here that takes in reduce func
}),
@ -112,37 +108,9 @@ const EvaluatorNode = ({ data, id }) => {
}, rejected);
};
const handleOnReduceMethodSelect = (event) => {
const method = event.target.value;
if (method === 'none') {
setReduceVars([]);
}
setReduceMethod(method);
};
const handleOnMapScopeSelect = (event) => {
setMapScope(event.target.value);
};
const handleReduceVarsChange = (event) => {
// Split on commas, ignoring commas wrapped in double-quotes
const regex_csv = /,(?!(?<=(?:^|,)\s*\x22(?:[^\x22]|\x22\x22|\\\x22)*,)(?:[^\x22]|\x22\x22|\\\x22)*\x22\s*(?:,|$))/g;
setReduceVars(event.target.value.split(regex_csv).map(s => s.trim()));
};
// To get CM editor state every render, use this and add ref={cmRef} to CodeMirror component
// const cmRef = React.useRef({});
// useEffect(() => {
// if (cmRef.current?.view) console.log('EditorView:', cmRef.current?.view);
// if (cmRef.current?.state) console.log('EditorState:', cmRef.current?.state);
// if (cmRef.current?.editor) {
// console.log('HTMLDivElement:', cmRef.current?.editor);
// }
// }, [cmRef.current]);
// const initEditor = (view, state) => {
// console.log(view, state);
// }
const hideStatusIndicator = () => {
if (status !== 'none') { setStatus('none'); }
@ -200,34 +168,17 @@ const EvaluatorNode = ({ data, id }) => {
}}
/>
</div>
{/* <CodeMirror
// onCreateEditor={initEditor}
value={data.code}
height="200px"
width="400px"
theme={materialLight}
style={{cursor: 'text'}}
onChange={handleCodeChange}
extensions={[python(), indentUnit.of(" ")]}
/> */}
</div>
<hr/>
{/* <hr/>
<div>
<div className="code-mirror-field-header">Method to reduce across <span className="code-style">responses</span>:</div>
<select name="method" id="method" onChange={handleOnReduceMethodSelect} className="nodrag">
<option value="none">None</option>
<option value="avg">Average across</option>
{/* <option value="custom">Custom reducer</option> */}
</select>
<span> </span>
<input type="text" id="method-vars" name="method-vars" onChange={handleReduceVarsChange} disabled={reduceMethod === 'none'} className="nodrag" />
{/* <label for="avg">Average over: </label>
<select name="avg" id="avg">
<option value="mod">mod</option>
<option value="paragraph">paragraph</option>
<option value="_none">N/A</option>
</select> */}
</div>
</div> */}
</div>
);
};

View File

@ -1,6 +1,6 @@
import React, { useState, useEffect } from 'react';
import { Handle } from 'react-flow-renderer';
import { Badge } from '@mantine/core';
import { Badge, MultiSelect } from '@mantine/core';
import useStore from './store';
import NodeLabel from './NodeLabelComponent'
import {BASE_URL} from './store';
@ -33,13 +33,13 @@ const bucketResponsesByLLM = (responses) => {
const InspectorNode = ({ data, id }) => {
const [responses, setResponses] = useState([]);
const [varSelects, setVarSelects] = useState([]);
const [pastInputs, setPastInputs] = useState([]);
const inputEdgesForNode = useStore((state) => state.inputEdgesForNode);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const handleVarValueSelect = () => {
}
// The MultiSelect so people can dynamically set what vars they care about
const [multiSelectVars, setMultiSelectVars] = useState(data.vars || []);
const [multiSelectValue, setMultiSelectValue] = useState(data.selected_vars || []);
const handleOnConnect = () => {
// Get the ids from the connected input nodes:
@ -59,19 +59,61 @@ const InspectorNode = ({ data, id }) => {
}).then(function(json) {
console.log(json);
if (json.responses && json.responses.length > 0) {
const responses = json.responses;
// Find all vars in response
let found_vars = new Set();
responses.forEach(res_obj => {
Object.keys(res_obj.vars).forEach(v => {
found_vars.add(v);
});
});
// Set the variables accessible in the MultiSelect for 'group by'
setMultiSelectVars(Array.from(found_vars).map(name => (
// We add a $ prefix to mark this as a prompt parameter, and so
// in the future we can add special types of variables without name collisions
{value: `${name}`, label: name}
)).concat({value: 'LLM', label: 'LLM'}));
// If this is an initial run or the multi select value is empty, set to group by 'LLM' by default:
let selected_vars = multiSelectValue;
if (multiSelectValue.length === 0) {
setMultiSelectValue(['LLM']);
selected_vars = ['LLM'];
}
// Now we need to perform groupings by each var in the selected vars list,
// nesting the groupings (preferrably with custom divs) and sorting within
// each group by value of that group's var (so all same values are clumped together).
/**
const groupBy = (resps, varnames) => {
if (varnames.length === 0) return [];
const groupName = varnames[0];
const groupedResponses = groupResponsesByVar(resps, groupName);
const groupedResponseDivs = groupedResponses.map(g => groupBy(g, varnames.slice(1)));
return (
<div key={groupName} className="response-group">
<span>{groupName}</span>
{groupedResponseDivs}
</div>
);
};
// Group by LLM
if (selected_vars.includes('LLM')) {
// ...
// Group without LLM
} else {
// ..
}
*/
// Bucket responses by LLM:
const responses_by_llm = bucketResponsesByLLM(json.responses);
// // Get the var names across all responses, as a set
// let tempvarnames = new Set();
// json.responses.forEach(r => {
// if (!r.vars) return;
// Object.keys(r.vars).forEach(tempvarnames.add);
// });
// // Create a dict version
// let tempvars = {};
const colors = ['#ace1aeb1', '#f1b963b1', '#e46161b1', '#f8f398b1', '#defcf9b1', '#cadefcb1', '#c3bef0b1', '#cca8e9b1'];
setResponses(Object.keys(responses_by_llm).map((llm, llm_idx) => {
@ -136,11 +178,15 @@ const InspectorNode = ({ data, id }) => {
<NodeLabel title={data.title || 'Inspect Node'}
nodeId={id}
icon={'🔍'} />
{/* <div className="var-select-toolbar">
{varSelects}
</div> */}
<MultiSelect onChange={setMultiSelectValue}
className='nodrag nowheel'
data={multiSelectVars}
placeholder="Pick vars to group responses, in order of importance"
size="xs"
value={multiSelectValue}
searchable />
<div className="inspect-response-container nowheel nodrag">
{responses}
{responses}
</div>
<Handle
type="target"

View File

@ -29,13 +29,16 @@ const TextFieldsNode = ({ data, id }) => {
const [templateVars, setTemplateVars] = useState(data.vars || []);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const delButtonId = 'del-';
const [idCounter, setIDCounter] = useState(0);
// const [resizeObserver, setResizeObserver] = useState(null);
const get_id = () => {
setIDCounter(idCounter + 1);
return 'f' + idCounter.toString();
}
const getUID = useCallback(() => {
if (data.fields) {
return 'f' + (1 + Object.keys(data.fields).reduce((acc, key) => (
Math.max(acc, parseInt(key.slice(1)))
), 0).toString());
} else {
return 'f0';
}
}, [data.fields]);
// Handle a change in a text fields' input.
const handleInputChange = useCallback((event) => {
@ -74,7 +77,7 @@ const TextFieldsNode = ({ data, id }) => {
delete new_data.fields[item_id];
// if the new_data is empty, initialize it with one empty field
if (Object.keys(new_data.fields).length === 0) {
new_data.fields[get_id()] = '';
new_data.fields[getUID()] = '';
}
setDataPropsForNode(id, new_data);
}, [data, id, setDataPropsForNode]);
@ -83,14 +86,14 @@ const TextFieldsNode = ({ data, id }) => {
const [fields, setFields] = useState([]);
useEffect(() => {
if (!data.fields)
setDataPropsForNode(id, { fields: {[get_id()]: ''}} );
setDataPropsForNode(id, { fields: {[getUID()]: ''}} );
}, []);
// Whenever 'data' changes, update the input fields to reflect the current state.
useEffect(() => {
const f = data.fields ? Object.keys(data.fields) : [];
const num_fields = f.length;
setFields(f.map((i, idx) => {
setFields(f.map((i) => {
const val = data.fields ? data.fields[i] : '';
return (
<div className="input-field" key={i}>
@ -104,7 +107,7 @@ const TextFieldsNode = ({ data, id }) => {
const handleAddField = useCallback(() => {
// Update the data for this text fields' id.
let new_data = { 'fields': {...data.fields} };
new_data.fields[get_id()] = "";
new_data.fields[getUID()] = "";
setDataPropsForNode(id, new_data);
}, [data, id, setDataPropsForNode]);
@ -129,7 +132,6 @@ const TextFieldsNode = ({ data, id }) => {
});
observer.observe(elem);
// setResizeObserver(observer);
}
ref.current = elem;
}, [ref, hooksY]);

View File

@ -49,7 +49,7 @@ const VisNode = ({ data, id }) => {
// Re-plot responses when anything changes
useEffect(() => {
if (!responses || responses.length === 0 || !multiSelectValue || multiSelectValue.length === 0) return;
if (!responses || responses.length === 0 || !multiSelectValue) return;
// Bucket responses by LLM:
let responses_by_llm = {};
@ -69,7 +69,7 @@ const VisNode = ({ data, id }) => {
width: 420, height: 300, title: '', margin: {
l: 105, r: 0, b: 20, t: 20, pad: 0
}
}
};
const plot_grouped_boxplot = (resp_to_x) => {
llm_names.forEach((llm, idx) => {
@ -83,7 +83,7 @@ const VisNode = ({ data, id }) => {
let text_items = [];
for (const name of names) {
rs.forEach(r => {
if (r.vars[varnames[0]].trim() !== name) return;
if (resp_to_x(r) !== name) return;
x_items = x_items.concat(r.eval_res.items).flat();
text_items = text_items.concat(createHoverTexts(r.responses)).flat();
y_items = y_items.concat(Array(r.eval_res.items.length).fill(truncStr(name, 12))).flat();