mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Select legend items on parcoords. Move Plotly component to main render(), and only update spec and layout.
This commit is contained in:
parent
d129d6b8d4
commit
a73b3d5ceb
chain-forge/src
@ -7,11 +7,11 @@ const truncStr = (s, maxLen) => {
|
||||
return s;
|
||||
};
|
||||
|
||||
const PlotLegend = ({ labels }) => {
|
||||
const PlotLegend = ({ labels, onClickLabel }) => {
|
||||
return (
|
||||
<div className="plot-legend">
|
||||
{Object.entries(labels).map(([label, color]) => (
|
||||
<div key={label}>
|
||||
<div key={label} className="plot-legend-item nodrag" onClick={() => onClickLabel(label)}>
|
||||
<span style={{ backgroundColor: color, width: '10px', height: '10px', display: 'inline-block' }}></span>
|
||||
<span style={{ marginLeft: '5px' }}>{truncStr(label, 56)}</span>
|
||||
</div>
|
||||
|
@ -49,20 +49,56 @@ const extractEvalResultsForMetric = (metric, responses) => {
|
||||
const VisNode = ({ data, id }) => {
|
||||
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
const [plotlyObj, setPlotlyObj] = useState([]);
|
||||
const [plotlySpec, setPlotlySpec] = useState([]);
|
||||
const [plotlyLayout, setPlotlyLayout] = useState({});
|
||||
const [pastInputs, setPastInputs] = useState([]);
|
||||
const [responses, setResponses] = useState([]);
|
||||
const [status, setStatus] = useState('none');
|
||||
|
||||
const [plotLegend, setPlotLegend] = useState(null);
|
||||
const [selectedLegendItems, setSelectedLegendItems] = useState(null);
|
||||
|
||||
// The MultiSelect so people can dynamically set what vars they care about
|
||||
const [multiSelectVars, setMultiSelectVars] = useState(data.vars || []);
|
||||
const [multiSelectValue, setMultiSelectValue] = useState(data.selected_vars || []);
|
||||
|
||||
// When the user clicks an item in the drop-down,
|
||||
// we want to autoclose the multiselect drop-down:
|
||||
const multiSelectRef = useRef(null);
|
||||
const handleMultiSelectValueChange = (new_val) => {
|
||||
if (multiSelectRef) {
|
||||
multiSelectRef.current.blur();
|
||||
}
|
||||
setStatus('loading');
|
||||
setMultiSelectValue(new_val);
|
||||
};
|
||||
|
||||
// Re-plot responses when anything changes
|
||||
useEffect(() => {
|
||||
if (!responses || responses.length === 0 || !multiSelectValue) return;
|
||||
|
||||
setStatus('none');
|
||||
|
||||
// If there are variables but no variables are selected...
|
||||
if (multiSelectVars && multiSelectVars.length > 0 && multiSelectValue.length === 0) {
|
||||
console.warn('No variables selected to plot.');
|
||||
setSelectedLegendItems(null);
|
||||
setPlotLegend(null);
|
||||
setPlotlySpec([]);
|
||||
setPlotlyLayout({});
|
||||
return;
|
||||
}
|
||||
|
||||
// Create Plotly spec here
|
||||
const varnames = multiSelectValue;
|
||||
const colors = ['#44d044', '#f1b933', '#e46161', '#8888f9', '#33bef0', '#bb55f9', '#cadefc', '#f8f398'];
|
||||
let spec = [];
|
||||
let layout = {
|
||||
width: 420, height: 300, title: '', margin: {
|
||||
l: 105, r: 0, b: 36, t: 20, pad: 0
|
||||
}
|
||||
};
|
||||
|
||||
// Bucket responses by LLM:
|
||||
let responses_by_llm = {};
|
||||
responses.forEach(item => {
|
||||
@ -91,23 +127,12 @@ const VisNode = ({ data, id }) => {
|
||||
// For multiple LLMs, add a control drop-down selector to switch the LLM visualized in the plot.
|
||||
}
|
||||
|
||||
|
||||
const get_items = (eval_res_obj) => {
|
||||
if (typeof_eval_res.includes('KeyValue'))
|
||||
return eval_res_obj.items.map(item => item[metric_axes_labels[0]]);
|
||||
return eval_res_obj.items;
|
||||
};
|
||||
|
||||
// Create Plotly spec here
|
||||
const varnames = multiSelectValue;
|
||||
const colors = ['#baf078', '#f1b963', '#e46161', '#8888f9', '#33bef0', '#defcf9', '#cadefc', '#f8f398'];
|
||||
let spec = [];
|
||||
let layout = {
|
||||
width: 420, height: 300, title: '', margin: {
|
||||
l: 105, r: 0, b: 36, t: 20, pad: 0
|
||||
}
|
||||
};
|
||||
|
||||
const plot_grouped_boxplot = (resp_to_x) => {
|
||||
// Get all possible values of the single variable response ('name' vals)
|
||||
const names = new Set(responses.map(resp_to_x));
|
||||
@ -154,11 +179,21 @@ const VisNode = ({ data, id }) => {
|
||||
if (varnames.length === 1) {
|
||||
console.log("Plotting parallel coordinates...");
|
||||
let unique_vals = getUniqueKeysInResponses(responses, (resp_obj) => resp_obj.vars[varnames[0]]);
|
||||
// const response_txts = responses.map(res_obj => res_obj.responses).flat();
|
||||
|
||||
let group_colors = colors;
|
||||
const unselected_line_color = '#ddd';
|
||||
const spec_colors = responses.map(resp_obj => {
|
||||
const idx = unique_vals.indexOf(resp_obj.vars[varnames[0]]);
|
||||
return Array(resp_obj.eval_res.items.length).fill(idx);
|
||||
}).flat();
|
||||
|
||||
let colorscale = [];
|
||||
for (let i = 0; i < unique_vals.length; i++) {
|
||||
colorscale.push([i / (unique_vals.length-1), group_colors[i % group_colors.length]]);
|
||||
if (!selectedLegendItems || selectedLegendItems.indexOf(unique_vals[i]) > -1)
|
||||
colorscale.push([i / (unique_vals.length-1), group_colors[i % group_colors.length]]);
|
||||
else
|
||||
colorscale.push([i / (unique_vals.length-1), unselected_line_color]);
|
||||
}
|
||||
|
||||
let dimensions = [];
|
||||
@ -175,10 +210,7 @@ const VisNode = ({ data, id }) => {
|
||||
type: 'parcoords',
|
||||
pad: [10, 10, 10, 10],
|
||||
line: {
|
||||
color: responses.map(resp_obj => {
|
||||
const idx = unique_vals.indexOf(resp_obj.vars[varnames[0]]);
|
||||
return Array(resp_obj.eval_res.items.length).fill(idx);
|
||||
}).flat(),
|
||||
color: spec_colors,
|
||||
colorscale: colorscale,
|
||||
},
|
||||
dimensions: dimensions,
|
||||
@ -186,16 +218,37 @@ const VisNode = ({ data, id }) => {
|
||||
layout.margin = { l: 40, r: 40, b: 40, t: 50, pad: 0 };
|
||||
layout.paper_bgcolor = "white";
|
||||
layout.font = {color: "black"};
|
||||
layout.selectedpoints = [];
|
||||
|
||||
// There's no built-in legend for parallel coords, unfortunately, so we need to construct our own:
|
||||
let legend_labels = {};
|
||||
unique_vals.forEach((v, idx) =>
|
||||
{legend_labels[v] = group_colors[idx];}
|
||||
);
|
||||
plot_legend = (<PlotLegend labels={legend_labels} />);
|
||||
unique_vals.forEach((v, idx) => {
|
||||
if (!selectedLegendItems || selectedLegendItems.indexOf(v) > -1)
|
||||
legend_labels[v] = group_colors[idx];
|
||||
else
|
||||
legend_labels[v] = unselected_line_color;
|
||||
});
|
||||
const onClickLegendItem = (label) => {
|
||||
if (selectedLegendItems && selectedLegendItems.length === 1 && selectedLegendItems[0] === label)
|
||||
setSelectedLegendItems(null); // Clicking twice on a legend item deselects it and displays all
|
||||
else
|
||||
setSelectedLegendItems([label]);
|
||||
};
|
||||
plot_legend = (<PlotLegend labels={legend_labels} onClickLabel={onClickLegendItem} />);
|
||||
|
||||
// Tried to support Plotly hover events here, but looks like
|
||||
// currently there are unsupported for parcoords: https://github.com/plotly/plotly.js/issues/3012
|
||||
// onHover = (e) => {
|
||||
// console.log(e.curveNumber);
|
||||
// // const curveIdx = e.curveNumber;
|
||||
// // if (curveIdx < response_txts.length) {
|
||||
// // if (!selectedLegendItems || selectedLegendItems.indexOf(unique_vals[spec_colors[curveIdx]]) > -1)
|
||||
// // console.log(response_txts[curveIdx]);
|
||||
// // }
|
||||
// };
|
||||
|
||||
console.log(spec);
|
||||
} else {
|
||||
setSelectedLegendItems(null);
|
||||
console.error("Plotting evaluations with more than one metric and more than one prompt parameter is currently unsupported.");
|
||||
}
|
||||
}
|
||||
@ -252,14 +305,10 @@ const VisNode = ({ data, id }) => {
|
||||
spec = [spec];
|
||||
|
||||
setPlotLegend(plot_legend);
|
||||
setPlotlyObj((
|
||||
<Plot
|
||||
data={spec}
|
||||
layout={layout}
|
||||
/>
|
||||
));
|
||||
setPlotlySpec(spec);
|
||||
setPlotlyLayout(layout);
|
||||
|
||||
}, [multiSelectVars, multiSelectValue, responses]);
|
||||
}, [multiSelectVars, multiSelectValue, responses, selectedLegendItems]);
|
||||
|
||||
const handleOnConnect = useCallback(() => {
|
||||
// Grab the input node ids
|
||||
@ -295,7 +344,7 @@ const VisNode = ({ data, id }) => {
|
||||
// :: For 1 var and 1 eval_res that's a number, plot {x: var, y: eval_res}
|
||||
// :: For 2 vars and 1 eval_res that's a number, plot {x: var1, y: var2, z: eval_res}
|
||||
// :: For all else, don't plot anything (at the moment)
|
||||
}, [data, setPlotlyObj]);
|
||||
}, [data]);
|
||||
|
||||
// console.log('from visnode', data);
|
||||
if (data.input) {
|
||||
@ -318,8 +367,10 @@ const VisNode = ({ data, id }) => {
|
||||
<div className="vis-node cfnode">
|
||||
<NodeLabel title={data.title || 'Vis Node'}
|
||||
nodeId={id}
|
||||
status={status}
|
||||
icon={'📊'} />
|
||||
<MultiSelect onChange={setMultiSelectValue}
|
||||
<MultiSelect ref={multiSelectRef}
|
||||
onChange={handleMultiSelectValueChange}
|
||||
className='nodrag nowheel'
|
||||
data={multiSelectVars}
|
||||
placeholder="Pick all vars you wish to plot"
|
||||
@ -327,7 +378,11 @@ const VisNode = ({ data, id }) => {
|
||||
value={multiSelectValue}
|
||||
searchable />
|
||||
<div className="nodrag">
|
||||
{plotlyObj}
|
||||
{plotlySpec && plotlySpec.length > 0 ? (
|
||||
<Plot
|
||||
data={plotlySpec}
|
||||
layout={plotlyLayout}
|
||||
/>) : <></>}
|
||||
{plotLegend ? plotLegend : <></>}
|
||||
</div>
|
||||
<Handle
|
||||
|
@ -161,6 +161,15 @@
|
||||
font-size: 11px;
|
||||
font-family: monospace;
|
||||
}
|
||||
.plot-legend-item {
|
||||
cursor: pointer;
|
||||
}
|
||||
.plot-legend-item:hover {
|
||||
opacity: 0.6;
|
||||
}
|
||||
.plot-legend-item:active {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.inspector-node {
|
||||
background-color: #fff;
|
||||
@ -171,6 +180,8 @@
|
||||
.inspect-response-container {
|
||||
overflow-y: scroll;
|
||||
min-width: 150px;
|
||||
width: 350px;
|
||||
height: 300px;
|
||||
max-width: 650px;
|
||||
max-height: 650px;
|
||||
resize: both;
|
||||
|
Loading…
x
Reference in New Issue
Block a user