diff --git a/chain-forge/package-lock.json b/chain-forge/package-lock.json index 8cc28a7..b362d3a 100644 --- a/chain-forge/package-lock.json +++ b/chain-forge/package-lock.json @@ -14,6 +14,7 @@ "@reactflow/background": "^11.2.0", "@reactflow/controls": "^11.1.11", "@reactflow/core": "^11.7.0", + "@reactflow/node-resizer": "^2.1.0", "@tabler/icons-react": "^2.17.0", "@testing-library/jest-dom": "^5.16.5", "@testing-library/react": "^13.4.0", @@ -3814,6 +3815,22 @@ "react-dom": ">=17" } }, + "node_modules/@reactflow/node-resizer": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/@reactflow/node-resizer/-/node-resizer-2.1.0.tgz", + "integrity": "sha512-DVL8nnWsltP8/iANadAcTaDB4wsEkx2mOLlBEPNE3yc5loSm3u9l5m4enXRcBym61MiMuTtDPzZMyYYQUjuYIg==", + "dependencies": { + "@reactflow/core": "^11.6.0", + "classcat": "^5.0.4", + "d3-drag": "^3.0.0", + "d3-selection": "^3.0.0", + "zustand": "^4.3.1" + }, + "peerDependencies": { + "react": ">=17", + "react-dom": ">=17" + } + }, "node_modules/@rollup/plugin-babel": { "version": "5.3.1", "resolved": "https://registry.npmjs.org/@rollup/plugin-babel/-/plugin-babel-5.3.1.tgz", diff --git a/chain-forge/package.json b/chain-forge/package.json index 907e9f7..dc25c0e 100644 --- a/chain-forge/package.json +++ b/chain-forge/package.json @@ -9,6 +9,7 @@ "@reactflow/background": "^11.2.0", "@reactflow/controls": "^11.1.11", "@reactflow/core": "^11.7.0", + "@reactflow/node-resizer": "^2.1.0", "@tabler/icons-react": "^2.17.0", "@testing-library/jest-dom": "^5.16.5", "@testing-library/react": "^13.4.0", diff --git a/chain-forge/src/ControlledTextArea.js b/chain-forge/src/ControlledTextArea.js new file mode 100644 index 0000000..1d7a0ef --- /dev/null +++ b/chain-forge/src/ControlledTextArea.js @@ -0,0 +1,22 @@ +import React, { useEffect, useRef, useState } from 'react'; + +/* Modified from https://stackoverflow.com/a/68928267 */ +const ControlledTextArea = (props) => { + const { value, onChange, ...rest } = props; + const [cursor, setCursor] = useState(null); + const ref = useRef(null); + + useEffect(() => { + const input = ref.current; + if (input) input.setSelectionRange(cursor, cursor); + }, [ref, cursor, value]); + + const handleChange = (e) => { + setCursor(e.target.selectionStart); + onChange && onChange(e); + }; + + return + {Object.keys(data.fields).length > 1 ? () : <>} + + ))); + } + + }, [data.fields, handleInputChange]); const setRef = useCallback((elem) => { // To listen for resize events of the textarea, we need to use a ResizeObserver. @@ -140,7 +172,7 @@ const TextFieldsNode = ({ data, id }) => {
} />
- {fields} + {textFields}
{ return [splitAndAddBreaks(truncStr(s, max_len), 60)]; }).flat(); } +const getUniqueKeysInResponses = (responses, keyFunc) => { + let ukeys = new Set(); + responses.forEach(res_obj => + ukeys.add(keyFunc(res_obj))); + return Array.from(ukeys); +}; +const extractEvalResultsForMetric = (metric, responses) => { + return responses.map(resp_obj => resp_obj.eval_res.items.map(item => item[metric])).flat(); +}; const VisNode = ({ data, id }) => { @@ -43,6 +53,8 @@ const VisNode = ({ data, id }) => { const [pastInputs, setPastInputs] = useState([]); const [responses, setResponses] = useState([]); + const [plotLegend, setPlotLegend] = useState(null); + // The MultiSelect so people can dynamically set what vars they care about const [multiSelectVars, setMultiSelectVars] = useState(data.vars || []); const [multiSelectValue, setMultiSelectValue] = useState(data.selected_vars || []); @@ -65,14 +77,15 @@ const VisNode = ({ data, id }) => { // (This is assumed to be consistent across response batches) const typeof_eval_res = 'dtype' in responses[0].eval_res ? responses[0].eval_res['dtype'] : 'Numeric'; - let metric_ax_label = null; + let plot_legend = null; + let metric_axes_labels = []; + let num_metrics = 1; if (typeof_eval_res.includes('KeyValue')) { - // Check if it's a single-item dict; in which case we can extract the key to name the axis: - const keys = Object.keys(responses[0].eval_res.items[0]); - if (keys.length === 1) - metric_ax_label = keys[0]; - else - throw Error('Dict metrics with more than one key are currently unsupported.') + metric_axes_labels = Object.keys(responses[0].eval_res.items[0]); + num_metrics = metric_axes_labels.length; + + // if (metric_axes_labels.length > 1) + // throw Error('Dict metrics with more than one key are currently unsupported.') // TODO: When multiple metrics are present, and 1 var is selected (can be multiple LLMs as well), // default to Parallel Coordinates plot, with the 1 var values on the y-axis as colored groups, and metrics on x-axis. // For multiple LLMs, add a control drop-down selector to switch the LLM visualized in the plot. @@ -81,13 +94,13 @@ const VisNode = ({ data, id }) => { const get_items = (eval_res_obj) => { if (typeof_eval_res.includes('KeyValue')) - return eval_res_obj.items.map(item => item[metric_ax_label]); + return eval_res_obj.items.map(item => item[metric_axes_labels[0]]); return eval_res_obj.items; }; // Create Plotly spec here const varnames = multiSelectValue; - const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9']; + const colors = ['#baf078', '#f1b963', '#e46161', '#8888f9', '#33bef0', '#defcf9', '#cadefc', '#f8f398']; let spec = []; let layout = { width: 420, height: 300, title: '', margin: { @@ -129,68 +142,123 @@ const VisNode = ({ data, id }) => { }); layout.boxmode = 'group'; - if (metric_ax_label) + if (metric_axes_labels.length > 0) layout.xaxis = { - title: { font: {size: 12}, text: metric_ax_label }, + title: { font: {size: 12}, text: metric_axes_labels[0] }, }; }; - if (varnames.length === 0) { - // No variables means they used a single prompt (no template) to generate responses - // (Users are likely evaluating differences in responses between LLMs) - plot_grouped_boxplot((r) => truncStr(r.prompt.trim(), 12)); - } - else if (varnames.length === 1) { - // 1 var; numeric eval - if (llm_names.length === 1) { - // Simple box plot, as there is only a single LLM in the response - // Get all possible values of the single variable response ('name' vals) - const names = new Set(responses.map(r => r.vars[varnames[0]].trim())); - for (const name of names) { - let x_items = []; - let text_items = []; - responses.forEach(r => { - if (r.vars[varnames[0]].trim() !== name) return; - x_items = x_items.concat(get_items(r.eval_res)); - text_items = text_items.concat(createHoverTexts(r.responses)); - }); - spec.push( - {type: 'box', x: x_items, name: truncStr(name, 12), boxpoints: 'all', text: text_items, hovertemplate: '%{text}', orientation: 'h'} - ); + if (num_metrics > 1) { + // For 2 or more metrics, display a parallel coordinates plot. + // :: For instance, if evaluator produces { height: 32, weight: 120 } plot responses with 2 metrics, 'height' and 'weight' + if (varnames.length === 1) { + console.log("Plotting parallel coordinates..."); + let unique_vals = getUniqueKeysInResponses(responses, (resp_obj) => resp_obj.vars[varnames[0]]); + let group_colors = colors; + + let colorscale = []; + for (let i = 0; i < unique_vals.length; i++) { + colorscale.push([i / (unique_vals.length-1), group_colors[i % group_colors.length]]); } - layout.hovermode = 'closest'; - if (metric_ax_label) - layout.xaxis = { - title: { font: {size: 12}, text: metric_ax_label }, - }; + let dimensions = []; + metric_axes_labels.forEach(metric => { + const evals = extractEvalResultsForMetric(metric, responses); + dimensions.push({ + range: [Math.min(...evals), Math.max(...evals)], + label: metric, + values: evals, + }); + }); + + spec.push({ + type: 'parcoords', + pad: [10, 10, 10, 10], + line: { + color: responses.map(resp_obj => { + const idx = unique_vals.indexOf(resp_obj.vars[varnames[0]]); + return Array(resp_obj.eval_res.items.length).fill(idx); + }).flat(), + colorscale: colorscale, + }, + dimensions: dimensions, + }); + layout.margin = { l: 40, r: 40, b: 40, t: 50, pad: 0 }; + layout.paper_bgcolor = "white"; + layout.font = {color: "black"}; + + // There's no built-in legend for parallel coords, unfortunately, so we need to construct our own: + let legend_labels = {}; + unique_vals.forEach((v, idx) => + {legend_labels[v] = group_colors[idx];} + ); + plot_legend = (); + + console.log(spec); } else { - // There are multiple LLMs in the response; do a grouped box plot by LLM. - // Note that 'name' is now the LLM, and 'x' stores the value of the var: - plot_grouped_boxplot((r) => r.vars[varnames[0]].trim()); + console.error("Plotting evaluations with more than one metric and more than one prompt parameter is currently unsupported."); } } - else if (varnames.length === 2) { - // Input is 2 vars; numeric eval - // Display a 3D scatterplot with 2 dimensions: - spec = { - type: 'scatter3d', - x: responses.map(r => r.vars[varnames[0]]).map(s => truncStr(s, 12)), - y: responses.map(r => r.vars[varnames[1]]).map(s => truncStr(s, 12)), - z: responses.map(r => get_items(r.eval_res).reduce((acc, val) => (acc + val), 0) / r.eval_res.items.length), // calculates mean - mode: 'markers', + else { // A single metric --use plots like grouped box-and-whiskers, 3d scatterplot + if (varnames.length === 0) { + // No variables means they used a single prompt (no template) to generate responses + // (Users are likely evaluating differences in responses between LLMs) + plot_grouped_boxplot((r) => truncStr(r.prompt.trim(), 12)); + } + else if (varnames.length === 1) { + // 1 var; numeric eval + if (llm_names.length === 1) { + // Simple box plot, as there is only a single LLM in the response + // Get all possible values of the single variable response ('name' vals) + const names = new Set(responses.map(r => r.vars[varnames[0]].trim())); + for (const name of names) { + let x_items = []; + let text_items = []; + responses.forEach(r => { + if (r.vars[varnames[0]].trim() !== name) return; + x_items = x_items.concat(get_items(r.eval_res)); + text_items = text_items.concat(createHoverTexts(r.responses)); + }); + spec.push( + {type: 'box', x: x_items, name: truncStr(name, 12), boxpoints: 'all', text: text_items, hovertemplate: '%{text}', orientation: 'h'} + ); + } + layout.hovermode = 'closest'; + + if (metric_axes_labels.length > 0) + layout.xaxis = { + title: { font: {size: 12}, text: metric_axes_labels[0] }, + }; + } else { + // There are multiple LLMs in the response; do a grouped box plot by LLM. + // Note that 'name' is now the LLM, and 'x' stores the value of the var: + plot_grouped_boxplot((r) => r.vars[varnames[0]].trim()); + } + } + else if (varnames.length === 2) { + // Input is 2 vars; numeric eval + // Display a 3D scatterplot with 2 dimensions: + spec = { + type: 'scatter3d', + x: responses.map(r => r.vars[varnames[0]]).map(s => truncStr(s, 12)), + y: responses.map(r => r.vars[varnames[1]]).map(s => truncStr(s, 12)), + z: responses.map(r => get_items(r.eval_res).reduce((acc, val) => (acc + val), 0) / r.eval_res.items.length), // calculates mean + mode: 'markers', + } } } if (!Array.isArray(spec)) spec = [spec]; + setPlotLegend(plot_legend); setPlotlyObj(( - )) + )); + }, [multiSelectVars, multiSelectValue, responses]); const handleOnConnect = useCallback(() => { @@ -258,7 +326,10 @@ const VisNode = ({ data, id }) => { size="sm" value={multiSelectValue} searchable /> -
{plotlyObj}
+
+ {plotlyObj} + {plotLegend ? plotLegend : <>} +