Trying CF on actual prompt engineering task.

This commit is contained in:
Ian Arawjo 2023-05-16 22:05:04 -04:00
parent 126eacb9c8
commit df1e89662d
11 changed files with 256 additions and 82 deletions

View File

@ -14,6 +14,7 @@
"@reactflow/background": "^11.2.0",
"@reactflow/controls": "^11.1.11",
"@reactflow/core": "^11.7.0",
"@reactflow/node-resizer": "^2.1.0",
"@tabler/icons-react": "^2.17.0",
"@testing-library/jest-dom": "^5.16.5",
"@testing-library/react": "^13.4.0",
@ -3814,6 +3815,22 @@
"react-dom": ">=17"
}
},
"node_modules/@reactflow/node-resizer": {
"version": "2.1.0",
"resolved": "https://registry.npmjs.org/@reactflow/node-resizer/-/node-resizer-2.1.0.tgz",
"integrity": "sha512-DVL8nnWsltP8/iANadAcTaDB4wsEkx2mOLlBEPNE3yc5loSm3u9l5m4enXRcBym61MiMuTtDPzZMyYYQUjuYIg==",
"dependencies": {
"@reactflow/core": "^11.6.0",
"classcat": "^5.0.4",
"d3-drag": "^3.0.0",
"d3-selection": "^3.0.0",
"zustand": "^4.3.1"
},
"peerDependencies": {
"react": ">=17",
"react-dom": ">=17"
}
},
"node_modules/@rollup/plugin-babel": {
"version": "5.3.1",
"resolved": "https://registry.npmjs.org/@rollup/plugin-babel/-/plugin-babel-5.3.1.tgz",

View File

@ -9,6 +9,7 @@
"@reactflow/background": "^11.2.0",
"@reactflow/controls": "^11.1.11",
"@reactflow/core": "^11.7.0",
"@reactflow/node-resizer": "^2.1.0",
"@tabler/icons-react": "^2.17.0",
"@testing-library/jest-dom": "^5.16.5",
"@testing-library/react": "^13.4.0",

View File

@ -0,0 +1,22 @@
import React, { useEffect, useRef, useState } from 'react';
/* Modified from https://stackoverflow.com/a/68928267 */
const ControlledTextArea = (props) => {
const { value, onChange, ...rest } = props;
const [cursor, setCursor] = useState(null);
const ref = useRef(null);
useEffect(() => {
const input = ref.current;
if (input) input.setSelectionRange(cursor, cursor);
}, [ref, cursor, value]);
const handleChange = (e) => {
setCursor(e.target.selectionStart);
onChange && onChange(e);
};
return <textarea ref={ref} value={value} onChange={handleChange} {...rest} />;
};
export default ControlledTextArea;

View File

@ -39,6 +39,13 @@ const groupResponsesBy = (responses, keyFunc) => {
});
return [responses_by_key, unspecified_group];
};
const getUniqueKeysInResponses = (responses, keyFunc) => {
let ukeys = new Set();
responses.forEach(res_obj =>
ukeys.add(keyFunc(res_obj)));
return Array.from(ukeys);
};
const getLLMsInResponses = (responses) => getUniqueKeysInResponses(responses, (resp_obj) => resp_obj.llm);
const InspectorNode = ({ data, id }) => {
@ -61,17 +68,14 @@ const InspectorNode = ({ data, id }) => {
const selected_vars = multiSelectValue;
// Find all LLMs in responses and store as array
let found_llms = new Set();
responses.forEach(res_obj =>
found_llms.add(res_obj.llm));
found_llms = Array.from(found_llms);
let found_llms = getLLMsInResponses(responses);
// Assign a color to each LLM in responses
const llm_colors = ['#ace1aeb1', '#f1b963b1', '#e46161b1', '#f8f398b1', '#defcf9b1', '#cadefcb1', '#c3bef0b1', '#cca8e9b1'];
const llm_badge_colors = ['green', 'orange', 'red', 'yellow', 'cyan', 'indigo', 'grape'];
const color_for_llm = (llm) => llm_colors[found_llms.indexOf(llm) % llm_colors.length];
const badge_color_for_llm = (llm) => llm_badge_colors[found_llms.indexOf(llm) % llm_badge_colors.length];
const response_box_colors = ['#ddd', '#eee', '#ddd', '#eee'];
const response_box_colors = ['#eee', '#fff', '#eee', '#ddd', '#eee', '#ddd', '#eee'];
const rgroup_color = (depth) => response_box_colors[depth % response_box_colors.length];
const getHeaderBadge = (key, val) => {

View File

@ -0,0 +1,23 @@
import React from 'react';
const truncStr = (s, maxLen) => {
if (s.length > maxLen) // Cut the name short if it's long
return s.substring(0, maxLen) + '...'
else
return s;
};
const PlotLegend = ({ labels }) => {
return (
<div className="plot-legend">
{Object.entries(labels).map(([label, color]) => (
<div key={label}>
<span style={{ backgroundColor: color, width: '10px', height: '10px', display: 'inline-block' }}></span>
<span style={{ marginLeft: '5px' }}>{truncStr(label, 56)}</span>
</div>
))}
</div>
);
};
export default PlotLegend;

View File

@ -48,8 +48,8 @@ const ScriptNode = ({ data, id }) => {
setScriptFiles(f.map((i) => {
const val = data.scriptFiles ? data.scriptFiles[i] : '';
return (
<div className="input-field" key={i}>
<input className="script-node-input" type="text" id={i} onChange={handleInputChange} value={val}/>
<div className="input-field nodrag" key={i}>
<input className="script-node-input" type="text" id={i} onChange={handleInputChange} value={val}></input>
<button className="remove-text-field-btn nodrag" id={delButtonId + i} onClick={handleDelete}>X</button>
<br/>
</div>

View File

@ -5,6 +5,22 @@ import useStore from './store';
import NodeLabel from './NodeLabelComponent'
import TemplateHooks from './TemplateHooksComponent';
/**
* The way React handles text areas is annoying: it resets the cursor position upon every edit
* (See https://stackoverflow.com/questions/46000544/react-controlled-input-cursor-jumps).
* We can try to fix this (see commented out code below), but if we do, we
* still run into race conditions around rendering. The simplest solution that
* already works is to not use "value", but rather store the value of a textarea within the <></>.
* This, however, spits out an error (even though it works just fine). We surpress this error
* with the following:
*
* TODO: Make this more proper in the future!
*/
const originalWarn = console.error.bind(console.error);
console.error = (msg) =>
!msg.toString().includes('Use the `defaultValue` or `value` props instead of setting children on <textarea>') && originalWarn(msg);
// Helper funcs
const union = (setA, setB) => {
const _union = new Set(setA);
for (const elem of setB) {
@ -34,22 +50,40 @@ const TextFieldsNode = ({ data, id }) => {
if (data.fields) {
return 'f' + (1 + Object.keys(data.fields).reduce((acc, key) => (
Math.max(acc, parseInt(key.slice(1)))
), 0).toString());
), 0)).toString();
} else {
return 'f0';
}
}, [data.fields]);
// const [cursor, setCursor] = useState(null);
// const handleFocusField = useCallback((target) => {
// if (!cursor || !target) return;
// const [last_focused_id, last_cursor_pos] = cursor;
// console.log(last_focused_id, target.id);
// if (target.id === last_focused_id) {
// target.setSelectionRange(last_cursor_pos, last_cursor_pos);
// setCursor(null);
// console.log('reset cursor pos');
// }
// }, [cursor]);
// Handle a change in a text fields' input.
const handleInputChange = useCallback((event) => {
// Update the data for this text fields' id.
let new_data = { 'fields': {...data.fields} };
new_data.fields[event.target.id] = event.target.value;
// Save the cursor pos, since React won't keep track of this
// setCursor([event.target.id, event.target.selectionStart]);
// TODO: Optimize this check.
let all_found_vars = new Set();
const braces_regex = /(?<!\\){(.*?)(?<!\\)}/g; // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this}
Object.keys(new_data['fields']).forEach((fieldId) => {
const new_field_ids = Object.keys(new_data.fields);
new_field_ids.forEach((fieldId) => {
let found_vars = new_data['fields'][fieldId].match(braces_regex);
if (found_vars && found_vars.length > 0) {
found_vars = found_vars.map(name => name.substring(1, name.length-1)); // remove brackets {}
@ -67,7 +101,7 @@ const TextFieldsNode = ({ data, id }) => {
}
setDataPropsForNode(id, new_data);
}, [data, id, setDataPropsForNode, templateVars]);
}, [data.fields, id, templateVars]);
// Handle delete text field.
const handleDelete = useCallback((event) => {
@ -83,26 +117,11 @@ const TextFieldsNode = ({ data, id }) => {
}, [data, id, setDataPropsForNode]);
// Initialize fields (run once at init)
const [fields, setFields] = useState([]);
useEffect(() => {
if (!data.fields)
setDataPropsForNode(id, { fields: {[getUID()]: ''}} );
}, []);
// Whenever 'data' changes, update the input fields to reflect the current state.
useEffect(() => {
const f = data.fields ? Object.keys(data.fields) : [];
const num_fields = f.length;
setFields(f.map((i) => {
const val = data.fields ? data.fields[i] : '';
return (
<div className="input-field" key={i}>
<textarea id={i} name={i} className="text-field-fixed nodrag" rows="2" cols="40" value={val} onChange={handleInputChange} />
{num_fields > 1 ? (<button id={delButtonId + i} className="remove-text-field-btn nodrag" onClick={handleDelete}>X</button>) : <></>}
</div>
)}));
}, [data.fields, handleInputChange, handleDelete]);
// Add a field
const handleAddField = useCallback(() => {
// Update the data for this text fields' id.
@ -111,13 +130,26 @@ const TextFieldsNode = ({ data, id }) => {
setDataPropsForNode(id, new_data);
}, [data, id, setDataPropsForNode]);
const [textFields, setTextFields] = useState([]);
// Dynamically update the y-position of the template hook <Handle>s
const ref = useRef(null);
const [hooksY, setHooksY] = useState(120);
useEffect(() => {
const node_height = ref.current.clientHeight;
setHooksY(node_height + 75);
}, [fields]);
if (data.fields) {
setTextFields(
Object.keys(data.fields).map(i => (
<div className="input-field" key={i}>
<textarea id={i} name={i} className="text-field-fixed nodrag" rows="2" cols="40" onChange={handleInputChange}>{data.fields[i]}</textarea>
{Object.keys(data.fields).length > 1 ? (<button id={delButtonId + i} className="remove-text-field-btn nodrag" onClick={handleDelete}>X</button>) : <></>}
</div>
)));
}
}, [data.fields, handleInputChange]);
const setRef = useCallback((elem) => {
// To listen for resize events of the textarea, we need to use a ResizeObserver.
@ -140,7 +172,7 @@ const TextFieldsNode = ({ data, id }) => {
<div className="text-fields-node cfnode">
<NodeLabel title={data.title || 'TextFields Node'} nodeId={id} icon={<IconTextPlus size="16px" />} />
<div ref={setRef}>
{fields}
{textFields}
</div>
<Handle
type="source"

View File

@ -4,6 +4,7 @@ import { MultiSelect } from '@mantine/core';
import useStore from './store';
import Plot from 'react-plotly.js';
import NodeLabel from './NodeLabelComponent';
import PlotLegend from './PlotLegend';
import {BASE_URL} from './store';
// Helper funcs
@ -35,6 +36,15 @@ const createHoverTexts = (responses) => {
return [splitAndAddBreaks(truncStr(s, max_len), 60)];
}).flat();
}
const getUniqueKeysInResponses = (responses, keyFunc) => {
let ukeys = new Set();
responses.forEach(res_obj =>
ukeys.add(keyFunc(res_obj)));
return Array.from(ukeys);
};
const extractEvalResultsForMetric = (metric, responses) => {
return responses.map(resp_obj => resp_obj.eval_res.items.map(item => item[metric])).flat();
};
const VisNode = ({ data, id }) => {
@ -43,6 +53,8 @@ const VisNode = ({ data, id }) => {
const [pastInputs, setPastInputs] = useState([]);
const [responses, setResponses] = useState([]);
const [plotLegend, setPlotLegend] = useState(null);
// The MultiSelect so people can dynamically set what vars they care about
const [multiSelectVars, setMultiSelectVars] = useState(data.vars || []);
const [multiSelectValue, setMultiSelectValue] = useState(data.selected_vars || []);
@ -65,14 +77,15 @@ const VisNode = ({ data, id }) => {
// (This is assumed to be consistent across response batches)
const typeof_eval_res = 'dtype' in responses[0].eval_res ? responses[0].eval_res['dtype'] : 'Numeric';
let metric_ax_label = null;
let plot_legend = null;
let metric_axes_labels = [];
let num_metrics = 1;
if (typeof_eval_res.includes('KeyValue')) {
// Check if it's a single-item dict; in which case we can extract the key to name the axis:
const keys = Object.keys(responses[0].eval_res.items[0]);
if (keys.length === 1)
metric_ax_label = keys[0];
else
throw Error('Dict metrics with more than one key are currently unsupported.')
metric_axes_labels = Object.keys(responses[0].eval_res.items[0]);
num_metrics = metric_axes_labels.length;
// if (metric_axes_labels.length > 1)
// throw Error('Dict metrics with more than one key are currently unsupported.')
// TODO: When multiple metrics are present, and 1 var is selected (can be multiple LLMs as well),
// default to Parallel Coordinates plot, with the 1 var values on the y-axis as colored groups, and metrics on x-axis.
// For multiple LLMs, add a control drop-down selector to switch the LLM visualized in the plot.
@ -81,13 +94,13 @@ const VisNode = ({ data, id }) => {
const get_items = (eval_res_obj) => {
if (typeof_eval_res.includes('KeyValue'))
return eval_res_obj.items.map(item => item[metric_ax_label]);
return eval_res_obj.items.map(item => item[metric_axes_labels[0]]);
return eval_res_obj.items;
};
// Create Plotly spec here
const varnames = multiSelectValue;
const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
const colors = ['#baf078', '#f1b963', '#e46161', '#8888f9', '#33bef0', '#defcf9', '#cadefc', '#f8f398'];
let spec = [];
let layout = {
width: 420, height: 300, title: '', margin: {
@ -129,68 +142,123 @@ const VisNode = ({ data, id }) => {
});
layout.boxmode = 'group';
if (metric_ax_label)
if (metric_axes_labels.length > 0)
layout.xaxis = {
title: { font: {size: 12}, text: metric_ax_label },
title: { font: {size: 12}, text: metric_axes_labels[0] },
};
};
if (varnames.length === 0) {
// No variables means they used a single prompt (no template) to generate responses
// (Users are likely evaluating differences in responses between LLMs)
plot_grouped_boxplot((r) => truncStr(r.prompt.trim(), 12));
}
else if (varnames.length === 1) {
// 1 var; numeric eval
if (llm_names.length === 1) {
// Simple box plot, as there is only a single LLM in the response
// Get all possible values of the single variable response ('name' vals)
const names = new Set(responses.map(r => r.vars[varnames[0]].trim()));
for (const name of names) {
let x_items = [];
let text_items = [];
responses.forEach(r => {
if (r.vars[varnames[0]].trim() !== name) return;
x_items = x_items.concat(get_items(r.eval_res));
text_items = text_items.concat(createHoverTexts(r.responses));
});
spec.push(
{type: 'box', x: x_items, name: truncStr(name, 12), boxpoints: 'all', text: text_items, hovertemplate: '%{text}', orientation: 'h'}
);
if (num_metrics > 1) {
// For 2 or more metrics, display a parallel coordinates plot.
// :: For instance, if evaluator produces { height: 32, weight: 120 } plot responses with 2 metrics, 'height' and 'weight'
if (varnames.length === 1) {
console.log("Plotting parallel coordinates...");
let unique_vals = getUniqueKeysInResponses(responses, (resp_obj) => resp_obj.vars[varnames[0]]);
let group_colors = colors;
let colorscale = [];
for (let i = 0; i < unique_vals.length; i++) {
colorscale.push([i / (unique_vals.length-1), group_colors[i % group_colors.length]]);
}
layout.hovermode = 'closest';
if (metric_ax_label)
layout.xaxis = {
title: { font: {size: 12}, text: metric_ax_label },
};
let dimensions = [];
metric_axes_labels.forEach(metric => {
const evals = extractEvalResultsForMetric(metric, responses);
dimensions.push({
range: [Math.min(...evals), Math.max(...evals)],
label: metric,
values: evals,
});
});
spec.push({
type: 'parcoords',
pad: [10, 10, 10, 10],
line: {
color: responses.map(resp_obj => {
const idx = unique_vals.indexOf(resp_obj.vars[varnames[0]]);
return Array(resp_obj.eval_res.items.length).fill(idx);
}).flat(),
colorscale: colorscale,
},
dimensions: dimensions,
});
layout.margin = { l: 40, r: 40, b: 40, t: 50, pad: 0 };
layout.paper_bgcolor = "white";
layout.font = {color: "black"};
// There's no built-in legend for parallel coords, unfortunately, so we need to construct our own:
let legend_labels = {};
unique_vals.forEach((v, idx) =>
{legend_labels[v] = group_colors[idx];}
);
plot_legend = (<PlotLegend labels={legend_labels} />);
console.log(spec);
} else {
// There are multiple LLMs in the response; do a grouped box plot by LLM.
// Note that 'name' is now the LLM, and 'x' stores the value of the var:
plot_grouped_boxplot((r) => r.vars[varnames[0]].trim());
console.error("Plotting evaluations with more than one metric and more than one prompt parameter is currently unsupported.");
}
}
else if (varnames.length === 2) {
// Input is 2 vars; numeric eval
// Display a 3D scatterplot with 2 dimensions:
spec = {
type: 'scatter3d',
x: responses.map(r => r.vars[varnames[0]]).map(s => truncStr(s, 12)),
y: responses.map(r => r.vars[varnames[1]]).map(s => truncStr(s, 12)),
z: responses.map(r => get_items(r.eval_res).reduce((acc, val) => (acc + val), 0) / r.eval_res.items.length), // calculates mean
mode: 'markers',
else { // A single metric --use plots like grouped box-and-whiskers, 3d scatterplot
if (varnames.length === 0) {
// No variables means they used a single prompt (no template) to generate responses
// (Users are likely evaluating differences in responses between LLMs)
plot_grouped_boxplot((r) => truncStr(r.prompt.trim(), 12));
}
else if (varnames.length === 1) {
// 1 var; numeric eval
if (llm_names.length === 1) {
// Simple box plot, as there is only a single LLM in the response
// Get all possible values of the single variable response ('name' vals)
const names = new Set(responses.map(r => r.vars[varnames[0]].trim()));
for (const name of names) {
let x_items = [];
let text_items = [];
responses.forEach(r => {
if (r.vars[varnames[0]].trim() !== name) return;
x_items = x_items.concat(get_items(r.eval_res));
text_items = text_items.concat(createHoverTexts(r.responses));
});
spec.push(
{type: 'box', x: x_items, name: truncStr(name, 12), boxpoints: 'all', text: text_items, hovertemplate: '%{text}', orientation: 'h'}
);
}
layout.hovermode = 'closest';
if (metric_axes_labels.length > 0)
layout.xaxis = {
title: { font: {size: 12}, text: metric_axes_labels[0] },
};
} else {
// There are multiple LLMs in the response; do a grouped box plot by LLM.
// Note that 'name' is now the LLM, and 'x' stores the value of the var:
plot_grouped_boxplot((r) => r.vars[varnames[0]].trim());
}
}
else if (varnames.length === 2) {
// Input is 2 vars; numeric eval
// Display a 3D scatterplot with 2 dimensions:
spec = {
type: 'scatter3d',
x: responses.map(r => r.vars[varnames[0]]).map(s => truncStr(s, 12)),
y: responses.map(r => r.vars[varnames[1]]).map(s => truncStr(s, 12)),
z: responses.map(r => get_items(r.eval_res).reduce((acc, val) => (acc + val), 0) / r.eval_res.items.length), // calculates mean
mode: 'markers',
}
}
}
if (!Array.isArray(spec))
spec = [spec];
setPlotLegend(plot_legend);
setPlotlyObj((
<Plot
data={spec}
layout={layout}
/>
))
));
}, [multiSelectVars, multiSelectValue, responses]);
const handleOnConnect = useCallback(() => {
@ -258,7 +326,10 @@ const VisNode = ({ data, id }) => {
size="sm"
value={multiSelectValue}
searchable />
<div className="nodrag">{plotlyObj}</div>
<div className="nodrag">
{plotlyObj}
{plotLegend ? plotLegend : <></>}
</div>
<Handle
type="target"
position="left"

View File

@ -157,6 +157,10 @@
border: 1px solid #999;
border-radius: 5px;
}
.plot-legend {
font-size: 11px;
font-family: monospace;
}
.inspector-node {
background-color: #fff;

View File

@ -31,7 +31,7 @@ class ResponseInfo:
"""Stores info about a single response. Passed to evaluator functions."""
text: str
prompt: str
var: str
var: list
llm: str
def __str__(self):

View File

@ -10,8 +10,8 @@ from promptengine.template import PromptTemplate, PromptPermutationGenerator
# A 'cheap' version of controlling for rate limits is to wait a few seconds between batches of requests being sent off.
# The following is only a guideline, and a bit on the conservative side.
MAX_SIMULTANEOUS_REQUESTS = {
LLM.ChatGPT: (50, 10), # max 50 requests a batch; wait 10 seconds between
LLM.GPT4: (20, 10), # max 10 requests a batch; wait 10 seconds between
LLM.ChatGPT: (30, 10), # max 30 requests a batch; wait 10 seconds between
LLM.GPT4: (5, 10), # max 5 requests a batch; wait 10 seconds between
LLM.Alpaca7B: (1, 0), # 1 indicates synchronous
}