From 13298c999870144578a560756509963083781a2e Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Tue, 18 Feb 2025 22:47:24 -0500 Subject: [PATCH] Add bootstrap 95% CIs in plots --- chainforge/react-server/package-lock.json | 15 + chainforge/react-server/package.json | 2 + chainforge/react-server/src/VisNode.tsx | 1367 ++++++++++++--------- chainforge/react-server/tsconfig.json | 1 + chainforge/react-server/types/jstat.d.ts | 3 + 5 files changed, 787 insertions(+), 601 deletions(-) create mode 100644 chainforge/react-server/types/jstat.d.ts diff --git a/chainforge/react-server/package-lock.json b/chainforge/react-server/package-lock.json index 4a49fa2..12c05a4 100644 --- a/chainforge/react-server/package-lock.json +++ b/chainforge/react-server/package-lock.json @@ -61,6 +61,7 @@ "emoji-picker-react": "^4.4.9", "google-auth-library": "^8.8.0", "https-browserify": "^1.0.0", + "jstat": "^1.9.6", "lazysizes": "^5.3.2", "lodash": "^4.17.21", "lz-string": "^1.5.0", @@ -91,6 +92,7 @@ "react-scripts": "5.0.1", "reactflow": "^11.0", "request": "^2.88.2", + "simple-statistics": "^7.8.7", "socket.io-client": "^4.6.1", "stream-browserify": "^3.0.0", "stream-http": "^3.2.0", @@ -17370,6 +17372,11 @@ "jss-plugin-vendor-prefixer": "10.10.0" } }, + "node_modules/jstat": { + "version": "1.9.6", + "resolved": "https://registry.npmjs.org/jstat/-/jstat-1.9.6.tgz", + "integrity": "sha512-rPBkJbK2TnA8pzs93QcDDPlKcrtZWuuCo2dVR0TFLOJSxhqfWOVCSp8aV3/oSbn+4uY4yw1URtLpHQedtmXfug==" + }, "node_modules/jsx-ast-utils": { "version": "3.3.5", "resolved": "https://registry.npmjs.org/jsx-ast-utils/-/jsx-ast-utils-3.3.5.tgz", @@ -22771,6 +22778,14 @@ "resolved": "https://registry.npmjs.org/signum/-/signum-1.0.0.tgz", "integrity": "sha512-yodFGwcyt59XRh7w5W3jPcIQb3Bwi21suEfT7MAWnBX3iCdklJpgDgvGT9o04UonglZN5SNMfJFkHIR/jO8GHw==" }, + "node_modules/simple-statistics": { + "version": "7.8.7", + "resolved": "https://registry.npmjs.org/simple-statistics/-/simple-statistics-7.8.7.tgz", + "integrity": "sha512-ed5FwTNYvkMTfbCai1U+r3symP+lIPKWCqKdudpN4NFNMn9RtDlFtSyAQhCp4oPH0YBjWu/qnW+5q5ZkPB3uHQ==", + "engines": { + "node": "*" + } + }, "node_modules/sisteransi": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/sisteransi/-/sisteransi-1.0.5.tgz", diff --git a/chainforge/react-server/package.json b/chainforge/react-server/package.json index a7a06f4..2bc112a 100644 --- a/chainforge/react-server/package.json +++ b/chainforge/react-server/package.json @@ -59,6 +59,7 @@ "emoji-picker-react": "^4.4.9", "google-auth-library": "^8.8.0", "https-browserify": "^1.0.0", + "jstat": "^1.9.6", "lazysizes": "^5.3.2", "lodash": "^4.17.21", "lz-string": "^1.5.0", @@ -89,6 +90,7 @@ "react-scripts": "5.0.1", "reactflow": "^11.0", "request": "^2.88.2", + "simple-statistics": "^7.8.7", "socket.io-client": "^4.6.1", "stream-browserify": "^3.0.0", "stream-http": "^3.2.0", diff --git a/chainforge/react-server/src/VisNode.tsx b/chainforge/react-server/src/VisNode.tsx index 51cdb6c..f8109e3 100644 --- a/chainforge/react-server/src/VisNode.tsx +++ b/chainforge/react-server/src/VisNode.tsx @@ -5,6 +5,7 @@ import React, { useRef, forwardRef, useImperativeHandle, + useTransition, } from "react"; import { Handle, Position } from "reactflow"; import { Button, Menu, NativeSelect } from "@mantine/core"; @@ -12,7 +13,7 @@ import useStore, { colorPalettes } from "./store"; import Plot from "react-plotly.js"; import BaseNode from "./BaseNode"; import NodeLabel from "./NodeLabelComponent"; -import PlotLegend, { PlotLegendProps } from "./PlotLegend"; +import PlotLegend from "./PlotLegend"; import { cleanMetavarsFilterFunc, truncStr } from "./backend/utils"; import { Dict, @@ -25,6 +26,94 @@ import { Status } from "./StatusIndicatorComponent"; import { grabResponses } from "./backend/backend"; import { IconChartBar, IconChartHistogram } from "@tabler/icons-react"; +/** + * STATS + */ +import { + sampleWithReplacement, + mean, + quantile, + standardDeviation, + sum, +} from "simple-statistics"; +import * as jStat from "jstat"; // jStat is a pure JS library without types + +const bootstrapCI = ( + values: number[], + numSamples = 1000, + alpha = 0.05, + overFunc?: (ns: number[]) => number, +) => { + const means = []; + const f = overFunc ?? mean; + for (let i = 0; i < numSamples; i++) { + // Resample with replacement + const resampled = sampleWithReplacement(values, values.length, Math.random); + means.push(f(resampled)); + } + + // Compute percentiles for the confidence interval + const lowerBound = quantile(means, alpha / 2); // 2.5th percentile + const upperBound = quantile(means, 1 - alpha / 2); // 97.5th percentile + + return { + ciMean: mean(means), // Bootstrap mean (could be slightly different from original mean) + lowerBound: lowerBound, + upperBound: upperBound, + }; +}; + +/** + * Computes the lower and upper bound for an error bar to display in a Plotly plot. + * @param samples The samples to compute the error bar for + * @param scaleBy A value to scale the outputs by + * @param overFunc The default function is mean. However, we might want to calculate CI over other values, such as sum or SD. In this case, we must use bootstrapping, since the t-stat standard error method doesn't work in these cases. + * @returns The [lowerBound, upperBound] values as a 2-item array, normalized to 100%. + */ +const computeErrorBar = ( + samples: number[], + scaleBy?: number, + overFunc?: (ns: number[]) => number, +) => { + if (samples.length < 2) return [0, 0]; // Not enough information + const scalar = scaleBy ?? 1.0; + + // Choose method depending on # of samples + // NOTE: The cutoff here is informed by research of Zhu & Kolassa (https://doi.org/10.1080/03610918.2017.1348516) + // which shows that below sample size 50, t-test provides a more reliable predictor of actual CI than bootstrapping methods. + if (samples.length < 50 && !overFunc) { + // Fallback to standard error-based confidence interval using t-stat + const se = standardDeviation(samples) / Math.sqrt(samples.length); // Compute standard Error + const t_value = (jStat as any).studentt.inv( + 1 - (1 - 0.95) / 2, + samples.length - 1, + ); // 95% CI (assuming normality) + const m = mean(samples); + console.warn("Error bar t-stat:", m, se, t_value); + return [(m - t_value * se) * scalar, (m + t_value * se) * scalar]; + } else { + // Compute a bootstrap 95% CI (confidence interval). + // NOTE: We use bootstrapping because with prompts, we can *never* assume + // the sample of LLM outputs is representative of the population for the user's hypothesis. + // LLM outputs also don't have to follow a normal distribution. + // This is resource-intensive but a much more reliable approx. than standard error/dev. + const { ciMean, lowerBound, upperBound } = bootstrapCI( + samples, + 1000, + 0.05, + overFunc, + ); + console.warn("Error bar 95% CI:", ciMean, lowerBound, upperBound); + return [(ciMean - lowerBound) * scalar, (upperBound - ciMean) * scalar]; + } +}; + +const castEvalScoreToNum = (score: EvaluationScore): number => { + if (typeof score === "number") return score; + else if (typeof score === "boolean") return score === true ? 1 : 0; + else return 0; // unknown, soft fail +}; + /** * UTIL FUNCTIONS FOR VIS PLOTS */ @@ -161,6 +250,7 @@ export interface VisViewProps { responses: LLMResponse[]; id?: string; data?: VisNodeData; + whenReplotting?: (isReplotting: boolean) => void; } export interface VisViewRef { resetControls: (responses: LLMResponse[]) => void; @@ -170,7 +260,7 @@ export interface VisViewRef { * Inner component for code evaluators/processors, storing the body of the UI (outside of the header and footers). */ export const VisView = forwardRef( - function CodeEvaluatorComponent({ responses, id, data }, ref) { + function CodeEvaluatorComponent({ responses, id, data, whenReplotting }, ref) { const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); const getColorForLLMAndSetIfNotFound = useStore( (state) => state.getColorForLLMAndSetIfNotFound, @@ -181,6 +271,9 @@ export const VisView = forwardRef( const [plotlySpec, setPlotlySpec] = useState([]); const [plotlyLayout, setPlotlyLayout] = useState({}); + // So updating the plot doesn't block the UI + const [isPlotRerenderPending, startTransition] = useTransition(); + // For some data types, there are multiple graph options available... const graphOptions = [ { key: "bar", label: "Bar Chart", icon: }, @@ -314,6 +407,12 @@ export const VisView = forwardRef( resetControls, })); + // Pending transitions display loading spinner + useEffect(() => { + if (!whenReplotting) return; + whenReplotting(isPlotRerenderPending); + }, [isPlotRerenderPending]); + // Re-plot responses when any responses or settings change useEffect(() => { if (!responses || responses.length === 0 || !multiSelectValue) return; @@ -336,635 +435,676 @@ export const VisView = forwardRef( return; } - // setStatus(Status.NONE); + startTransition(() => { + const get_llm = (resp_obj: LLMResponse) => { + if (selectedLLMGroup === "LLM") + return typeof resp_obj.llm === "string" + ? resp_obj.llm + : resp_obj.llm?.name; + else return resp_obj.metavars[selectedLLMGroup] as string; + }; + const getLLMsInResponses = (responses: LLMResponse[]) => + getUniqueKeysInResponses(responses, get_llm); - const get_llm = (resp_obj: LLMResponse) => { - if (selectedLLMGroup === "LLM") - return typeof resp_obj.llm === "string" - ? resp_obj.llm - : resp_obj.llm?.name; - else return resp_obj.metavars[selectedLLMGroup] as string; - }; - const getLLMsInResponses = (responses: LLMResponse[]) => - getUniqueKeysInResponses(responses, get_llm); + // Get all LLMs in responses, by selected LLM group + const llm_names = getLLMsInResponses(responses); - // Get all LLMs in responses, by selected LLM group - const llm_names = getLLMsInResponses(responses); - - // Create Plotly spec here - const varnames = - multiSelectValue !== "LLM (default)" && multiSelectValue !== undefined - ? [multiSelectValue] - : []; - const varcolors = colorPalettes.var; // ['#44d044', '#f1b933', '#e46161', '#8888f9', '#33bef0', '#bb55f9', '#cadefc', '#f8f398']; - let spec: Dict[] | Dict = []; - const layout: Dict = { - autosize: true, - dragmode: "pan", - title: "", - margin: { - l: 125, - r: 0, - b: 36, - t: 20, - pad: 6, - }, - yaxis: { showgrid: true }, - }; - - // Bucket responses by LLM: - const responses_by_llm: Dict = {}; - responses.forEach((item) => { - const llm = get_llm(item); - if (llm in responses_by_llm) responses_by_llm[llm].push(item); - else responses_by_llm[llm] = [item]; - }); - - // Get the type of evaluation results, if present - // (This is assumed to be consistent across response batches) - let typeof_eval_res = - responses[0].eval_res && "dtype" in responses[0].eval_res - ? responses[0].eval_res.dtype - : "Numeric"; - - // If categorical type, check if all binary: - if (typeof_eval_res === "Categorical") { - const is_all_bools = responses.reduce( - (acc0: boolean, res_obj: LLMResponse) => - acc0 && - res_obj.eval_res !== undefined && - res_obj.eval_res.items?.reduce( - (acc: boolean, cur: EvaluationScore) => - acc && typeof cur === "boolean", - true, - ), - true, - ); - if (is_all_bools) { - typeof_eval_res = "Boolean"; - setDisableGraphTypeOption(true); - } - } else { - setDisableGraphTypeOption(false); - } - - // Check the max length of eval results, as if it's only 1 score per item (num of generations per prompt n=1), - // we might want to plot the result differently: - let max_num_results_per_prompt = 1; - responses.forEach((res_obj) => { - if ( - res_obj.eval_res !== undefined && - res_obj.eval_res?.items?.length > max_num_results_per_prompt - ) - max_num_results_per_prompt = res_obj.eval_res.items.length; - }); - - let plot_legend: React.ReactNode | null = null; - let metric_axes_labels: string[] = []; - let num_metrics = 1; - if ( - typeof_eval_res.includes("KeyValue") && - responses[0].eval_res !== undefined - ) { - metric_axes_labels = Object.keys(responses[0].eval_res.items[0]); - num_metrics = metric_axes_labels.length; - } - - const get_var = ( - resp_obj: LLMResponse, - varname: string, - empty_str_if_undefined = false, - ) => { - const v = varname.startsWith("__meta_") - ? resp_obj.metavars[varname.slice("__meta_".length)] - : resp_obj.vars[varname]; - if (v === undefined && empty_str_if_undefined) return ""; - return v; - }; - - const get_var_and_trim = ( - resp_obj: LLMResponse, - varname: string, - empty_str_if_undefined = false, - ) => { - const v = get_var(resp_obj, varname, empty_str_if_undefined); - if (v !== undefined) return v.trim(); - else return v; - }; - - const get_items = (eval_res_obj?: EvaluationResults) => { - if (eval_res_obj === undefined) return []; - if (typeof_eval_res.includes("KeyValue")) - return eval_res_obj.items.map( - (item) => - (item as Dict)[metric_axes_labels[0]], - ); - return eval_res_obj.items; - }; - - // Only for Boolean data - const plot_accuracy = ( - resp_to_x: (r: LLMResponse) => string, - group_type: "var" | "llm", - ) => { - // Plots the percentage of 'true' evaluations out of the total number of evaluations, - // per category of 'resp_to_x', as a horizontal bar chart, with different colors per category. - const names = new Set(responses.map(resp_to_x)); - const shortnames = genUniqueShortnames(names); - const x_items: number[] = []; - const y_items: string[] = []; - const marker_colors: string[] = []; - for (const name of names) { - // Add a shortened version of the name as the y-tick - y_items.push(shortnames[name]); - - // Calculate how much percentage a single 'true' value counts for: - const num_eval_scores = responses.reduce((acc, r) => { - if (resp_to_x(r) !== name) return acc; - else return acc + get_items(r.eval_res).length; - }, 0); - const perc_scalar = 100 / num_eval_scores; - - // Calculate the length of the bar - x_items.push( - responses.reduce((acc: number, r: LLMResponse) => { - if (resp_to_x(r) !== name) return acc; - else - return ( - acc + - get_items(r.eval_res).filter((res) => res === true).length * - perc_scalar - ); - }, 0), - ); - - // Lookup the color per LLM when displaying LLM differences, - // otherwise use the palette for displaying variables. - const color = - group_type === "llm" - ? getColorForLLMAndSetIfNotFound(name) - : getColorForLLMAndSetIfNotFound(get_llm(responses[0])); - marker_colors.push(color); - } - - // Set the left margin to fit the yticks labels - layout.margin.l = calcLeftPaddingForYLabels(Object.values(shortnames)); - - spec = [ - { - type: "bar", - y: y_items, - x: x_items, - marker: { - color: marker_colors, - }, - // error_x: { // TODO: Error bars - // type: 'data', - // array: [0.5, 1, 2], - // visible: true - // }, - hovertemplate: "%{x:.2f}%%{y}", - showtrace: false, - orientation: "h", + // Create Plotly spec here + const varnames = + multiSelectValue !== "LLM (default)" && multiSelectValue !== undefined + ? [multiSelectValue] + : []; + const varcolors = colorPalettes.var; // ['#44d044', '#f1b933', '#e46161', '#8888f9', '#33bef0', '#bb55f9', '#cadefc', '#f8f398']; + let spec: Dict[] | Dict = []; + const layout: Dict = { + autosize: true, + dragmode: "pan", + title: "", + margin: { + l: 125, + r: 0, + b: 36, + t: 20, + pad: 6, }, - ]; - layout.xaxis = { - range: [0, 100], - tickmode: "linear", - tick0: 0, - dtick: 10, + yaxis: { showgrid: true }, }; - setForcedGraphType("bar"); // bar chart + // Bucket responses by LLM: + const responses_by_llm: Dict = {}; + responses.forEach((item) => { + const llm = get_llm(item); + if (llm in responses_by_llm) responses_by_llm[llm].push(item); + else responses_by_llm[llm] = [item]; + }); - if (metric_axes_labels.length > 0) - layout.xaxis = { - title: { font: { size: 12 }, text: metric_axes_labels[0] }, - ...layout.xaxis, - }; - else - layout.xaxis = { - title: { font: { size: 12 }, text: "% percent true" }, - ...layout.xaxis, - }; - }; + // Get the type of evaluation results, if present + // (This is assumed to be consistent across response batches) + let typeof_eval_res = + responses[0].eval_res && "dtype" in responses[0].eval_res + ? responses[0].eval_res.dtype + : "Numeric"; - const plot_simple_boxplot = ( - resp_to_x: (r: LLMResponse) => string, - group_type: "var" | "llm", - ) => { - let names = new Set(); - const plotting_categorical_vars = - group_type === "var" && typeof_eval_res === "Categorical"; - - // When we're plotting vars, we want the stacked bar colors to be the *categories*, - // and the x_items to be the names of vars, so that the left axis is a vertical list of varnames. - if (plotting_categorical_vars) { - // Get all categories present in the evaluation results - responses.forEach((r) => - get_items(r.eval_res).forEach((i) => names.add(i.toString())), + // If categorical type, check if all binary: + if (typeof_eval_res === "Categorical") { + const is_all_bools = responses.reduce( + (acc0: boolean, res_obj: LLMResponse) => + acc0 && + res_obj.eval_res !== undefined && + res_obj.eval_res.items?.reduce( + (acc: boolean, cur: EvaluationScore) => + acc && typeof cur === "boolean", + true, + ), + true, ); + if (is_all_bools) { + typeof_eval_res = "Boolean"; + setDisableGraphTypeOption(true); + } } else { - // Get all possible values of the single variable response ('name' vals) - names = new Set(responses.map(resp_to_x)); + setDisableGraphTypeOption(false); } - const shortnames = genUniqueShortnames(names); - for (const name of names) { - let x_items: EvaluationScore[] = []; - let text_items: string[] = []; + // Check the max length of eval results, as if it's only 1 score per item (num of generations per prompt n=1), + // we might want to plot the result differently: + let max_num_results_per_prompt = 1; + responses.forEach((res_obj) => { + if ( + res_obj.eval_res !== undefined && + res_obj.eval_res?.items?.length > max_num_results_per_prompt + ) + max_num_results_per_prompt = res_obj.eval_res.items.length; + }); - if (plotting_categorical_vars) { - responses.forEach((r) => { - // Get all evaluation results for this response which match the category 'name': - const eval_res = get_items(r.eval_res).filter((i) => i === name); - x_items = x_items.concat( - new Array(eval_res.length).fill(resp_to_x(r)), - ); - }); - } else { - responses.forEach((r) => { - if (resp_to_x(r) !== name) return; - x_items = x_items.concat(get_items(r.eval_res)); - text_items = text_items.concat( - createHoverTexts( - r.responses.map((v) => (typeof v === "string" ? v : v.d)), - ), - ); - }); + let plot_legend: React.ReactNode | null = null; + let metric_axes_labels: string[] = []; + let num_metrics = 1; + if ( + typeof_eval_res.includes("KeyValue") && + responses[0].eval_res !== undefined + ) { + metric_axes_labels = Object.keys(responses[0].eval_res.items[0]); + num_metrics = metric_axes_labels.length; + } + + const get_var = ( + resp_obj: LLMResponse, + varname: string, + empty_str_if_undefined = false, + ) => { + const v = varname.startsWith("__meta_") + ? resp_obj.metavars[varname.slice("__meta_".length)] + : resp_obj.vars[varname]; + if (v === undefined && empty_str_if_undefined) return ""; + return v; + }; + + const get_var_and_trim = ( + resp_obj: LLMResponse, + varname: string, + empty_str_if_undefined = false, + ) => { + const v = get_var(resp_obj, varname, empty_str_if_undefined); + if (v !== undefined) return v.trim(); + else return v; + }; + + const get_items = (eval_res_obj?: EvaluationResults) => { + if (eval_res_obj === undefined) return []; + if (typeof_eval_res.includes("KeyValue")) + return eval_res_obj.items.map( + (item) => + (item as Dict)[ + metric_axes_labels[0] + ], + ); + return eval_res_obj.items; + }; + + // Only for Boolean data + const plot_accuracy = ( + resp_to_x: (r: LLMResponse) => string, + group_type: "var" | "llm", + ) => { + // Plots the percentage of 'true' evaluations out of the total number of evaluations, + // per category of 'resp_to_x', as a horizontal bar chart, with different colors per category. + const names = new Set(responses.map(resp_to_x)); + const shortnames = genUniqueShortnames(names); + const x_items: number[] = []; + const y_items: string[] = []; + const marker_colors: string[] = []; + const error_values: number[][] = []; + for (const name of names) { + // Add a shortened version of the name as the y-tick + y_items.push(shortnames[name]); + + // Calculate the number of true values over the total possible number + let num_true_vals = 0; + let num_eval_scores = 0; + const all_samples: number[] = []; + for (const r of responses) { + if (resp_to_x(r) !== name) continue; + const items = get_items(r.eval_res); + Array.prototype.push.apply( + all_samples, + items.map((i) => (i === true ? 1 : 0)), + ); // extend the `all_samples` array + num_eval_scores += items.length; + num_true_vals += items.filter((res) => res === true).length; + } + if (num_eval_scores > 0) + x_items.push(num_true_vals * (100 / num_eval_scores)); + + // Compute error bar info + error_values.push(computeErrorBar(all_samples, 100)); + + // Compute standard error for the error bar (SE = sqrt(p(1-p)/n) * 100) + // const p = num_true_vals / num_eval_scores; + // const standard_error = + // Math.sqrt((p * (1 - p)) / num_eval_scores) * 100; + // error_values.push(standard_error); + + // Lookup the color per LLM when displaying LLM differences, + // otherwise use the palette for displaying variables. + const color = + group_type === "llm" + ? getColorForLLMAndSetIfNotFound(name) + : getColorForLLMAndSetIfNotFound(get_llm(responses[0])); + marker_colors.push(color); } - // Lookup the color per LLM when displaying LLM differences, - // otherwise use the palette for displaying variables. - const color = - group_type === "llm" - ? getColorForLLMAndSetIfNotFound(name) - : // : varcolors[name_idx % varcolors.length]; - getColorForLLMAndSetIfNotFound(get_llm(responses[0])); + // Set the left margin to fit the yticks labels + layout.margin.l = calcLeftPaddingForYLabels( + Object.values(shortnames), + ); - if ( - typeof_eval_res === "Boolean" || - typeof_eval_res === "Categorical" - ) { - // Plot a histogram for categorical or boolean data. - spec.push({ - type: "histogram", - histfunc: "sum", - name: shortnames[name], - marker: { color }, - y: x_items, + spec = [ + { + type: "bar", + y: y_items, + x: x_items, + marker: { + color: marker_colors, + }, + error_x: { + type: "data", + // Asymmetric errors bars, since we're using bootstrapping to determine the 95% CI + array: error_values.map((e) => e[1]), // Upper bound + arrayminus: error_values.map((e) => e[0]), // Lower bound + visible: true, + }, + hovertemplate: "%{x:.2f}%%{y}", + showtrace: false, orientation: "h", - }); - layout.barmode = "stack"; - layout.yaxis = { - showticklabels: true, - dtick: 1, - type: "category", - showgrid: true, - }; + }, + ]; + layout.xaxis = { + range: [0, 100], + tickmode: "linear", + tick0: 0, + dtick: 10, + }; + + setForcedGraphType("bar"); // bar chart + + if (metric_axes_labels.length > 0) layout.xaxis = { - title: { font: { size: 12 }, text: "Number of 'true' values" }, + title: { font: { size: 12 }, text: metric_axes_labels[0] }, ...layout.xaxis, }; - } else { - // Plot bar or boxplots for all other cases. - // x_items = [x_items.reduce((val, acc) => val + acc, 0)]; - const d: Dict = { - name: shortnames[name], - x: x_items, - text: text_items, - hovertemplate: "%{text}", - orientation: "h", - marker: { color }, + else + layout.xaxis = { + title: { font: { size: 12 }, text: "% percent true" }, + ...layout.xaxis, }; + }; - // If only one result, plot a bar chart: - if (x_items.length === 1) { - d.type = "bar"; - d.textposition = "none"; // hide the text which appears within each bar - d.y = new Array(x_items.length).fill(shortnames[name]); - setForcedGraphType("bar"); - } else { - // If multiple eval results per response object (num generations per prompt n > 1), - // let user decide: - if (graphType.key === "bar") { - d.type = "histogram"; - d.histfunc = "sum"; - d.y = new Array(x_items.length).fill(shortnames[name]); - d.textposition = "none"; // hide the text which appears within each bar - layout.xaxis = { - title: { font: { size: 12 }, text: "Sum of scores" }, - ...layout.xaxis, - }; - } else { - // Box-and-whiskers plot - d.type = "box"; - d.boxpoints = "all"; - } - } - spec.push(d); + const plot_simple_boxplot = ( + resp_to_x: (r: LLMResponse) => string, + group_type: "var" | "llm", + ) => { + let names = new Set(); + const plotting_categorical_vars = + group_type === "var" && typeof_eval_res === "Categorical"; + + // When we're plotting vars, we want the stacked bar colors to be the *categories*, + // and the x_items to be the names of vars, so that the left axis is a vertical list of varnames. + if (plotting_categorical_vars) { + // Get all categories present in the evaluation results + responses.forEach((r) => + get_items(r.eval_res).forEach((i) => names.add(i.toString())), + ); + } else { + // Get all possible values of the single variable response ('name' vals) + names = new Set(responses.map(resp_to_x)); } - } - layout.hovermode = "closest"; - layout.showlegend = false; - // Set the left margin to fit the yticks labels - layout.margin.l = calcLeftPaddingForYLabels(Object.values(shortnames)); - - if (metric_axes_labels.length > 0) - layout.xaxis = { - title: { font: { size: 12 }, text: metric_axes_labels[0] }, - ...layout.xaxis, - }; - }; - - const plot_grouped_boxplot = (resp_to_x: (r: LLMResponse) => string) => { - // Get all possible values of the single variable response ('name' vals) - const names = new Set(responses.map(resp_to_x)); - const shortnames = genUniqueShortnames(names); - - llm_names.forEach((llm) => { - // Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks. - const rs = responses_by_llm[llm]; - - let x_items: EvaluationScore[] = []; - let y_items: EvaluationScore[] = []; - let text_items: string[] = []; + const shortnames = genUniqueShortnames(names); for (const name of names) { - rs.forEach((r) => { - if (resp_to_x(r) !== name) return; - x_items = x_items.concat(get_items(r.eval_res)).flat(); - text_items = text_items - .concat( + let x_items: EvaluationScore[] = []; + let text_items: string[] = []; + + if (plotting_categorical_vars) { + responses.forEach((r) => { + // Get all evaluation results for this response which match the category 'name': + const eval_res = get_items(r.eval_res).filter( + (i) => i === name, + ); + x_items = x_items.concat( + new Array(eval_res.length).fill(resp_to_x(r)), + ); + }); + } else { + responses.forEach((r) => { + if (resp_to_x(r) !== name) return; + x_items = x_items.concat(get_items(r.eval_res)); + text_items = text_items.concat( createHoverTexts( r.responses.map((v) => (typeof v === "string" ? v : v.d)), ), - ) - .flat(); - y_items = y_items - .concat( - Array(get_items(r.eval_res).length).fill(shortnames[name]), - ) - .flat(); - }); - } + ); + }); + } - if (typeof_eval_res === "Boolean") { - // Plot a histogram for boolean (true/false) categorical data. - spec.push({ - type: "histogram", - histfunc: "sum", - name: llm, - marker: { color: getColorForLLMAndSetIfNotFound(llm) }, - x: x_items.map((i) => (i === true ? "1" : "0")), - y: y_items, - orientation: "h", - }); - layout.barmode = "stack"; + // Lookup the color per LLM when displaying LLM differences, + // otherwise use the palette for displaying variables. + const color = + group_type === "llm" + ? getColorForLLMAndSetIfNotFound(name) + : // : varcolors[name_idx % varcolors.length]; + getColorForLLMAndSetIfNotFound(get_llm(responses[0])); + + if ( + typeof_eval_res === "Boolean" || + typeof_eval_res === "Categorical" + ) { + // Plot a histogram for categorical or boolean data. + spec.push({ + type: "histogram", + histfunc: "sum", + name: shortnames[name], + marker: { color }, + y: x_items, + orientation: "h", + }); + layout.barmode = "stack"; + layout.yaxis = { + showticklabels: true, + dtick: 1, + type: "category", + showgrid: true, + }; + layout.xaxis = { + title: { font: { size: 12 }, text: "Number of 'true' values" }, + ...layout.xaxis, + }; + } else { + // Plot bar or boxplots for all other cases. + const d: Dict = { + name: shortnames[name], + x: x_items, + text: text_items, + hovertemplate: "%{text}", + orientation: "h", + marker: { color }, + }; + + // If only one result, plot a bar chart: + if (x_items.length === 1) { + d.type = "bar"; + d.textposition = "none"; // hide the text which appears within each bar + d.y = new Array(x_items.length).fill(shortnames[name]); + setForcedGraphType("bar"); + } else { + // If multiple eval results per response object (num generations per prompt n > 1), + // let user decide: + if (graphType.key === "bar") { + d.type = "histogram"; + d.histfunc = "sum"; + d.y = new Array(x_items.length).fill(shortnames[name]); + d.textposition = "none"; // hide the text which appears within each bar + layout.xaxis = { + title: { font: { size: 12 }, text: "Sum of scores" }, + ...layout.xaxis, + }; + + // Compute error bars if present + const error_values = [ + computeErrorBar(x_items.map(castEvalScoreToNum), 1.0, sum), + ]; + if (error_values.length > 0) + d.error_x = { + type: "data", + // Asymmetric errors bars, since we're using bootstrapping to determine the 95% CI + array: error_values.map((e) => e[1]), // Upper bound + arrayminus: error_values.map((e) => e[0]), // Lower bound + visible: true, + }; + } else { + // Box-and-whiskers plot + d.type = "box"; + d.boxpoints = "all"; + } + } + + spec.push(d); + } + } + layout.hovermode = "closest"; + layout.showlegend = false; + + // Set the left margin to fit the yticks labels + layout.margin.l = calcLeftPaddingForYLabels( + Object.values(shortnames), + ); + + if (metric_axes_labels.length > 0) layout.xaxis = { - title: { font: { size: 12 }, text: "Number of 'true' values" }, + title: { font: { size: 12 }, text: metric_axes_labels[0] }, ...layout.xaxis, }; - setForcedGraphType("bar"); - } else { - // Plot a boxplot or bar chart for other cases. - const d = { - name: llm, - marker: { color: getColorForLLMAndSetIfNotFound(llm) }, - x: x_items, - y: y_items, - boxpoints: "all", - text: text_items, - hovertemplate: "%{text} (%{x})", - orientation: "h", - } as Dict; + }; - // If only one result, plot a bar chart: - // if (max_num_results_per_prompt === 1) { - let xaxis_title = "score"; - if (graphType.key === "bar") { - d.type = "bar"; - d.textposition = "none"; // hide the text which appears within each bar - xaxis_title = "Sum of scores"; - } else { - // Box-and-whiskers plot - d.type = "box"; + const plot_grouped_boxplot = ( + resp_to_x: (r: LLMResponse) => string, + ) => { + // Get all possible values of the single variable response ('name' vals) + const names = new Set(responses.map(resp_to_x)); + const shortnames = genUniqueShortnames(names); + + llm_names.forEach((llm) => { + // Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks. + const rs = responses_by_llm[llm]; + + let x_items: EvaluationScore[] = []; + let y_items: EvaluationScore[] = []; + // let x_items_by_shortname: { [key: string]: [] } = {}; + let text_items: string[] = []; + for (const name of names) { + rs.forEach((r) => { + if (resp_to_x(r) !== name) return; + const items = get_items(r.eval_res); + x_items = x_items.concat(items).flat(); + text_items = text_items + .concat( + createHoverTexts( + r.responses.map((v) => (typeof v === "string" ? v : v.d)), + ), + ) + .flat(); + y_items = y_items + .concat(Array(items.length).fill(shortnames[name])) + .flat(); + }); } - // } else { - // // If multiple eval results per response object (num generations per prompt n > 1), - // // plot box-and-whiskers to illustrate the variability: - // d.type = "box"; - // } - spec.push(d); - layout.xaxis = { - title: { font: { size: 12 }, text: xaxis_title }, - ...layout.axis, - }; - } - }); - layout.boxmode = "group"; - layout.bargap = 0.5; - - // Set the left margin to fit the yticks labels - layout.margin.l = calcLeftPaddingForYLabels(Object.values(shortnames)); - - if (metric_axes_labels.length > 0) - layout.xaxis = { - title: { font: { size: 12 }, text: metric_axes_labels[0] }, - }; - }; - - if (num_metrics > 1) { - // For 2 or more metrics, display a parallel coordinates plot. - // :: For instance, if evaluator produces { height: 32, weight: 120 } plot responses with 2 metrics, 'height' and 'weight' - if (varnames.length === 1) { - const unique_vals = getUniqueKeysInResponses(responses, (resp_obj) => - get_var(resp_obj, varnames[0]), - ); - // const response_txts = responses.map(res_obj => res_obj.responses).flat(); - - const group_colors = varcolors; - const unselected_line_color = "#ddd"; - const spec_colors = responses - .map((resp_obj) => { - const idx = unique_vals.indexOf(get_var(resp_obj, varnames[0])); - return resp_obj.eval_res - ? Array(resp_obj.eval_res.items.length).fill(idx) - : []; - }) - .flat(); - - const colorscale: [number, string][] = []; - for (let i = 0; i < unique_vals.length; i++) { - if ( - !selectedLegendItems || - selectedLegendItems.indexOf(unique_vals[i]) > -1 - ) - colorscale.push([ - i / (unique_vals.length - 1), - group_colors[i % group_colors.length], - ]); - else - colorscale.push([ - i / (unique_vals.length - 1), - unselected_line_color, - ]); - } - - const dimensions: Dict = []; - metric_axes_labels.forEach((metric) => { - const evals = extractEvalResultsForMetric(metric, responses); - dimensions.push({ - range: evals.every((e) => typeof e === "number") - ? [ - Math.min(...(evals as number[])), - Math.max(...(evals as number[])), - ] - : undefined, - label: metric, - values: evals, - }); - }); - - spec.push({ - type: "parcoords", - pad: [10, 10, 10, 10], - line: { - color: spec_colors, - colorscale, - }, - dimensions, - }); - layout.margin = { l: 40, r: 40, b: 40, t: 50, pad: 0 }; - layout.paper_bgcolor = "white"; - layout.font = { color: "black" }; - layout.selectedpoints = []; - - // There's no built-in legend for parallel coords, unfortunately, so we need to construct our own: - const legend_labels: Dict = {}; - unique_vals.forEach((v, idx) => { - if (!selectedLegendItems || selectedLegendItems.indexOf(v) > -1) - legend_labels[v] = group_colors[idx % group_colors.length]; - else legend_labels[v] = unselected_line_color; - }); - const onClickLegendItem = (label: string) => { - if ( - selectedLegendItems && - selectedLegendItems.length === 1 && - selectedLegendItems[0] === label - ) - setSelectedLegendItems(null); // Clicking twice on a legend item deselects it and displays all - else setSelectedLegendItems([label]); - }; - plot_legend = ( - - ); - - // Tried to support Plotly hover events here, but looks like - // currently there are unsupported for parcoords: https://github.com/plotly/plotly.js/issues/3012 - // onHover = (e) => { - // console.log(e.curveNumber); - // // const curveIdx = e.curveNumber; - // // if (curveIdx < response_txts.length) { - // // if (!selectedLegendItems || selectedLegendItems.indexOf(unique_vals[spec_colors[curveIdx]]) > -1) - // // console.log(response_txts[curveIdx]); - // // } - // }; - } else { - setSelectedLegendItems(null); - const error_text = - "Plotting evaluations with more than one metric and more than one prompt parameter is currently unsupported."; - setPlaceholderText( -

- {error_text} -

, - ); - console.error(error_text); - } - } else { - // A single metric --use plots like grouped box-and-whiskers, 3d scatterplot - if (varnames.length === 0) { - // No variables means they used a single prompt (no template) to generate responses - // (Users are likely evaluating differences in responses between LLMs) - if (typeof_eval_res === "Boolean") plot_accuracy(get_llm, "llm"); - else plot_simple_boxplot(get_llm, "llm"); - } else if (varnames.length === 1) { - // 1 var; numeric eval - if (llm_names.length === 1) { - if (typeof_eval_res === "Boolean") - // Accuracy plot per value of the selected variable: - plot_accuracy((r) => get_var_and_trim(r, varnames[0]), "var"); - else { - // Simple box plot, as there is only a single LLM in the response - plot_simple_boxplot( - (r) => get_var_and_trim(r, varnames[0]), - "var", - ); - } - } else { - // There are multiple LLMs in the response; do a grouped box plot by LLM. - // Note that 'name' is now the LLM, and 'x' stores the value of the var: - plot_grouped_boxplot((r) => get_var_and_trim(r, varnames[0])); - } - } else if (varnames.length === 2) { - // Input is 2 vars; numeric eval - // Display a 3D scatterplot with 2 dimensions: - - const names_0 = new Set( - responses.map((r) => get_var_and_trim(r, varnames[0])), - ); - const shortnames_0 = genUniqueShortnames(names_0); - const names_1 = new Set( - responses.map((r) => get_var_and_trim(r, varnames[1])), - ); - const shortnames_1 = genUniqueShortnames(names_1); - - if (llm_names.length === 1) { - spec = { - type: "scatter3d", - x: responses - .map((r) => get_var(r, varnames[0], true)) - .map((s) => shortnames_0[s]), - y: responses - .map((r) => get_var(r, varnames[1], true)) - .map((s) => shortnames_1[s]), - z: responses.map( - (r) => - get_items(r.eval_res).reduce( - (acc: number, val) => - acc + (typeof val === "number" ? val : 0), - 0, - ) / (r.eval_res?.items.length ?? 1), - ), // calculates mean - mode: "markers", - marker: { - color: getColorForLLMAndSetIfNotFound(llm_names[0]), - }, - }; - } else { - spec = []; - llm_names.forEach((llm) => { - const resps = responses.filter((r) => get_llm(r) === llm); + if (typeof_eval_res === "Boolean") { + // Plot a histogram for boolean (true/false) categorical data. spec.push({ + type: "histogram", + histfunc: "sum", + name: llm, + marker: { color: getColorForLLMAndSetIfNotFound(llm) }, + x: x_items.map((i) => (i === true ? "1" : "0")), + y: y_items, + orientation: "h", + }); + layout.barmode = "stack"; + layout.xaxis = { + title: { font: { size: 12 }, text: "Number of 'true' values" }, + ...layout.xaxis, + }; + setForcedGraphType("bar"); + } else { + // Plot a boxplot or bar chart for other cases. + const d = { + name: llm, + marker: { color: getColorForLLMAndSetIfNotFound(llm) }, + x: x_items, + y: y_items, + boxpoints: "all", + text: text_items, + hovertemplate: "%{text} (%{x})", + orientation: "h", + } as Dict; + + // If only one result, plot a bar chart: + // if (max_num_results_per_prompt === 1) { + let xaxis_title = "score"; + if (graphType.key === "bar") { + d.type = "bar"; + d.textposition = "none"; // hide the text which appears within each bar + xaxis_title = "Sum of scores"; + + if (typeof_eval_res === "Numeric") { + // To make error bars work, we need to sum the numbers, instead of relying + // upon the stacked bar chart: + let sum_x_items: number[] = []; + // let x_items_by_y: { [key: string]: number[] } = {}; + let error_bars: number[][] = []; + const seq_y_items = []; + for (const name of Object.values(shortnames)) { + seq_y_items.push(name); + const xs_for_y = x_items + .filter((_, idx) => y_items[idx] === name) + .map(castEvalScoreToNum); + // x_items_by_y[name] = xs_for_y; + sum_x_items = sum_x_items.concat(sum(xs_for_y)); + error_bars = error_bars.concat([ + computeErrorBar(xs_for_y, 1.0, sum), + ]); + } + d.x = sum_x_items; + d.y = seq_y_items; + d.hovertemplate = llm; + delete d.text; + + // Add error bars to plot + d.error_x = { + type: "data", + // Asymmetric errors bars, since we're using bootstrapping to determine the 95% CI + array: error_bars.map((e) => e[1]), // Upper bound + arrayminus: error_bars.map((e) => e[0]), // Lower bound + visible: true, + }; + } + } else { + // Box-and-whiskers plot + d.type = "box"; + } + + spec.push(d); + layout.xaxis = { + title: { font: { size: 12 }, text: xaxis_title }, + ...layout.axis, + }; + } + }); + layout.boxmode = "group"; + layout.bargap = 0.5; + // layout.yaxis = { + // tickfont: { size: 10 }, + // ...layout.yaxis, + // }; + + // Set the left margin to fit the yticks labels + layout.margin.l = calcLeftPaddingForYLabels( + Object.values(shortnames), + ); + + if (metric_axes_labels.length > 0) + layout.xaxis = { + title: { font: { size: 12 }, text: metric_axes_labels[0] }, + }; + }; + + if (num_metrics > 1) { + // For 2 or more metrics, display a parallel coordinates plot. + // :: For instance, if evaluator produces { height: 32, weight: 120 } plot responses with 2 metrics, 'height' and 'weight' + if (varnames.length === 1) { + const unique_vals = getUniqueKeysInResponses( + responses, + (resp_obj) => get_var(resp_obj, varnames[0]), + ); + // const response_txts = responses.map(res_obj => res_obj.responses).flat(); + + const group_colors = varcolors; + const unselected_line_color = "#ddd"; + const spec_colors = responses + .map((resp_obj) => { + const idx = unique_vals.indexOf(get_var(resp_obj, varnames[0])); + return resp_obj.eval_res + ? Array(resp_obj.eval_res.items.length).fill(idx) + : []; + }) + .flat(); + + const colorscale: [number, string][] = []; + for (let i = 0; i < unique_vals.length; i++) { + if ( + !selectedLegendItems || + selectedLegendItems.indexOf(unique_vals[i]) > -1 + ) + colorscale.push([ + i / (unique_vals.length - 1), + group_colors[i % group_colors.length], + ]); + else + colorscale.push([ + i / (unique_vals.length - 1), + unselected_line_color, + ]); + } + + const dimensions: Dict = []; + metric_axes_labels.forEach((metric) => { + const evals = extractEvalResultsForMetric(metric, responses); + dimensions.push({ + range: evals.every((e) => typeof e === "number") + ? [ + Math.min(...(evals as number[])), + Math.max(...(evals as number[])), + ] + : undefined, + label: metric, + values: evals, + }); + }); + + spec.push({ + type: "parcoords", + pad: [10, 10, 10, 10], + line: { + color: spec_colors, + colorscale, + }, + dimensions, + }); + layout.margin = { l: 40, r: 40, b: 40, t: 50, pad: 0 }; + layout.paper_bgcolor = "white"; + layout.font = { color: "black" }; + layout.selectedpoints = []; + + // There's no built-in legend for parallel coords, unfortunately, so we need to construct our own: + const legend_labels: Dict = {}; + unique_vals.forEach((v, idx) => { + if (!selectedLegendItems || selectedLegendItems.indexOf(v) > -1) + legend_labels[v] = group_colors[idx % group_colors.length]; + else legend_labels[v] = unselected_line_color; + }); + const onClickLegendItem = (label: string) => { + if ( + selectedLegendItems && + selectedLegendItems.length === 1 && + selectedLegendItems[0] === label + ) + setSelectedLegendItems(null); // Clicking twice on a legend item deselects it and displays all + else setSelectedLegendItems([label]); + }; + plot_legend = ( + + ); + + // Tried to support Plotly hover events here, but looks like + // currently there are unsupported for parcoords: https://github.com/plotly/plotly.js/issues/3012 + // onHover = (e) => { + // console.log(e.curveNumber); + // // const curveIdx = e.curveNumber; + // // if (curveIdx < response_txts.length) { + // // if (!selectedLegendItems || selectedLegendItems.indexOf(unique_vals[spec_colors[curveIdx]]) > -1) + // // console.log(response_txts[curveIdx]); + // // } + // }; + } else { + setSelectedLegendItems(null); + const error_text = + "Plotting evaluations with more than one metric and more than one prompt parameter is currently unsupported."; + setPlaceholderText( +

+ {error_text} +

, + ); + console.error(error_text); + } + } else { + // A single metric --use plots like grouped box-and-whiskers, 3d scatterplot + if (varnames.length === 0) { + // No variables means they used a single prompt (no template) to generate responses + // (Users are likely evaluating differences in responses between LLMs) + if (typeof_eval_res === "Boolean") plot_accuracy(get_llm, "llm"); + else plot_simple_boxplot(get_llm, "llm"); + } else if (varnames.length === 1) { + // 1 var; numeric eval + if (llm_names.length === 1) { + if (typeof_eval_res === "Boolean") + // Accuracy plot per value of the selected variable: + plot_accuracy((r) => get_var_and_trim(r, varnames[0]), "var"); + else { + // Simple box plot, as there is only a single LLM in the response + plot_simple_boxplot( + (r) => get_var_and_trim(r, varnames[0]), + "var", + ); + } + } else { + // There are multiple LLMs in the response; do a grouped box plot by LLM. + // Note that 'name' is now the LLM, and 'x' stores the value of the var: + plot_grouped_boxplot((r) => get_var_and_trim(r, varnames[0])); + } + } else if (varnames.length === 2) { + // Input is 2 vars; numeric eval + // Display a 3D scatterplot with 2 dimensions: + + const names_0 = new Set( + responses.map((r) => get_var_and_trim(r, varnames[0])), + ); + const shortnames_0 = genUniqueShortnames(names_0); + const names_1 = new Set( + responses.map((r) => get_var_and_trim(r, varnames[1])), + ); + const shortnames_1 = genUniqueShortnames(names_1); + + if (llm_names.length === 1) { + spec = { type: "scatter3d", - x: resps + x: responses .map((r) => get_var(r, varnames[0], true)) .map((s) => shortnames_0[s]), - y: resps + y: responses .map((r) => get_var(r, varnames[1], true)) .map((s) => shortnames_1[s]), - z: resps.map( + z: responses.map( (r) => get_items(r.eval_res).reduce( (acc: number, val) => @@ -974,20 +1114,46 @@ export const VisView = forwardRef( ), // calculates mean mode: "markers", marker: { - color: getColorForLLMAndSetIfNotFound(llm), + color: getColorForLLMAndSetIfNotFound(llm_names[0]), }, - name: llm, + }; + } else { + spec = []; + llm_names.forEach((llm) => { + const resps = responses.filter((r) => get_llm(r) === llm); + spec.push({ + type: "scatter3d", + x: resps + .map((r) => get_var(r, varnames[0], true)) + .map((s) => shortnames_0[s]), + y: resps + .map((r) => get_var(r, varnames[1], true)) + .map((s) => shortnames_1[s]), + z: resps.map( + (r) => + get_items(r.eval_res).reduce( + (acc: number, val) => + acc + (typeof val === "number" ? val : 0), + 0, + ) / (r.eval_res?.items.length ?? 1), + ), // calculates mean + mode: "markers", + marker: { + color: getColorForLLMAndSetIfNotFound(llm), + }, + name: llm, + }); }); - }); + } } } - } - if (!Array.isArray(spec)) spec = [spec]; + if (!Array.isArray(spec)) spec = [spec]; - setPlotLegend(plot_legend); - setPlotlySpec(spec as Dict[]); - setPlotlyLayout(layout); + setPlotLegend(plot_legend); + setPlotlySpec(spec as Dict[]); + setPlotlyLayout(layout); + }); // if (plotDivRef && plotDivRef.current) { // plotDivRef.current.style.width = '300px'; @@ -1067,7 +1233,6 @@ export const VisView = forwardRef( size="xs" value={"score"} miw="80px" - disabled /> {availableLLMGroups && availableLLMGroups.length > 1 ? ( @@ -1225,7 +1390,7 @@ const VisNode: React.FC = ({ data, id }) => { status={status} icon={"📊"} /> - + setStatus(isReplotting ? Status.LOADING : Status.NONE)} />