mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Trying CF on actual prompt engineering task.
This commit is contained in:
parent
126eacb9c8
commit
df1e89662d
17
chain-forge/package-lock.json
generated
17
chain-forge/package-lock.json
generated
@ -14,6 +14,7 @@
|
||||
"@reactflow/background": "^11.2.0",
|
||||
"@reactflow/controls": "^11.1.11",
|
||||
"@reactflow/core": "^11.7.0",
|
||||
"@reactflow/node-resizer": "^2.1.0",
|
||||
"@tabler/icons-react": "^2.17.0",
|
||||
"@testing-library/jest-dom": "^5.16.5",
|
||||
"@testing-library/react": "^13.4.0",
|
||||
@ -3814,6 +3815,22 @@
|
||||
"react-dom": ">=17"
|
||||
}
|
||||
},
|
||||
"node_modules/@reactflow/node-resizer": {
|
||||
"version": "2.1.0",
|
||||
"resolved": "https://registry.npmjs.org/@reactflow/node-resizer/-/node-resizer-2.1.0.tgz",
|
||||
"integrity": "sha512-DVL8nnWsltP8/iANadAcTaDB4wsEkx2mOLlBEPNE3yc5loSm3u9l5m4enXRcBym61MiMuTtDPzZMyYYQUjuYIg==",
|
||||
"dependencies": {
|
||||
"@reactflow/core": "^11.6.0",
|
||||
"classcat": "^5.0.4",
|
||||
"d3-drag": "^3.0.0",
|
||||
"d3-selection": "^3.0.0",
|
||||
"zustand": "^4.3.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=17",
|
||||
"react-dom": ">=17"
|
||||
}
|
||||
},
|
||||
"node_modules/@rollup/plugin-babel": {
|
||||
"version": "5.3.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/plugin-babel/-/plugin-babel-5.3.1.tgz",
|
||||
|
@ -9,6 +9,7 @@
|
||||
"@reactflow/background": "^11.2.0",
|
||||
"@reactflow/controls": "^11.1.11",
|
||||
"@reactflow/core": "^11.7.0",
|
||||
"@reactflow/node-resizer": "^2.1.0",
|
||||
"@tabler/icons-react": "^2.17.0",
|
||||
"@testing-library/jest-dom": "^5.16.5",
|
||||
"@testing-library/react": "^13.4.0",
|
||||
|
22
chain-forge/src/ControlledTextArea.js
Normal file
22
chain-forge/src/ControlledTextArea.js
Normal file
@ -0,0 +1,22 @@
|
||||
import React, { useEffect, useRef, useState } from 'react';
|
||||
|
||||
/* Modified from https://stackoverflow.com/a/68928267 */
|
||||
const ControlledTextArea = (props) => {
|
||||
const { value, onChange, ...rest } = props;
|
||||
const [cursor, setCursor] = useState(null);
|
||||
const ref = useRef(null);
|
||||
|
||||
useEffect(() => {
|
||||
const input = ref.current;
|
||||
if (input) input.setSelectionRange(cursor, cursor);
|
||||
}, [ref, cursor, value]);
|
||||
|
||||
const handleChange = (e) => {
|
||||
setCursor(e.target.selectionStart);
|
||||
onChange && onChange(e);
|
||||
};
|
||||
|
||||
return <textarea ref={ref} value={value} onChange={handleChange} {...rest} />;
|
||||
};
|
||||
|
||||
export default ControlledTextArea;
|
@ -39,6 +39,13 @@ const groupResponsesBy = (responses, keyFunc) => {
|
||||
});
|
||||
return [responses_by_key, unspecified_group];
|
||||
};
|
||||
const getUniqueKeysInResponses = (responses, keyFunc) => {
|
||||
let ukeys = new Set();
|
||||
responses.forEach(res_obj =>
|
||||
ukeys.add(keyFunc(res_obj)));
|
||||
return Array.from(ukeys);
|
||||
};
|
||||
const getLLMsInResponses = (responses) => getUniqueKeysInResponses(responses, (resp_obj) => resp_obj.llm);
|
||||
|
||||
const InspectorNode = ({ data, id }) => {
|
||||
|
||||
@ -61,17 +68,14 @@ const InspectorNode = ({ data, id }) => {
|
||||
const selected_vars = multiSelectValue;
|
||||
|
||||
// Find all LLMs in responses and store as array
|
||||
let found_llms = new Set();
|
||||
responses.forEach(res_obj =>
|
||||
found_llms.add(res_obj.llm));
|
||||
found_llms = Array.from(found_llms);
|
||||
let found_llms = getLLMsInResponses(responses);
|
||||
|
||||
// Assign a color to each LLM in responses
|
||||
const llm_colors = ['#ace1aeb1', '#f1b963b1', '#e46161b1', '#f8f398b1', '#defcf9b1', '#cadefcb1', '#c3bef0b1', '#cca8e9b1'];
|
||||
const llm_badge_colors = ['green', 'orange', 'red', 'yellow', 'cyan', 'indigo', 'grape'];
|
||||
const color_for_llm = (llm) => llm_colors[found_llms.indexOf(llm) % llm_colors.length];
|
||||
const badge_color_for_llm = (llm) => llm_badge_colors[found_llms.indexOf(llm) % llm_badge_colors.length];
|
||||
const response_box_colors = ['#ddd', '#eee', '#ddd', '#eee'];
|
||||
const response_box_colors = ['#eee', '#fff', '#eee', '#ddd', '#eee', '#ddd', '#eee'];
|
||||
const rgroup_color = (depth) => response_box_colors[depth % response_box_colors.length];
|
||||
|
||||
const getHeaderBadge = (key, val) => {
|
||||
|
23
chain-forge/src/PlotLegend.js
Normal file
23
chain-forge/src/PlotLegend.js
Normal file
@ -0,0 +1,23 @@
|
||||
import React from 'react';
|
||||
|
||||
const truncStr = (s, maxLen) => {
|
||||
if (s.length > maxLen) // Cut the name short if it's long
|
||||
return s.substring(0, maxLen) + '...'
|
||||
else
|
||||
return s;
|
||||
};
|
||||
|
||||
const PlotLegend = ({ labels }) => {
|
||||
return (
|
||||
<div className="plot-legend">
|
||||
{Object.entries(labels).map(([label, color]) => (
|
||||
<div key={label}>
|
||||
<span style={{ backgroundColor: color, width: '10px', height: '10px', display: 'inline-block' }}></span>
|
||||
<span style={{ marginLeft: '5px' }}>{truncStr(label, 56)}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default PlotLegend;
|
@ -48,8 +48,8 @@ const ScriptNode = ({ data, id }) => {
|
||||
setScriptFiles(f.map((i) => {
|
||||
const val = data.scriptFiles ? data.scriptFiles[i] : '';
|
||||
return (
|
||||
<div className="input-field" key={i}>
|
||||
<input className="script-node-input" type="text" id={i} onChange={handleInputChange} value={val}/>
|
||||
<div className="input-field nodrag" key={i}>
|
||||
<input className="script-node-input" type="text" id={i} onChange={handleInputChange} value={val}></input>
|
||||
<button className="remove-text-field-btn nodrag" id={delButtonId + i} onClick={handleDelete}>X</button>
|
||||
<br/>
|
||||
</div>
|
||||
|
@ -5,6 +5,22 @@ import useStore from './store';
|
||||
import NodeLabel from './NodeLabelComponent'
|
||||
import TemplateHooks from './TemplateHooksComponent';
|
||||
|
||||
/**
|
||||
* The way React handles text areas is annoying: it resets the cursor position upon every edit
|
||||
* (See https://stackoverflow.com/questions/46000544/react-controlled-input-cursor-jumps).
|
||||
* We can try to fix this (see commented out code below), but if we do, we
|
||||
* still run into race conditions around rendering. The simplest solution that
|
||||
* already works is to not use "value", but rather store the value of a textarea within the <></>.
|
||||
* This, however, spits out an error (even though it works just fine). We surpress this error
|
||||
* with the following:
|
||||
*
|
||||
* TODO: Make this more proper in the future!
|
||||
*/
|
||||
const originalWarn = console.error.bind(console.error);
|
||||
console.error = (msg) =>
|
||||
!msg.toString().includes('Use the `defaultValue` or `value` props instead of setting children on <textarea>') && originalWarn(msg);
|
||||
|
||||
// Helper funcs
|
||||
const union = (setA, setB) => {
|
||||
const _union = new Set(setA);
|
||||
for (const elem of setB) {
|
||||
@ -34,22 +50,40 @@ const TextFieldsNode = ({ data, id }) => {
|
||||
if (data.fields) {
|
||||
return 'f' + (1 + Object.keys(data.fields).reduce((acc, key) => (
|
||||
Math.max(acc, parseInt(key.slice(1)))
|
||||
), 0).toString());
|
||||
), 0)).toString();
|
||||
} else {
|
||||
return 'f0';
|
||||
}
|
||||
}, [data.fields]);
|
||||
|
||||
// const [cursor, setCursor] = useState(null);
|
||||
// const handleFocusField = useCallback((target) => {
|
||||
// if (!cursor || !target) return;
|
||||
|
||||
// const [last_focused_id, last_cursor_pos] = cursor;
|
||||
// console.log(last_focused_id, target.id);
|
||||
// if (target.id === last_focused_id) {
|
||||
// target.setSelectionRange(last_cursor_pos, last_cursor_pos);
|
||||
// setCursor(null);
|
||||
// console.log('reset cursor pos');
|
||||
// }
|
||||
// }, [cursor]);
|
||||
|
||||
// Handle a change in a text fields' input.
|
||||
const handleInputChange = useCallback((event) => {
|
||||
|
||||
// Update the data for this text fields' id.
|
||||
let new_data = { 'fields': {...data.fields} };
|
||||
new_data.fields[event.target.id] = event.target.value;
|
||||
|
||||
// Save the cursor pos, since React won't keep track of this
|
||||
// setCursor([event.target.id, event.target.selectionStart]);
|
||||
|
||||
// TODO: Optimize this check.
|
||||
let all_found_vars = new Set();
|
||||
const braces_regex = /(?<!\\){(.*?)(?<!\\)}/g; // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this}
|
||||
Object.keys(new_data['fields']).forEach((fieldId) => {
|
||||
const new_field_ids = Object.keys(new_data.fields);
|
||||
new_field_ids.forEach((fieldId) => {
|
||||
let found_vars = new_data['fields'][fieldId].match(braces_regex);
|
||||
if (found_vars && found_vars.length > 0) {
|
||||
found_vars = found_vars.map(name => name.substring(1, name.length-1)); // remove brackets {}
|
||||
@ -67,7 +101,7 @@ const TextFieldsNode = ({ data, id }) => {
|
||||
}
|
||||
|
||||
setDataPropsForNode(id, new_data);
|
||||
}, [data, id, setDataPropsForNode, templateVars]);
|
||||
}, [data.fields, id, templateVars]);
|
||||
|
||||
// Handle delete text field.
|
||||
const handleDelete = useCallback((event) => {
|
||||
@ -83,26 +117,11 @@ const TextFieldsNode = ({ data, id }) => {
|
||||
}, [data, id, setDataPropsForNode]);
|
||||
|
||||
// Initialize fields (run once at init)
|
||||
const [fields, setFields] = useState([]);
|
||||
useEffect(() => {
|
||||
if (!data.fields)
|
||||
setDataPropsForNode(id, { fields: {[getUID()]: ''}} );
|
||||
}, []);
|
||||
|
||||
// Whenever 'data' changes, update the input fields to reflect the current state.
|
||||
useEffect(() => {
|
||||
const f = data.fields ? Object.keys(data.fields) : [];
|
||||
const num_fields = f.length;
|
||||
setFields(f.map((i) => {
|
||||
const val = data.fields ? data.fields[i] : '';
|
||||
return (
|
||||
<div className="input-field" key={i}>
|
||||
<textarea id={i} name={i} className="text-field-fixed nodrag" rows="2" cols="40" value={val} onChange={handleInputChange} />
|
||||
{num_fields > 1 ? (<button id={delButtonId + i} className="remove-text-field-btn nodrag" onClick={handleDelete}>X</button>) : <></>}
|
||||
</div>
|
||||
)}));
|
||||
}, [data.fields, handleInputChange, handleDelete]);
|
||||
|
||||
// Add a field
|
||||
const handleAddField = useCallback(() => {
|
||||
// Update the data for this text fields' id.
|
||||
@ -111,13 +130,26 @@ const TextFieldsNode = ({ data, id }) => {
|
||||
setDataPropsForNode(id, new_data);
|
||||
}, [data, id, setDataPropsForNode]);
|
||||
|
||||
const [textFields, setTextFields] = useState([]);
|
||||
|
||||
// Dynamically update the y-position of the template hook <Handle>s
|
||||
const ref = useRef(null);
|
||||
const [hooksY, setHooksY] = useState(120);
|
||||
useEffect(() => {
|
||||
const node_height = ref.current.clientHeight;
|
||||
setHooksY(node_height + 75);
|
||||
}, [fields]);
|
||||
|
||||
if (data.fields) {
|
||||
setTextFields(
|
||||
Object.keys(data.fields).map(i => (
|
||||
<div className="input-field" key={i}>
|
||||
<textarea id={i} name={i} className="text-field-fixed nodrag" rows="2" cols="40" onChange={handleInputChange}>{data.fields[i]}</textarea>
|
||||
{Object.keys(data.fields).length > 1 ? (<button id={delButtonId + i} className="remove-text-field-btn nodrag" onClick={handleDelete}>X</button>) : <></>}
|
||||
</div>
|
||||
)));
|
||||
}
|
||||
|
||||
}, [data.fields, handleInputChange]);
|
||||
|
||||
const setRef = useCallback((elem) => {
|
||||
// To listen for resize events of the textarea, we need to use a ResizeObserver.
|
||||
@ -140,7 +172,7 @@ const TextFieldsNode = ({ data, id }) => {
|
||||
<div className="text-fields-node cfnode">
|
||||
<NodeLabel title={data.title || 'TextFields Node'} nodeId={id} icon={<IconTextPlus size="16px" />} />
|
||||
<div ref={setRef}>
|
||||
{fields}
|
||||
{textFields}
|
||||
</div>
|
||||
<Handle
|
||||
type="source"
|
||||
|
@ -4,6 +4,7 @@ import { MultiSelect } from '@mantine/core';
|
||||
import useStore from './store';
|
||||
import Plot from 'react-plotly.js';
|
||||
import NodeLabel from './NodeLabelComponent';
|
||||
import PlotLegend from './PlotLegend';
|
||||
import {BASE_URL} from './store';
|
||||
|
||||
// Helper funcs
|
||||
@ -35,6 +36,15 @@ const createHoverTexts = (responses) => {
|
||||
return [splitAndAddBreaks(truncStr(s, max_len), 60)];
|
||||
}).flat();
|
||||
}
|
||||
const getUniqueKeysInResponses = (responses, keyFunc) => {
|
||||
let ukeys = new Set();
|
||||
responses.forEach(res_obj =>
|
||||
ukeys.add(keyFunc(res_obj)));
|
||||
return Array.from(ukeys);
|
||||
};
|
||||
const extractEvalResultsForMetric = (metric, responses) => {
|
||||
return responses.map(resp_obj => resp_obj.eval_res.items.map(item => item[metric])).flat();
|
||||
};
|
||||
|
||||
const VisNode = ({ data, id }) => {
|
||||
|
||||
@ -43,6 +53,8 @@ const VisNode = ({ data, id }) => {
|
||||
const [pastInputs, setPastInputs] = useState([]);
|
||||
const [responses, setResponses] = useState([]);
|
||||
|
||||
const [plotLegend, setPlotLegend] = useState(null);
|
||||
|
||||
// The MultiSelect so people can dynamically set what vars they care about
|
||||
const [multiSelectVars, setMultiSelectVars] = useState(data.vars || []);
|
||||
const [multiSelectValue, setMultiSelectValue] = useState(data.selected_vars || []);
|
||||
@ -65,14 +77,15 @@ const VisNode = ({ data, id }) => {
|
||||
// (This is assumed to be consistent across response batches)
|
||||
const typeof_eval_res = 'dtype' in responses[0].eval_res ? responses[0].eval_res['dtype'] : 'Numeric';
|
||||
|
||||
let metric_ax_label = null;
|
||||
let plot_legend = null;
|
||||
let metric_axes_labels = [];
|
||||
let num_metrics = 1;
|
||||
if (typeof_eval_res.includes('KeyValue')) {
|
||||
// Check if it's a single-item dict; in which case we can extract the key to name the axis:
|
||||
const keys = Object.keys(responses[0].eval_res.items[0]);
|
||||
if (keys.length === 1)
|
||||
metric_ax_label = keys[0];
|
||||
else
|
||||
throw Error('Dict metrics with more than one key are currently unsupported.')
|
||||
metric_axes_labels = Object.keys(responses[0].eval_res.items[0]);
|
||||
num_metrics = metric_axes_labels.length;
|
||||
|
||||
// if (metric_axes_labels.length > 1)
|
||||
// throw Error('Dict metrics with more than one key are currently unsupported.')
|
||||
// TODO: When multiple metrics are present, and 1 var is selected (can be multiple LLMs as well),
|
||||
// default to Parallel Coordinates plot, with the 1 var values on the y-axis as colored groups, and metrics on x-axis.
|
||||
// For multiple LLMs, add a control drop-down selector to switch the LLM visualized in the plot.
|
||||
@ -81,13 +94,13 @@ const VisNode = ({ data, id }) => {
|
||||
|
||||
const get_items = (eval_res_obj) => {
|
||||
if (typeof_eval_res.includes('KeyValue'))
|
||||
return eval_res_obj.items.map(item => item[metric_ax_label]);
|
||||
return eval_res_obj.items.map(item => item[metric_axes_labels[0]]);
|
||||
return eval_res_obj.items;
|
||||
};
|
||||
|
||||
// Create Plotly spec here
|
||||
const varnames = multiSelectValue;
|
||||
const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
|
||||
const colors = ['#baf078', '#f1b963', '#e46161', '#8888f9', '#33bef0', '#defcf9', '#cadefc', '#f8f398'];
|
||||
let spec = [];
|
||||
let layout = {
|
||||
width: 420, height: 300, title: '', margin: {
|
||||
@ -129,68 +142,123 @@ const VisNode = ({ data, id }) => {
|
||||
});
|
||||
layout.boxmode = 'group';
|
||||
|
||||
if (metric_ax_label)
|
||||
if (metric_axes_labels.length > 0)
|
||||
layout.xaxis = {
|
||||
title: { font: {size: 12}, text: metric_ax_label },
|
||||
title: { font: {size: 12}, text: metric_axes_labels[0] },
|
||||
};
|
||||
};
|
||||
|
||||
if (varnames.length === 0) {
|
||||
// No variables means they used a single prompt (no template) to generate responses
|
||||
// (Users are likely evaluating differences in responses between LLMs)
|
||||
plot_grouped_boxplot((r) => truncStr(r.prompt.trim(), 12));
|
||||
}
|
||||
else if (varnames.length === 1) {
|
||||
// 1 var; numeric eval
|
||||
if (llm_names.length === 1) {
|
||||
// Simple box plot, as there is only a single LLM in the response
|
||||
// Get all possible values of the single variable response ('name' vals)
|
||||
const names = new Set(responses.map(r => r.vars[varnames[0]].trim()));
|
||||
for (const name of names) {
|
||||
let x_items = [];
|
||||
let text_items = [];
|
||||
responses.forEach(r => {
|
||||
if (r.vars[varnames[0]].trim() !== name) return;
|
||||
x_items = x_items.concat(get_items(r.eval_res));
|
||||
text_items = text_items.concat(createHoverTexts(r.responses));
|
||||
});
|
||||
spec.push(
|
||||
{type: 'box', x: x_items, name: truncStr(name, 12), boxpoints: 'all', text: text_items, hovertemplate: '%{text}', orientation: 'h'}
|
||||
);
|
||||
if (num_metrics > 1) {
|
||||
// For 2 or more metrics, display a parallel coordinates plot.
|
||||
// :: For instance, if evaluator produces { height: 32, weight: 120 } plot responses with 2 metrics, 'height' and 'weight'
|
||||
if (varnames.length === 1) {
|
||||
console.log("Plotting parallel coordinates...");
|
||||
let unique_vals = getUniqueKeysInResponses(responses, (resp_obj) => resp_obj.vars[varnames[0]]);
|
||||
let group_colors = colors;
|
||||
|
||||
let colorscale = [];
|
||||
for (let i = 0; i < unique_vals.length; i++) {
|
||||
colorscale.push([i / (unique_vals.length-1), group_colors[i % group_colors.length]]);
|
||||
}
|
||||
layout.hovermode = 'closest';
|
||||
|
||||
if (metric_ax_label)
|
||||
layout.xaxis = {
|
||||
title: { font: {size: 12}, text: metric_ax_label },
|
||||
};
|
||||
let dimensions = [];
|
||||
metric_axes_labels.forEach(metric => {
|
||||
const evals = extractEvalResultsForMetric(metric, responses);
|
||||
dimensions.push({
|
||||
range: [Math.min(...evals), Math.max(...evals)],
|
||||
label: metric,
|
||||
values: evals,
|
||||
});
|
||||
});
|
||||
|
||||
spec.push({
|
||||
type: 'parcoords',
|
||||
pad: [10, 10, 10, 10],
|
||||
line: {
|
||||
color: responses.map(resp_obj => {
|
||||
const idx = unique_vals.indexOf(resp_obj.vars[varnames[0]]);
|
||||
return Array(resp_obj.eval_res.items.length).fill(idx);
|
||||
}).flat(),
|
||||
colorscale: colorscale,
|
||||
},
|
||||
dimensions: dimensions,
|
||||
});
|
||||
layout.margin = { l: 40, r: 40, b: 40, t: 50, pad: 0 };
|
||||
layout.paper_bgcolor = "white";
|
||||
layout.font = {color: "black"};
|
||||
|
||||
// There's no built-in legend for parallel coords, unfortunately, so we need to construct our own:
|
||||
let legend_labels = {};
|
||||
unique_vals.forEach((v, idx) =>
|
||||
{legend_labels[v] = group_colors[idx];}
|
||||
);
|
||||
plot_legend = (<PlotLegend labels={legend_labels} />);
|
||||
|
||||
console.log(spec);
|
||||
} else {
|
||||
// There are multiple LLMs in the response; do a grouped box plot by LLM.
|
||||
// Note that 'name' is now the LLM, and 'x' stores the value of the var:
|
||||
plot_grouped_boxplot((r) => r.vars[varnames[0]].trim());
|
||||
console.error("Plotting evaluations with more than one metric and more than one prompt parameter is currently unsupported.");
|
||||
}
|
||||
}
|
||||
else if (varnames.length === 2) {
|
||||
// Input is 2 vars; numeric eval
|
||||
// Display a 3D scatterplot with 2 dimensions:
|
||||
spec = {
|
||||
type: 'scatter3d',
|
||||
x: responses.map(r => r.vars[varnames[0]]).map(s => truncStr(s, 12)),
|
||||
y: responses.map(r => r.vars[varnames[1]]).map(s => truncStr(s, 12)),
|
||||
z: responses.map(r => get_items(r.eval_res).reduce((acc, val) => (acc + val), 0) / r.eval_res.items.length), // calculates mean
|
||||
mode: 'markers',
|
||||
else { // A single metric --use plots like grouped box-and-whiskers, 3d scatterplot
|
||||
if (varnames.length === 0) {
|
||||
// No variables means they used a single prompt (no template) to generate responses
|
||||
// (Users are likely evaluating differences in responses between LLMs)
|
||||
plot_grouped_boxplot((r) => truncStr(r.prompt.trim(), 12));
|
||||
}
|
||||
else if (varnames.length === 1) {
|
||||
// 1 var; numeric eval
|
||||
if (llm_names.length === 1) {
|
||||
// Simple box plot, as there is only a single LLM in the response
|
||||
// Get all possible values of the single variable response ('name' vals)
|
||||
const names = new Set(responses.map(r => r.vars[varnames[0]].trim()));
|
||||
for (const name of names) {
|
||||
let x_items = [];
|
||||
let text_items = [];
|
||||
responses.forEach(r => {
|
||||
if (r.vars[varnames[0]].trim() !== name) return;
|
||||
x_items = x_items.concat(get_items(r.eval_res));
|
||||
text_items = text_items.concat(createHoverTexts(r.responses));
|
||||
});
|
||||
spec.push(
|
||||
{type: 'box', x: x_items, name: truncStr(name, 12), boxpoints: 'all', text: text_items, hovertemplate: '%{text}', orientation: 'h'}
|
||||
);
|
||||
}
|
||||
layout.hovermode = 'closest';
|
||||
|
||||
if (metric_axes_labels.length > 0)
|
||||
layout.xaxis = {
|
||||
title: { font: {size: 12}, text: metric_axes_labels[0] },
|
||||
};
|
||||
} else {
|
||||
// There are multiple LLMs in the response; do a grouped box plot by LLM.
|
||||
// Note that 'name' is now the LLM, and 'x' stores the value of the var:
|
||||
plot_grouped_boxplot((r) => r.vars[varnames[0]].trim());
|
||||
}
|
||||
}
|
||||
else if (varnames.length === 2) {
|
||||
// Input is 2 vars; numeric eval
|
||||
// Display a 3D scatterplot with 2 dimensions:
|
||||
spec = {
|
||||
type: 'scatter3d',
|
||||
x: responses.map(r => r.vars[varnames[0]]).map(s => truncStr(s, 12)),
|
||||
y: responses.map(r => r.vars[varnames[1]]).map(s => truncStr(s, 12)),
|
||||
z: responses.map(r => get_items(r.eval_res).reduce((acc, val) => (acc + val), 0) / r.eval_res.items.length), // calculates mean
|
||||
mode: 'markers',
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!Array.isArray(spec))
|
||||
spec = [spec];
|
||||
|
||||
setPlotLegend(plot_legend);
|
||||
setPlotlyObj((
|
||||
<Plot
|
||||
data={spec}
|
||||
layout={layout}
|
||||
/>
|
||||
))
|
||||
));
|
||||
|
||||
}, [multiSelectVars, multiSelectValue, responses]);
|
||||
|
||||
const handleOnConnect = useCallback(() => {
|
||||
@ -258,7 +326,10 @@ const VisNode = ({ data, id }) => {
|
||||
size="sm"
|
||||
value={multiSelectValue}
|
||||
searchable />
|
||||
<div className="nodrag">{plotlyObj}</div>
|
||||
<div className="nodrag">
|
||||
{plotlyObj}
|
||||
{plotLegend ? plotLegend : <></>}
|
||||
</div>
|
||||
<Handle
|
||||
type="target"
|
||||
position="left"
|
||||
|
@ -157,6 +157,10 @@
|
||||
border: 1px solid #999;
|
||||
border-radius: 5px;
|
||||
}
|
||||
.plot-legend {
|
||||
font-size: 11px;
|
||||
font-family: monospace;
|
||||
}
|
||||
|
||||
.inspector-node {
|
||||
background-color: #fff;
|
||||
|
@ -31,7 +31,7 @@ class ResponseInfo:
|
||||
"""Stores info about a single response. Passed to evaluator functions."""
|
||||
text: str
|
||||
prompt: str
|
||||
var: str
|
||||
var: list
|
||||
llm: str
|
||||
|
||||
def __str__(self):
|
||||
|
@ -10,8 +10,8 @@ from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
# A 'cheap' version of controlling for rate limits is to wait a few seconds between batches of requests being sent off.
|
||||
# The following is only a guideline, and a bit on the conservative side.
|
||||
MAX_SIMULTANEOUS_REQUESTS = {
|
||||
LLM.ChatGPT: (50, 10), # max 50 requests a batch; wait 10 seconds between
|
||||
LLM.GPT4: (20, 10), # max 10 requests a batch; wait 10 seconds between
|
||||
LLM.ChatGPT: (30, 10), # max 30 requests a batch; wait 10 seconds between
|
||||
LLM.GPT4: (5, 10), # max 5 requests a batch; wait 10 seconds between
|
||||
LLM.Alpaca7B: (1, 0), # 1 indicates synchronous
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user