mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Fixed bug in TFNode. WIP better InspectorNode
This commit is contained in:
parent
8c3b2c58aa
commit
f7fb238a6d
@ -1,9 +1,7 @@
|
||||
import React, { useState, useEffect, useRef } from 'react';
|
||||
import React, { useState, useRef } from 'react';
|
||||
import { Handle } from 'react-flow-renderer';
|
||||
import useStore from './store';
|
||||
import StatusIndicator from './StatusIndicatorComponent'
|
||||
import NodeLabel from './NodeLabelComponent'
|
||||
import AlertModal from './AlertModal'
|
||||
import { IconTerminal } from '@tabler/icons-react'
|
||||
import {BASE_URL} from './store';
|
||||
|
||||
@ -27,9 +25,7 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
|
||||
const [codeText, setCodeText] = useState(data.code);
|
||||
const [codeTextOnLastRun, setCodeTextOnLastRun] = useState(false);
|
||||
const [reduceMethod, setReduceMethod] = useState('none');
|
||||
const [mapScope, setMapScope] = useState('response');
|
||||
const [reduceVars, setReduceVars] = useState([]);
|
||||
|
||||
const handleCodeChange = (code) => {
|
||||
if (codeTextOnLastRun !== false) {
|
||||
@ -82,7 +78,7 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
code: codeTextOnRun,
|
||||
scope: mapScope,
|
||||
responses: input_node_ids,
|
||||
reduce_vars: reduceMethod === 'avg' ? reduceVars : [],
|
||||
reduce_vars: [], // reduceMethod === 'avg' ? reduceVars : [],
|
||||
script_paths: script_paths,
|
||||
// write an extra part here that takes in reduce func
|
||||
}),
|
||||
@ -112,37 +108,9 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
}, rejected);
|
||||
};
|
||||
|
||||
const handleOnReduceMethodSelect = (event) => {
|
||||
const method = event.target.value;
|
||||
if (method === 'none') {
|
||||
setReduceVars([]);
|
||||
}
|
||||
setReduceMethod(method);
|
||||
};
|
||||
|
||||
const handleOnMapScopeSelect = (event) => {
|
||||
setMapScope(event.target.value);
|
||||
};
|
||||
|
||||
const handleReduceVarsChange = (event) => {
|
||||
// Split on commas, ignoring commas wrapped in double-quotes
|
||||
const regex_csv = /,(?!(?<=(?:^|,)\s*\x22(?:[^\x22]|\x22\x22|\\\x22)*,)(?:[^\x22]|\x22\x22|\\\x22)*\x22\s*(?:,|$))/g;
|
||||
setReduceVars(event.target.value.split(regex_csv).map(s => s.trim()));
|
||||
};
|
||||
|
||||
// To get CM editor state every render, use this and add ref={cmRef} to CodeMirror component
|
||||
// const cmRef = React.useRef({});
|
||||
// useEffect(() => {
|
||||
// if (cmRef.current?.view) console.log('EditorView:', cmRef.current?.view);
|
||||
// if (cmRef.current?.state) console.log('EditorState:', cmRef.current?.state);
|
||||
// if (cmRef.current?.editor) {
|
||||
// console.log('HTMLDivElement:', cmRef.current?.editor);
|
||||
// }
|
||||
// }, [cmRef.current]);
|
||||
|
||||
// const initEditor = (view, state) => {
|
||||
// console.log(view, state);
|
||||
// }
|
||||
|
||||
const hideStatusIndicator = () => {
|
||||
if (status !== 'none') { setStatus('none'); }
|
||||
@ -200,34 +168,17 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
{/* <CodeMirror
|
||||
// onCreateEditor={initEditor}
|
||||
value={data.code}
|
||||
height="200px"
|
||||
width="400px"
|
||||
theme={materialLight}
|
||||
style={{cursor: 'text'}}
|
||||
onChange={handleCodeChange}
|
||||
extensions={[python(), indentUnit.of(" ")]}
|
||||
/> */}
|
||||
</div>
|
||||
<hr/>
|
||||
{/* <hr/>
|
||||
<div>
|
||||
<div className="code-mirror-field-header">Method to reduce across <span className="code-style">responses</span>:</div>
|
||||
<select name="method" id="method" onChange={handleOnReduceMethodSelect} className="nodrag">
|
||||
<option value="none">None</option>
|
||||
<option value="avg">Average across</option>
|
||||
{/* <option value="custom">Custom reducer</option> */}
|
||||
</select>
|
||||
<span> </span>
|
||||
<input type="text" id="method-vars" name="method-vars" onChange={handleReduceVarsChange} disabled={reduceMethod === 'none'} className="nodrag" />
|
||||
{/* <label for="avg">Average over: </label>
|
||||
<select name="avg" id="avg">
|
||||
<option value="mod">mod</option>
|
||||
<option value="paragraph">paragraph</option>
|
||||
<option value="_none">N/A</option>
|
||||
</select> */}
|
||||
</div>
|
||||
</div> */}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
@ -1,6 +1,6 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Handle } from 'react-flow-renderer';
|
||||
import { Badge } from '@mantine/core';
|
||||
import { Badge, MultiSelect } from '@mantine/core';
|
||||
import useStore from './store';
|
||||
import NodeLabel from './NodeLabelComponent'
|
||||
import {BASE_URL} from './store';
|
||||
@ -33,13 +33,13 @@ const bucketResponsesByLLM = (responses) => {
|
||||
const InspectorNode = ({ data, id }) => {
|
||||
|
||||
const [responses, setResponses] = useState([]);
|
||||
const [varSelects, setVarSelects] = useState([]);
|
||||
const [pastInputs, setPastInputs] = useState([]);
|
||||
const inputEdgesForNode = useStore((state) => state.inputEdgesForNode);
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
|
||||
const handleVarValueSelect = () => {
|
||||
}
|
||||
// The MultiSelect so people can dynamically set what vars they care about
|
||||
const [multiSelectVars, setMultiSelectVars] = useState(data.vars || []);
|
||||
const [multiSelectValue, setMultiSelectValue] = useState(data.selected_vars || []);
|
||||
|
||||
const handleOnConnect = () => {
|
||||
// Get the ids from the connected input nodes:
|
||||
@ -59,19 +59,61 @@ const InspectorNode = ({ data, id }) => {
|
||||
}).then(function(json) {
|
||||
console.log(json);
|
||||
if (json.responses && json.responses.length > 0) {
|
||||
const responses = json.responses;
|
||||
|
||||
// Find all vars in response
|
||||
let found_vars = new Set();
|
||||
responses.forEach(res_obj => {
|
||||
Object.keys(res_obj.vars).forEach(v => {
|
||||
found_vars.add(v);
|
||||
});
|
||||
});
|
||||
|
||||
// Set the variables accessible in the MultiSelect for 'group by'
|
||||
setMultiSelectVars(Array.from(found_vars).map(name => (
|
||||
// We add a $ prefix to mark this as a prompt parameter, and so
|
||||
// in the future we can add special types of variables without name collisions
|
||||
{value: `${name}`, label: name}
|
||||
)).concat({value: 'LLM', label: 'LLM'}));
|
||||
|
||||
// If this is an initial run or the multi select value is empty, set to group by 'LLM' by default:
|
||||
let selected_vars = multiSelectValue;
|
||||
if (multiSelectValue.length === 0) {
|
||||
setMultiSelectValue(['LLM']);
|
||||
selected_vars = ['LLM'];
|
||||
}
|
||||
|
||||
// Now we need to perform groupings by each var in the selected vars list,
|
||||
// nesting the groupings (preferrably with custom divs) and sorting within
|
||||
// each group by value of that group's var (so all same values are clumped together).
|
||||
/**
|
||||
const groupBy = (resps, varnames) => {
|
||||
if (varnames.length === 0) return [];
|
||||
|
||||
const groupName = varnames[0];
|
||||
const groupedResponses = groupResponsesByVar(resps, groupName);
|
||||
const groupedResponseDivs = groupedResponses.map(g => groupBy(g, varnames.slice(1)));
|
||||
|
||||
return (
|
||||
<div key={groupName} className="response-group">
|
||||
<span>{groupName}</span>
|
||||
{groupedResponseDivs}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// Group by LLM
|
||||
if (selected_vars.includes('LLM')) {
|
||||
// ...
|
||||
|
||||
// Group without LLM
|
||||
} else {
|
||||
// ..
|
||||
}
|
||||
*/
|
||||
|
||||
// Bucket responses by LLM:
|
||||
const responses_by_llm = bucketResponsesByLLM(json.responses);
|
||||
|
||||
// // Get the var names across all responses, as a set
|
||||
// let tempvarnames = new Set();
|
||||
// json.responses.forEach(r => {
|
||||
// if (!r.vars) return;
|
||||
// Object.keys(r.vars).forEach(tempvarnames.add);
|
||||
// });
|
||||
|
||||
// // Create a dict version
|
||||
// let tempvars = {};
|
||||
|
||||
const colors = ['#ace1aeb1', '#f1b963b1', '#e46161b1', '#f8f398b1', '#defcf9b1', '#cadefcb1', '#c3bef0b1', '#cca8e9b1'];
|
||||
setResponses(Object.keys(responses_by_llm).map((llm, llm_idx) => {
|
||||
@ -136,11 +178,15 @@ const InspectorNode = ({ data, id }) => {
|
||||
<NodeLabel title={data.title || 'Inspect Node'}
|
||||
nodeId={id}
|
||||
icon={'🔍'} />
|
||||
{/* <div className="var-select-toolbar">
|
||||
{varSelects}
|
||||
</div> */}
|
||||
<MultiSelect onChange={setMultiSelectValue}
|
||||
className='nodrag nowheel'
|
||||
data={multiSelectVars}
|
||||
placeholder="Pick vars to group responses, in order of importance"
|
||||
size="xs"
|
||||
value={multiSelectValue}
|
||||
searchable />
|
||||
<div className="inspect-response-container nowheel nodrag">
|
||||
{responses}
|
||||
{responses}
|
||||
</div>
|
||||
<Handle
|
||||
type="target"
|
||||
|
@ -29,13 +29,16 @@ const TextFieldsNode = ({ data, id }) => {
|
||||
const [templateVars, setTemplateVars] = useState(data.vars || []);
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
const delButtonId = 'del-';
|
||||
const [idCounter, setIDCounter] = useState(0);
|
||||
// const [resizeObserver, setResizeObserver] = useState(null);
|
||||
|
||||
const get_id = () => {
|
||||
setIDCounter(idCounter + 1);
|
||||
return 'f' + idCounter.toString();
|
||||
}
|
||||
const getUID = useCallback(() => {
|
||||
if (data.fields) {
|
||||
return 'f' + (1 + Object.keys(data.fields).reduce((acc, key) => (
|
||||
Math.max(acc, parseInt(key.slice(1)))
|
||||
), 0).toString());
|
||||
} else {
|
||||
return 'f0';
|
||||
}
|
||||
}, [data.fields]);
|
||||
|
||||
// Handle a change in a text fields' input.
|
||||
const handleInputChange = useCallback((event) => {
|
||||
@ -74,7 +77,7 @@ const TextFieldsNode = ({ data, id }) => {
|
||||
delete new_data.fields[item_id];
|
||||
// if the new_data is empty, initialize it with one empty field
|
||||
if (Object.keys(new_data.fields).length === 0) {
|
||||
new_data.fields[get_id()] = '';
|
||||
new_data.fields[getUID()] = '';
|
||||
}
|
||||
setDataPropsForNode(id, new_data);
|
||||
}, [data, id, setDataPropsForNode]);
|
||||
@ -83,14 +86,14 @@ const TextFieldsNode = ({ data, id }) => {
|
||||
const [fields, setFields] = useState([]);
|
||||
useEffect(() => {
|
||||
if (!data.fields)
|
||||
setDataPropsForNode(id, { fields: {[get_id()]: ''}} );
|
||||
setDataPropsForNode(id, { fields: {[getUID()]: ''}} );
|
||||
}, []);
|
||||
|
||||
// Whenever 'data' changes, update the input fields to reflect the current state.
|
||||
useEffect(() => {
|
||||
const f = data.fields ? Object.keys(data.fields) : [];
|
||||
const num_fields = f.length;
|
||||
setFields(f.map((i, idx) => {
|
||||
setFields(f.map((i) => {
|
||||
const val = data.fields ? data.fields[i] : '';
|
||||
return (
|
||||
<div className="input-field" key={i}>
|
||||
@ -104,7 +107,7 @@ const TextFieldsNode = ({ data, id }) => {
|
||||
const handleAddField = useCallback(() => {
|
||||
// Update the data for this text fields' id.
|
||||
let new_data = { 'fields': {...data.fields} };
|
||||
new_data.fields[get_id()] = "";
|
||||
new_data.fields[getUID()] = "";
|
||||
setDataPropsForNode(id, new_data);
|
||||
}, [data, id, setDataPropsForNode]);
|
||||
|
||||
@ -129,7 +132,6 @@ const TextFieldsNode = ({ data, id }) => {
|
||||
});
|
||||
|
||||
observer.observe(elem);
|
||||
// setResizeObserver(observer);
|
||||
}
|
||||
ref.current = elem;
|
||||
}, [ref, hooksY]);
|
||||
|
@ -49,7 +49,7 @@ const VisNode = ({ data, id }) => {
|
||||
|
||||
// Re-plot responses when anything changes
|
||||
useEffect(() => {
|
||||
if (!responses || responses.length === 0 || !multiSelectValue || multiSelectValue.length === 0) return;
|
||||
if (!responses || responses.length === 0 || !multiSelectValue) return;
|
||||
|
||||
// Bucket responses by LLM:
|
||||
let responses_by_llm = {};
|
||||
@ -69,7 +69,7 @@ const VisNode = ({ data, id }) => {
|
||||
width: 420, height: 300, title: '', margin: {
|
||||
l: 105, r: 0, b: 20, t: 20, pad: 0
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const plot_grouped_boxplot = (resp_to_x) => {
|
||||
llm_names.forEach((llm, idx) => {
|
||||
@ -83,7 +83,7 @@ const VisNode = ({ data, id }) => {
|
||||
let text_items = [];
|
||||
for (const name of names) {
|
||||
rs.forEach(r => {
|
||||
if (r.vars[varnames[0]].trim() !== name) return;
|
||||
if (resp_to_x(r) !== name) return;
|
||||
x_items = x_items.concat(r.eval_res.items).flat();
|
||||
text_items = text_items.concat(createHoverTexts(r.responses)).flat();
|
||||
y_items = y_items.concat(Array(r.eval_res.items.length).fill(truncStr(name, 12))).flat();
|
||||
|
Loading…
x
Reference in New Issue
Block a user