mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Basic support for dict evaluation results
This commit is contained in:
parent
f7fb238a6d
commit
f0c506d242
@ -19,15 +19,18 @@ const vars_to_str = (vars) => {
|
||||
});
|
||||
return pairs;
|
||||
};
|
||||
const bucketResponsesByLLM = (responses) => {
|
||||
let responses_by_llm = {};
|
||||
const groupResponsesBy = (responses, keyFunc) => {
|
||||
let responses_by_key = {};
|
||||
let unspecified_group = [];
|
||||
responses.forEach(item => {
|
||||
if (item.llm in responses_by_llm)
|
||||
responses_by_llm[item.llm].push(item);
|
||||
const key = keyFunc(item);
|
||||
const d = key !== null ? responses_by_key : unspecified_group;
|
||||
if (key in d)
|
||||
d[key].push(item);
|
||||
else
|
||||
responses_by_llm[item.llm] = [item];
|
||||
d[key] = [item];
|
||||
});
|
||||
return responses_by_llm;
|
||||
return [responses_by_key, unspecified_group];
|
||||
};
|
||||
|
||||
const InspectorNode = ({ data, id }) => {
|
||||
@ -86,34 +89,66 @@ const InspectorNode = ({ data, id }) => {
|
||||
// Now we need to perform groupings by each var in the selected vars list,
|
||||
// nesting the groupings (preferrably with custom divs) and sorting within
|
||||
// each group by value of that group's var (so all same values are clumped together).
|
||||
// :: For instance, for varnames = ['LLM', '$var1', '$var2'] we should get back
|
||||
// :: nested divs first grouped by LLM (first level), then by var1, then var2 (deepest level).
|
||||
/**
|
||||
const groupBy = (resps, varnames) => {
|
||||
if (varnames.length === 0) return [];
|
||||
const groupByVars = (resps, varnames, eatenvars) => {
|
||||
if (resps.length === 0) return [];
|
||||
if (varnames.length === 0) {
|
||||
// Base case. Display n response(s) to each single prompt, back-to-back:
|
||||
return resps.map((res_obj, res_idx) => {
|
||||
// Spans for actual individual response texts
|
||||
const ps = res_obj.responses.map((r, idx) =>
|
||||
(<pre className="small-response" key={idx}>{r}</pre>)
|
||||
);
|
||||
|
||||
const groupName = varnames[0];
|
||||
const groupedResponses = groupResponsesByVar(resps, groupName);
|
||||
const groupedResponseDivs = groupedResponses.map(g => groupBy(g, varnames.slice(1)));
|
||||
// At the deepest level, there may still be some vars left over. We want to display these
|
||||
// as tags, too, so we need to display only the ones that weren't 'eaten' during the recursive call:
|
||||
// (e.g., the vars that weren't part of the initial 'varnames' list that form the groupings)
|
||||
const vars = vars_to_str(res_obj.vars.filter(v => !eatenvars.includes(v)));
|
||||
const var_tags = vars.map((v) =>
|
||||
(<Badge key={v} color="blue" size="xs">{v}</Badge>)
|
||||
);
|
||||
return (
|
||||
<div key={"r"+res_idx} className="response-box" style={{ backgroundColor: colorForLLM(res_obj.llm) }}>
|
||||
{var_tags}
|
||||
{ps}
|
||||
</div>
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<div key={groupName} className="response-group">
|
||||
<span>{groupName}</span>
|
||||
{groupedResponseDivs}
|
||||
// Bucket responses by the first var in the list, where
|
||||
// we also bucket any 'leftover' responses that didn't have the requested variable (a kind of 'soft fail')
|
||||
const group_name = varnames[0];
|
||||
const [grouped_resps, leftover_resps] = (group_name === 'LLM')
|
||||
? groupResponsesBy(resps, (r => r.llm))
|
||||
: groupResponsesBy(resps, (r => ((group_name in r.vars) ? r.vars[group_name] : null)));
|
||||
// Now produce nested divs corresponding to the groups
|
||||
const remaining_vars = varnames.slice(1);
|
||||
const updated_eatenvars = eatenvars.concat([group_name]);
|
||||
const grouped_resps_divs = grouped_resps.map(g => groupByVars(g, remaining_vars, updated_eatenvars));
|
||||
const leftover_resps_divs = leftover_resps.length > 0 ? groupByVars(leftover_resps, remaining_vars, updated_eatenvars) : [];
|
||||
|
||||
return (<>
|
||||
<div key={group_name} className="response-group">
|
||||
<h1>{group_name}</h1>
|
||||
{grouped_resps_divs}
|
||||
</div>
|
||||
);
|
||||
{leftover_resps_divs.length === 0 ? (<></>) : (
|
||||
<div key={'__unspecified_group'} className="response-group">
|
||||
{leftover_resps_divs}
|
||||
</div>
|
||||
)}
|
||||
</>);
|
||||
};
|
||||
|
||||
// Group by LLM
|
||||
if (selected_vars.includes('LLM')) {
|
||||
// ...
|
||||
|
||||
// Group without LLM
|
||||
} else {
|
||||
// ..
|
||||
}
|
||||
*/
|
||||
// Produce DIV elements grouped by selected vars
|
||||
groupByVars(responses, selected_vars, []);
|
||||
**/
|
||||
|
||||
// Bucket responses by LLM:
|
||||
const responses_by_llm = bucketResponsesByLLM(json.responses);
|
||||
const responses_by_llm = groupResponsesBy(responses, (r => r.llm));
|
||||
|
||||
const colors = ['#ace1aeb1', '#f1b963b1', '#e46161b1', '#f8f398b1', '#defcf9b1', '#cadefcb1', '#c3bef0b1', '#cca8e9b1'];
|
||||
setResponses(Object.keys(responses_by_llm).map((llm, llm_idx) => {
|
||||
|
@ -61,32 +61,57 @@ const VisNode = ({ data, id }) => {
|
||||
});
|
||||
const llm_names = Object.keys(responses_by_llm);
|
||||
|
||||
// Get the type of evaluation results, if present
|
||||
// (This is assumed to be consistent across response batches)
|
||||
const typeof_eval_res = 'dtype' in responses[0].eval_res ? responses[0].eval_res['dtype'] : 'Numeric';
|
||||
|
||||
let metric_ax_label = null;
|
||||
if (typeof_eval_res.includes('KeyValue')) {
|
||||
// Check if it's a single-item dict; in which case we can extract the key to name the axis:
|
||||
const keys = Object.keys(responses[0].eval_res.items[0]);
|
||||
if (keys.length === 1)
|
||||
metric_ax_label = keys[0];
|
||||
else
|
||||
throw Error('Dict metrics with more than one key are currently unsupported.')
|
||||
// TODO: When multiple metrics are present, and 1 var is selected (can be multiple LLMs as well),
|
||||
// default to Parallel Coordinates plot, with the 1 var values on the y-axis as colored groups, and metrics on x-axis.
|
||||
// For multiple LLMs, add a control drop-down selector to switch the LLM visualized in the plot.
|
||||
}
|
||||
|
||||
|
||||
const get_items = (eval_res_obj) => {
|
||||
if (typeof_eval_res.includes('KeyValue'))
|
||||
return eval_res_obj.items.map(item => item[metric_ax_label]);
|
||||
return eval_res_obj.items;
|
||||
};
|
||||
|
||||
// Create Plotly spec here
|
||||
const varnames = multiSelectValue;
|
||||
const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
|
||||
let spec = [];
|
||||
let layout = {
|
||||
width: 420, height: 300, title: '', margin: {
|
||||
l: 105, r: 0, b: 20, t: 20, pad: 0
|
||||
l: 105, r: 0, b: 36, t: 20, pad: 0
|
||||
}
|
||||
};
|
||||
|
||||
const plot_grouped_boxplot = (resp_to_x) => {
|
||||
// Get all possible values of the single variable response ('name' vals)
|
||||
const names = new Set(responses.map(resp_to_x));
|
||||
|
||||
llm_names.forEach((llm, idx) => {
|
||||
// Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
|
||||
const rs = responses_by_llm[llm];
|
||||
|
||||
// Get all possible values of the single variable response ('name' vals)
|
||||
const names = new Set(responses.map(resp_to_x));
|
||||
let x_items = [];
|
||||
let y_items = [];
|
||||
let text_items = [];
|
||||
for (const name of names) {
|
||||
rs.forEach(r => {
|
||||
if (resp_to_x(r) !== name) return;
|
||||
x_items = x_items.concat(r.eval_res.items).flat();
|
||||
x_items = x_items.concat(get_items(r.eval_res)).flat();
|
||||
text_items = text_items.concat(createHoverTexts(r.responses)).flat();
|
||||
y_items = y_items.concat(Array(r.eval_res.items.length).fill(truncStr(name, 12))).flat();
|
||||
y_items = y_items.concat(Array(get_items(r.eval_res).length).fill(truncStr(name, 12))).flat();
|
||||
});
|
||||
}
|
||||
|
||||
@ -101,21 +126,13 @@ const VisNode = ({ data, id }) => {
|
||||
hovertemplate: '%{text} <b><i>(%{x})</i></b>',
|
||||
orientation: 'h',
|
||||
});
|
||||
|
||||
// const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat();
|
||||
// spec.push({
|
||||
// type: 'box',
|
||||
// name: llm,
|
||||
// marker: {color: colors[idx % colors.length]},
|
||||
// x: rs.map(r => r.eval_res.items).flat(),
|
||||
// y: rs.map(r => Array(r.eval_res.items.length).fill(resp_to_x(r))).flat(),
|
||||
// boxpoints: 'all',
|
||||
// text: hover_texts,
|
||||
// hovertemplate: '%{text} <b><i>(%{x})</i></b>',
|
||||
// orientation: 'h',
|
||||
// });
|
||||
});
|
||||
layout.boxmode = 'group';
|
||||
|
||||
if (metric_ax_label)
|
||||
layout.xaxis = {
|
||||
title: { font: {size: 12}, text: metric_ax_label },
|
||||
};
|
||||
};
|
||||
|
||||
if (varnames.length === 0) {
|
||||
@ -134,7 +151,7 @@ const VisNode = ({ data, id }) => {
|
||||
let text_items = [];
|
||||
responses.forEach(r => {
|
||||
if (r.vars[varnames[0]].trim() !== name) return;
|
||||
x_items = x_items.concat(r.eval_res.items);
|
||||
x_items = x_items.concat(get_items(r.eval_res));
|
||||
text_items = text_items.concat(createHoverTexts(r.responses));
|
||||
});
|
||||
spec.push(
|
||||
@ -142,6 +159,11 @@ const VisNode = ({ data, id }) => {
|
||||
);
|
||||
}
|
||||
layout.hovermode = 'closest';
|
||||
|
||||
if (metric_ax_label)
|
||||
layout.xaxis = {
|
||||
title: { font: {size: 12}, text: metric_ax_label },
|
||||
};
|
||||
} else {
|
||||
// There are multiple LLMs in the response; do a grouped box plot by LLM.
|
||||
// Note that 'name' is now the LLM, and 'x' stores the value of the var:
|
||||
@ -155,7 +177,7 @@ const VisNode = ({ data, id }) => {
|
||||
type: 'scatter3d',
|
||||
x: responses.map(r => r.vars[varnames[0]]).map(s => truncStr(s, 12)),
|
||||
y: responses.map(r => r.vars[varnames[1]]).map(s => truncStr(s, 12)),
|
||||
z: responses.map(r => r.eval_res.mean),
|
||||
z: responses.map(r => get_items(r.eval_res).reduce((acc, val) => (acc + val), 0) / r.eval_res.items.length), // calculates mean
|
||||
mode: 'markers',
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json, os, asyncio, sys, argparse, threading, traceback
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from statistics import mean, median, stdev
|
||||
from flask import Flask, request, jsonify, render_template, send_from_directory
|
||||
from flask_cors import CORS
|
||||
@ -66,35 +67,119 @@ def load_cache_json(filepath: str) -> dict:
|
||||
responses = json.load(f)
|
||||
return responses
|
||||
|
||||
class MetricType(Enum):
|
||||
KeyValue = 0
|
||||
KeyValue_Numeric = 1
|
||||
KeyValue_Categorical = 2
|
||||
KeyValue_Mixed = 3
|
||||
Numeric = 4
|
||||
Categorical = 5
|
||||
Mixed = 6
|
||||
Unknown = 7
|
||||
Empty = 8
|
||||
|
||||
def check_typeof_vals(arr: list) -> MetricType:
|
||||
if len(arr) == 0: return MetricType.Empty
|
||||
|
||||
def typeof_set(types: set) -> MetricType:
|
||||
if len(types) == 0: return MetricType.Empty
|
||||
if len(types) == 1 and next(iter(types)) == dict:
|
||||
return MetricType.KeyValue
|
||||
elif all((t in (int, float) for t in types)):
|
||||
# Numeric metrics only
|
||||
return MetricType.Numeric
|
||||
elif all((t in (str, bool) for t in types)):
|
||||
# Categorical metrics only ('bool' is True/False, counts as categorical)
|
||||
return MetricType.Categorical
|
||||
elif all((t in (int, float, bool, str) for t in types)):
|
||||
# Mix of numeric and categorical types
|
||||
return MetricType.Mixed
|
||||
else:
|
||||
# Mix of types beyond basic ones
|
||||
return MetricType.Unknown
|
||||
|
||||
def typeof_dict_vals(d):
|
||||
dict_val_type = typeof_set(set((type(v) for v in d.values())))
|
||||
if dict_val_type == MetricType.Numeric:
|
||||
return MetricType.KeyValue_Numeric
|
||||
elif dict_val_type == MetricType.Categorical:
|
||||
return MetricType.KeyValue_Categorical
|
||||
else:
|
||||
return MetricType.KeyValue_Mixed
|
||||
|
||||
# Checks type of all values in 'arr' and returns the type
|
||||
val_type = typeof_set(set((type(v) for v in arr)))
|
||||
if val_type == MetricType.KeyValue:
|
||||
# This is a 'KeyValue' pair type. We need to find the more specific type of the values in the dict.
|
||||
# First, we check that all dicts have the exact same keys
|
||||
for i in range(len(arr)-1):
|
||||
d, e = arr[i], arr[i+1]
|
||||
if set(d.keys()) != set(e.keys()):
|
||||
raise Exception('The keys and size of dicts for evaluation results must be consistent across evaluations.')
|
||||
|
||||
# Then, we check the consistency of the type of dict values:
|
||||
first_dict_val_type = typeof_dict_vals(arr[0])
|
||||
for d in arr[1:]:
|
||||
if first_dict_val_type != typeof_dict_vals(d):
|
||||
raise Exception('Types of values in dicts for evaluation results must be consistent across responses.')
|
||||
# If we're here, all checks passed, and we return the more specific KeyValue type:
|
||||
return first_dict_val_type
|
||||
else:
|
||||
return val_type
|
||||
|
||||
def run_over_responses(eval_func, responses: list, scope: str) -> list:
|
||||
for resp_obj in responses:
|
||||
res = resp_obj['responses']
|
||||
if scope == 'response':
|
||||
evals = [ # Run evaluator func over every individual response text
|
||||
eval_func(
|
||||
ResponseInfo(
|
||||
text=r,
|
||||
prompt=resp_obj['prompt'],
|
||||
var=resp_obj['vars'],
|
||||
llm=resp_obj['llm'])
|
||||
) for r in res
|
||||
]
|
||||
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
|
||||
'mean': mean(evals),
|
||||
'median': median(evals),
|
||||
'stdev': stdev(evals) if len(evals) > 1 else 0,
|
||||
'range': (min(evals), max(evals)),
|
||||
'items': evals,
|
||||
}
|
||||
else: # operate over the entire response batch
|
||||
# Run evaluator func over every individual response text
|
||||
evals = [eval_func(
|
||||
ResponseInfo(
|
||||
text=r,
|
||||
prompt=resp_obj['prompt'],
|
||||
var=resp_obj['vars'],
|
||||
llm=resp_obj['llm'])
|
||||
) for r in res]
|
||||
|
||||
# Check the type of evaluation results
|
||||
# NOTE: We assume this is consistent across all evaluations, but it may not be.
|
||||
eval_res_type = check_typeof_vals(evals)
|
||||
|
||||
if eval_res_type == MetricType.Numeric:
|
||||
# Store items with summary of mean, median, etc
|
||||
resp_obj['eval_res'] = {
|
||||
'mean': mean(evals),
|
||||
'median': median(evals),
|
||||
'stdev': stdev(evals) if len(evals) > 1 else 0,
|
||||
'range': (min(evals), max(evals)),
|
||||
'items': evals,
|
||||
'dtype': eval_res_type.name,
|
||||
}
|
||||
elif eval_res_type in (MetricType.Unknown, MetricType.Empty):
|
||||
raise Exception('Unsupported types found in evaluation results. Only supported types for metrics are: int, float, bool, str.')
|
||||
else:
|
||||
# Categorical, KeyValue, etc, we just store the items:
|
||||
resp_obj['eval_res'] = {
|
||||
'items': evals,
|
||||
'dtype': eval_res_type.name,
|
||||
}
|
||||
else:
|
||||
# Run evaluator func over the entire response batch
|
||||
ev = eval_func(res)
|
||||
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
|
||||
'mean': ev,
|
||||
'median': ev,
|
||||
'stdev': 0,
|
||||
'range': (ev, ev),
|
||||
'items': [ev],
|
||||
}
|
||||
ev_type = typeof_set(set((type(ev),)))
|
||||
if ev_type == MetricType.Numeric:
|
||||
resp_obj['eval_res'] = {
|
||||
'mean': ev,
|
||||
'median': ev,
|
||||
'stdev': 0,
|
||||
'range': (ev, ev),
|
||||
'items': [ev],
|
||||
'type': ev_type.name,
|
||||
}
|
||||
else:
|
||||
resp_obj['eval_res'] = {
|
||||
'items': [ev],
|
||||
'type': ev_type.name,
|
||||
}
|
||||
return responses
|
||||
|
||||
def reduce_responses(responses: list, vars: list) -> list:
|
||||
|
Loading…
x
Reference in New Issue
Block a user