From f0c506d242368ea2e10f88ef0767d2faed303f62 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Sun, 14 May 2023 10:07:16 -0400 Subject: [PATCH] Basic support for dict evaluation results --- chain-forge/src/InspectorNode.js | 87 ++++++++++++++------ chain-forge/src/VisNode.js | 62 +++++++++----- python-backend/flask_app.py | 133 +++++++++++++++++++++++++------ 3 files changed, 212 insertions(+), 70 deletions(-) diff --git a/chain-forge/src/InspectorNode.js b/chain-forge/src/InspectorNode.js index 32ff9a5..910d100 100644 --- a/chain-forge/src/InspectorNode.js +++ b/chain-forge/src/InspectorNode.js @@ -19,15 +19,18 @@ const vars_to_str = (vars) => { }); return pairs; }; -const bucketResponsesByLLM = (responses) => { - let responses_by_llm = {}; +const groupResponsesBy = (responses, keyFunc) => { + let responses_by_key = {}; + let unspecified_group = []; responses.forEach(item => { - if (item.llm in responses_by_llm) - responses_by_llm[item.llm].push(item); + const key = keyFunc(item); + const d = key !== null ? responses_by_key : unspecified_group; + if (key in d) + d[key].push(item); else - responses_by_llm[item.llm] = [item]; + d[key] = [item]; }); - return responses_by_llm; + return [responses_by_key, unspecified_group]; }; const InspectorNode = ({ data, id }) => { @@ -86,34 +89,66 @@ const InspectorNode = ({ data, id }) => { // Now we need to perform groupings by each var in the selected vars list, // nesting the groupings (preferrably with custom divs) and sorting within // each group by value of that group's var (so all same values are clumped together). + // :: For instance, for varnames = ['LLM', '$var1', '$var2'] we should get back + // :: nested divs first grouped by LLM (first level), then by var1, then var2 (deepest level). /** - const groupBy = (resps, varnames) => { - if (varnames.length === 0) return []; + const groupByVars = (resps, varnames, eatenvars) => { + if (resps.length === 0) return []; + if (varnames.length === 0) { + // Base case. Display n response(s) to each single prompt, back-to-back: + return resps.map((res_obj, res_idx) => { + // Spans for actual individual response texts + const ps = res_obj.responses.map((r, idx) => + (
{r}
) + ); - const groupName = varnames[0]; - const groupedResponses = groupResponsesByVar(resps, groupName); - const groupedResponseDivs = groupedResponses.map(g => groupBy(g, varnames.slice(1))); + // At the deepest level, there may still be some vars left over. We want to display these + // as tags, too, so we need to display only the ones that weren't 'eaten' during the recursive call: + // (e.g., the vars that weren't part of the initial 'varnames' list that form the groupings) + const vars = vars_to_str(res_obj.vars.filter(v => !eatenvars.includes(v))); + const var_tags = vars.map((v) => + ({v}) + ); + return ( +
+ {var_tags} + {ps} +
+ ); + }); + } - return ( -
- {groupName} - {groupedResponseDivs} + // Bucket responses by the first var in the list, where + // we also bucket any 'leftover' responses that didn't have the requested variable (a kind of 'soft fail') + const group_name = varnames[0]; + const [grouped_resps, leftover_resps] = (group_name === 'LLM') + ? groupResponsesBy(resps, (r => r.llm)) + : groupResponsesBy(resps, (r => ((group_name in r.vars) ? r.vars[group_name] : null))); + // Now produce nested divs corresponding to the groups + const remaining_vars = varnames.slice(1); + const updated_eatenvars = eatenvars.concat([group_name]); + const grouped_resps_divs = grouped_resps.map(g => groupByVars(g, remaining_vars, updated_eatenvars)); + const leftover_resps_divs = leftover_resps.length > 0 ? groupByVars(leftover_resps, remaining_vars, updated_eatenvars) : []; + + return (<> +
+

{group_name}

+ {grouped_resps_divs}
- ); + {leftover_resps_divs.length === 0 ? (<>) : ( +
+ {leftover_resps_divs} +
+ )} + ); }; - // Group by LLM - if (selected_vars.includes('LLM')) { - // ... - - // Group without LLM - } else { - // .. - } - */ + // Produce DIV elements grouped by selected vars + groupByVars(responses, selected_vars, []); + **/ // Bucket responses by LLM: - const responses_by_llm = bucketResponsesByLLM(json.responses); + const responses_by_llm = groupResponsesBy(responses, (r => r.llm)); const colors = ['#ace1aeb1', '#f1b963b1', '#e46161b1', '#f8f398b1', '#defcf9b1', '#cadefcb1', '#c3bef0b1', '#cca8e9b1']; setResponses(Object.keys(responses_by_llm).map((llm, llm_idx) => { diff --git a/chain-forge/src/VisNode.js b/chain-forge/src/VisNode.js index 2d320cd..40dcfbb 100644 --- a/chain-forge/src/VisNode.js +++ b/chain-forge/src/VisNode.js @@ -61,32 +61,57 @@ const VisNode = ({ data, id }) => { }); const llm_names = Object.keys(responses_by_llm); + // Get the type of evaluation results, if present + // (This is assumed to be consistent across response batches) + const typeof_eval_res = 'dtype' in responses[0].eval_res ? responses[0].eval_res['dtype'] : 'Numeric'; + + let metric_ax_label = null; + if (typeof_eval_res.includes('KeyValue')) { + // Check if it's a single-item dict; in which case we can extract the key to name the axis: + const keys = Object.keys(responses[0].eval_res.items[0]); + if (keys.length === 1) + metric_ax_label = keys[0]; + else + throw Error('Dict metrics with more than one key are currently unsupported.') + // TODO: When multiple metrics are present, and 1 var is selected (can be multiple LLMs as well), + // default to Parallel Coordinates plot, with the 1 var values on the y-axis as colored groups, and metrics on x-axis. + // For multiple LLMs, add a control drop-down selector to switch the LLM visualized in the plot. + } + + + const get_items = (eval_res_obj) => { + if (typeof_eval_res.includes('KeyValue')) + return eval_res_obj.items.map(item => item[metric_ax_label]); + return eval_res_obj.items; + }; + // Create Plotly spec here const varnames = multiSelectValue; const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9']; let spec = []; let layout = { width: 420, height: 300, title: '', margin: { - l: 105, r: 0, b: 20, t: 20, pad: 0 + l: 105, r: 0, b: 36, t: 20, pad: 0 } }; const plot_grouped_boxplot = (resp_to_x) => { + // Get all possible values of the single variable response ('name' vals) + const names = new Set(responses.map(resp_to_x)); + llm_names.forEach((llm, idx) => { // Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks. const rs = responses_by_llm[llm]; - // Get all possible values of the single variable response ('name' vals) - const names = new Set(responses.map(resp_to_x)); let x_items = []; let y_items = []; let text_items = []; for (const name of names) { rs.forEach(r => { if (resp_to_x(r) !== name) return; - x_items = x_items.concat(r.eval_res.items).flat(); + x_items = x_items.concat(get_items(r.eval_res)).flat(); text_items = text_items.concat(createHoverTexts(r.responses)).flat(); - y_items = y_items.concat(Array(r.eval_res.items.length).fill(truncStr(name, 12))).flat(); + y_items = y_items.concat(Array(get_items(r.eval_res).length).fill(truncStr(name, 12))).flat(); }); } @@ -101,21 +126,13 @@ const VisNode = ({ data, id }) => { hovertemplate: '%{text} (%{x})', orientation: 'h', }); - - // const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat(); - // spec.push({ - // type: 'box', - // name: llm, - // marker: {color: colors[idx % colors.length]}, - // x: rs.map(r => r.eval_res.items).flat(), - // y: rs.map(r => Array(r.eval_res.items.length).fill(resp_to_x(r))).flat(), - // boxpoints: 'all', - // text: hover_texts, - // hovertemplate: '%{text} (%{x})', - // orientation: 'h', - // }); }); layout.boxmode = 'group'; + + if (metric_ax_label) + layout.xaxis = { + title: { font: {size: 12}, text: metric_ax_label }, + }; }; if (varnames.length === 0) { @@ -134,7 +151,7 @@ const VisNode = ({ data, id }) => { let text_items = []; responses.forEach(r => { if (r.vars[varnames[0]].trim() !== name) return; - x_items = x_items.concat(r.eval_res.items); + x_items = x_items.concat(get_items(r.eval_res)); text_items = text_items.concat(createHoverTexts(r.responses)); }); spec.push( @@ -142,6 +159,11 @@ const VisNode = ({ data, id }) => { ); } layout.hovermode = 'closest'; + + if (metric_ax_label) + layout.xaxis = { + title: { font: {size: 12}, text: metric_ax_label }, + }; } else { // There are multiple LLMs in the response; do a grouped box plot by LLM. // Note that 'name' is now the LLM, and 'x' stores the value of the var: @@ -155,7 +177,7 @@ const VisNode = ({ data, id }) => { type: 'scatter3d', x: responses.map(r => r.vars[varnames[0]]).map(s => truncStr(s, 12)), y: responses.map(r => r.vars[varnames[1]]).map(s => truncStr(s, 12)), - z: responses.map(r => r.eval_res.mean), + z: responses.map(r => get_items(r.eval_res).reduce((acc, val) => (acc + val), 0) / r.eval_res.items.length), // calculates mean mode: 'markers', } } diff --git a/python-backend/flask_app.py b/python-backend/flask_app.py index 346ac8b..bec8001 100644 --- a/python-backend/flask_app.py +++ b/python-backend/flask_app.py @@ -1,5 +1,6 @@ import json, os, asyncio, sys, argparse, threading, traceback from dataclasses import dataclass +from enum import Enum from statistics import mean, median, stdev from flask import Flask, request, jsonify, render_template, send_from_directory from flask_cors import CORS @@ -66,35 +67,119 @@ def load_cache_json(filepath: str) -> dict: responses = json.load(f) return responses +class MetricType(Enum): + KeyValue = 0 + KeyValue_Numeric = 1 + KeyValue_Categorical = 2 + KeyValue_Mixed = 3 + Numeric = 4 + Categorical = 5 + Mixed = 6 + Unknown = 7 + Empty = 8 + +def check_typeof_vals(arr: list) -> MetricType: + if len(arr) == 0: return MetricType.Empty + + def typeof_set(types: set) -> MetricType: + if len(types) == 0: return MetricType.Empty + if len(types) == 1 and next(iter(types)) == dict: + return MetricType.KeyValue + elif all((t in (int, float) for t in types)): + # Numeric metrics only + return MetricType.Numeric + elif all((t in (str, bool) for t in types)): + # Categorical metrics only ('bool' is True/False, counts as categorical) + return MetricType.Categorical + elif all((t in (int, float, bool, str) for t in types)): + # Mix of numeric and categorical types + return MetricType.Mixed + else: + # Mix of types beyond basic ones + return MetricType.Unknown + + def typeof_dict_vals(d): + dict_val_type = typeof_set(set((type(v) for v in d.values()))) + if dict_val_type == MetricType.Numeric: + return MetricType.KeyValue_Numeric + elif dict_val_type == MetricType.Categorical: + return MetricType.KeyValue_Categorical + else: + return MetricType.KeyValue_Mixed + + # Checks type of all values in 'arr' and returns the type + val_type = typeof_set(set((type(v) for v in arr))) + if val_type == MetricType.KeyValue: + # This is a 'KeyValue' pair type. We need to find the more specific type of the values in the dict. + # First, we check that all dicts have the exact same keys + for i in range(len(arr)-1): + d, e = arr[i], arr[i+1] + if set(d.keys()) != set(e.keys()): + raise Exception('The keys and size of dicts for evaluation results must be consistent across evaluations.') + + # Then, we check the consistency of the type of dict values: + first_dict_val_type = typeof_dict_vals(arr[0]) + for d in arr[1:]: + if first_dict_val_type != typeof_dict_vals(d): + raise Exception('Types of values in dicts for evaluation results must be consistent across responses.') + # If we're here, all checks passed, and we return the more specific KeyValue type: + return first_dict_val_type + else: + return val_type + def run_over_responses(eval_func, responses: list, scope: str) -> list: for resp_obj in responses: res = resp_obj['responses'] if scope == 'response': - evals = [ # Run evaluator func over every individual response text - eval_func( - ResponseInfo( - text=r, - prompt=resp_obj['prompt'], - var=resp_obj['vars'], - llm=resp_obj['llm']) - ) for r in res - ] - resp_obj['eval_res'] = { # NOTE: assumes this is numeric data - 'mean': mean(evals), - 'median': median(evals), - 'stdev': stdev(evals) if len(evals) > 1 else 0, - 'range': (min(evals), max(evals)), - 'items': evals, - } - else: # operate over the entire response batch + # Run evaluator func over every individual response text + evals = [eval_func( + ResponseInfo( + text=r, + prompt=resp_obj['prompt'], + var=resp_obj['vars'], + llm=resp_obj['llm']) + ) for r in res] + + # Check the type of evaluation results + # NOTE: We assume this is consistent across all evaluations, but it may not be. + eval_res_type = check_typeof_vals(evals) + + if eval_res_type == MetricType.Numeric: + # Store items with summary of mean, median, etc + resp_obj['eval_res'] = { + 'mean': mean(evals), + 'median': median(evals), + 'stdev': stdev(evals) if len(evals) > 1 else 0, + 'range': (min(evals), max(evals)), + 'items': evals, + 'dtype': eval_res_type.name, + } + elif eval_res_type in (MetricType.Unknown, MetricType.Empty): + raise Exception('Unsupported types found in evaluation results. Only supported types for metrics are: int, float, bool, str.') + else: + # Categorical, KeyValue, etc, we just store the items: + resp_obj['eval_res'] = { + 'items': evals, + 'dtype': eval_res_type.name, + } + else: + # Run evaluator func over the entire response batch ev = eval_func(res) - resp_obj['eval_res'] = { # NOTE: assumes this is numeric data - 'mean': ev, - 'median': ev, - 'stdev': 0, - 'range': (ev, ev), - 'items': [ev], - } + ev_type = typeof_set(set((type(ev),))) + if ev_type == MetricType.Numeric: + resp_obj['eval_res'] = { + 'mean': ev, + 'median': ev, + 'stdev': 0, + 'range': (ev, ev), + 'items': [ev], + 'type': ev_type.name, + } + else: + resp_obj['eval_res'] = { + 'items': [ev], + 'type': ev_type.name, + } return responses def reduce_responses(responses: list, vars: list) -> list: