From 548e84dba098cde6fb7bdc25ea6970400afa0538 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Mon, 24 Feb 2025 17:15:39 -0500 Subject: [PATCH] Merge toolbar into Mantine React Table toolbar, to save space --- .../react-server/src/LLMResponseInspector.tsx | 137 +++++++++++------- chainforge/react-server/src/PromptNode.tsx | 1 - chainforge/react-server/src/ResponseBoxes.tsx | 50 +++++-- 3 files changed, 117 insertions(+), 71 deletions(-) diff --git a/chainforge/react-server/src/LLMResponseInspector.tsx b/chainforge/react-server/src/LLMResponseInspector.tsx index 6097a22..a81e70b 100644 --- a/chainforge/react-server/src/LLMResponseInspector.tsx +++ b/chainforge/react-server/src/LLMResponseInspector.tsx @@ -24,6 +24,7 @@ import { Stack, ScrollArea, LoadingOverlay, + Button, } from "@mantine/core"; import { useToggle } from "@mantine/hooks"; import { @@ -41,6 +42,9 @@ import { type MRT_SortingState, type MRT_Virtualizer, MRT_Row, + MRT_ShowHideColumnsButton, + MRT_ToggleFiltersButton, + MRT_ToggleDensePaddingButton, } from "mantine-react-table"; import * as XLSX from "xlsx"; import useStore from "./store"; @@ -201,7 +205,7 @@ export const exportToExcel = ( const data = jsonResponses .map((res_obj, res_obj_idx) => { const llm = getLLMName(res_obj); - const prompt = res_obj.prompt; + const prompt = StringLookup.get(res_obj.prompt) ?? ""; const vars = res_obj.vars; const metavars = res_obj.metavars ?? {}; const ratings = { @@ -219,7 +223,7 @@ export const exportToExcel = ( // Add columns for vars Object.entries(vars).forEach(([varname, val]) => { - row[`Var: ${varname}`] = val; + row[`Var: ${varname}`] = StringLookup.get(val) ?? ""; }); // Add column(s) for human ratings, if present @@ -253,7 +257,7 @@ export const exportToExcel = ( // Add columns for metavars, if present Object.entries(metavars).forEach(([varname, val]) => { if (!cleanMetavarsFilterFunc(varname)) return; // skip llm group metavars - row[`Metavar: ${varname}`] = val; + row[`Metavar: ${varname}`] = StringLookup.get(val) ?? ""; }); return row; @@ -281,7 +285,9 @@ const LLMResponseInspector: React.FC = ({ const [receivedResponsesOnce, setReceivedResponsesOnce] = useState(false); // The type of view to use to display responses. Can be either hierarchy or table. - const [viewFormat, setViewFormat] = useState("hierarchy"); + const [viewFormat, setViewFormat] = useState( + wideFormat ? "table" : "hierarchy", + ); // The MultiSelect so people can dynamically set what vars they care about const [multiSelectVars, setMultiSelectVars] = useState< @@ -328,6 +334,30 @@ const LLMResponseInspector: React.FC = ({ columnResizeMode: "onEnd", enableStickyHeader: true, initialState: { density: "md", pagination: { pageSize: 30, pageIndex: 0 } }, + renderToolbarInternalActions: ({ table }) => ( + <> + {/* built-in buttons (must pass in table prop for them to work!) */} + + + + + ), + renderTopToolbarCustomActions: () => ( + + { + setTableColVar(event.currentTarget.value); + setUserSelectedTableCol(true); + }} + data={multiSelectVars} + label="Select main column variable:" + size={sz} + w={wideFormat ? "50%" : "100%"} + /> + {searchBar} + + ), }); // The var name to use for columns in the table view @@ -663,31 +693,25 @@ const LLMResponseInspector: React.FC = ({ const val = resp_objs[0].metavars[v]; return val !== undefined ? val : "(unspecified)"; }); - let eval_cols_vals: string[] = []; + let eval_cols_vals: [string | JSX.Element, string][][] = []; if (eval_res_cols && eval_res_cols.length > 0) { // We can assume that there's only one response object, since to // if eval_res_cols is set, there must be only one LLM. eval_cols_vals = eval_res_cols.map((metric_name, metric_idx) => { const items = resp_objs[0].eval_res?.items; - if (!items) return "(no result)"; - return items - .map((item) => { - if (item === undefined) return "(undefined)"; - if ( - typeof item !== "object" && - metric_idx === 0 && - metric_name === "Score" - ) - return getEvalResultStr(item, true, true) as string; - else if (typeof item === "object" && metric_name in item) - return getEvalResultStr( - item[metric_name], - true, - true, - ) as string; - else return "(unspecified)"; - }) - .join("\n"); // treat n>1 resps per prompt as multi-line results in the column + if (!items) return [["(no result)", "(no result)"]]; + return items.map((item) => { + if (item === undefined) return ["(undefined)", "(undefined)"]; + if ( + typeof item !== "object" && + metric_idx === 0 && + metric_name === "Score" + ) + return getEvalResultStr(item, true); + else if (typeof item === "object" && metric_name in item) + return getEvalResultStr(item[metric_name], true); + else return ["(unspecified)", "(unspecified)"]; + }); // treat n>1 resps per prompt as multi-line results in the column }); } @@ -719,7 +743,11 @@ const LLMResponseInspector: React.FC = ({ }); const row: Dict< - string | undefined | LLMResponse[] | LLMResponseData[] + | string + | undefined + | LLMResponse[] + | LLMResponseData[] + | { type: "eval"; data: (string | JSX.Element)[][] } > = {}; let vals_arr_start_idx = 0; var_cols_vals.forEach((v, i) => { @@ -738,7 +766,10 @@ const LLMResponseInspector: React.FC = ({ }); vals_arr_start_idx += sel_var_cols.length; eval_cols_vals.forEach((v, i) => { - row[`c${i + vals_arr_start_idx}`] = StringLookup.get(v); + row[`c${i + vals_arr_start_idx}`] = { + type: "eval", + data: v, + }; }); return row; @@ -780,12 +811,12 @@ const LLMResponseInspector: React.FC = ({ Object.entries(row).forEach(([cname, val]) => { if (val === undefined || cname[0] === "o") return; if (!(cname in colAvgNumChars)) colAvgNumChars[cname] = 0; - const hasLLMResps = !(typeof val === "string"); + const hasLLMResps = + !(typeof val === "string") && Array.isArray(val); if (hasLLMResps && !colHasLLMResponses.has(cname)) colHasLLMResponses.add(cname); // Count the number of chars in the total text that will be displaced in this cell, // and add it to the count: - const numChars = hasLLMResps ? val .map((r) => @@ -794,7 +825,11 @@ const LLMResponseInspector: React.FC = ({ .join(""), ) .join("").length - : val.length; + : typeof val === "string" + ? val.length + : (val.data as (string | JSX.Element)[][]) + .map((e) => e[1]) + .join("").length; colAvgNumChars[cname] += (numChars * 1.0) / numRows; // we apply the averaging here for speed }); }); @@ -805,7 +840,11 @@ const LLMResponseInspector: React.FC = ({ // Get the text for this row. Used when filtering or sorting. const val = row[`c${i}`]; if (typeof val === "string" || val === undefined) return val; - else + else if ("type" in val && val.type === "eval") { + return (val.data as (string | JSX.Element)[][]) + .map((e) => e[1]) + .join("\n"); + } else return (val as LLMResponse[]) .flatMap((r) => r.responses) .map(llmResponseDataToString) @@ -820,7 +859,15 @@ const LLMResponseInspector: React.FC = ({ Cell: ({ cell, row }: { cell: MRT_Cell; row: any }) => { const val = row.original[`c${i}`]; if (typeof val === "string") return val; - else + else if ("type" in val && val.type === "eval") { + return ( + + {(val.data as [string | JSX.Element, string][]).map( + (e) => e[0], + )} + + ); + } else return ( {generateResponseBoxes( @@ -839,13 +886,14 @@ const LLMResponseInspector: React.FC = ({ ), mantineTableBodyCellProps: (() => { + const fz = wideFormat ? {} : { fontSize: 12 }; // text font size when in drawer should be smaller if (colHasLLMResponses.has(`c${i}`)) return { style: { padding: "4px 2px 0px 2px", verticalAlign: "top" }, // Adjusts overall padding & spacing }; else return { - style: { lineHeight: 1.2 }, + style: { lineHeight: 1.2, ...fz }, }; })(), })) as MRT_ColumnDef[]; @@ -1069,7 +1117,7 @@ const LLMResponseInspector: React.FC = ({ } autoComplete="off" size={sz} - placeholder={"Search keywords"} + placeholder={"Search responses"} w="100%" value={searchValue} onChange={handleSearchValueChange} @@ -1176,28 +1224,7 @@ const LLMResponseInspector: React.FC = ({ - - { - setTableColVar(event.currentTarget.value); - setUserSelectedTableCol(true); - }} - data={multiSelectVars} - label="Select main column variable:" - size={sz} - w={wideFormat ? "50%" : "100%"} - /> - {searchBar} - setOnlyShowScores(e.currentTarget.checked)} - mb="md" - size={sz} - display={showEvalScoreOptions ? "inherit" : "none"} - /> - + <> diff --git a/chainforge/react-server/src/PromptNode.tsx b/chainforge/react-server/src/PromptNode.tsx index 6915614..128f1b7 100644 --- a/chainforge/react-server/src/PromptNode.tsx +++ b/chainforge/react-server/src/PromptNode.tsx @@ -990,7 +990,6 @@ Soft failing by replacing undefined with empty strings.`, ? StringLookup.get(resp_obj.llm) ?? "(LLM lookup failed)" : resp_obj.llm.name; - console.log(o); return o; }), ) diff --git a/chainforge/react-server/src/ResponseBoxes.tsx b/chainforge/react-server/src/ResponseBoxes.tsx index 904d3ce..bf1f339 100644 --- a/chainforge/react-server/src/ResponseBoxes.tsx +++ b/chainforge/react-server/src/ResponseBoxes.tsx @@ -17,31 +17,50 @@ const ResponseRatingToolbar = lazy(() => import("./ResponseRatingToolbar")); /* HELPER FUNCTIONS */ const SUCCESS_EVAL_SCORES = new Set(["true", "yes"]); const FAILURE_EVAL_SCORES = new Set(["false", "no"]); +/** + * Returns an array of JSX elements, and the searchable text underpinning them, + * that represents a concrete version of the Evaluation Scores passed in. + * @param eval_item The evaluation result to visualize. + * @param hide_prefix Whether to hide 'score: ' or '{key}: ' prefixes when printing. + * @param onlyString Whether to only return string values. + * @returns An array [JSX.Element, string] where the latter is a string representation of the eval score, to enable search + */ export const getEvalResultStr = ( eval_item: EvaluationScore, hide_prefix: boolean, onlyString?: boolean, -): JSX.Element | string => { +): [JSX.Element | string, string] => { if (Array.isArray(eval_item)) { - return (hide_prefix ? "" : "scores: ") + eval_item.join(", "); + const items_str = (hide_prefix ? "" : "scores: ") + eval_item.join(", "); + return [items_str, items_str]; } else if (typeof eval_item === "object") { - const strs: (JSX.Element | string)[] = Object.keys(eval_item).map( + const strs: [JSX.Element | string, string][] = Object.keys(eval_item).map( (key, j) => { let val = eval_item[key]; if (typeof val === "number" && val.toString().indexOf(".") > -1) val = val.toFixed(4); // truncate floats to 4 decimal places - if (onlyString) return `${key}: ${getEvalResultStr(val, true, true)}`; + const [recurs_res, recurs_str] = getEvalResultStr(val, true); + if (onlyString) return [`${key}: ${recurs_str}`, recurs_str]; else - return ( + return [
{key}: - {getEvalResultStr(val, true)} -
- ); + {recurs_res} + , + recurs_str, + ]; }, ); - if (onlyString) return strs.join("\n"); - else return {strs}; + const joined_strs = strs.map((s) => s[1]).join("\n"); + if (onlyString) { + return [joined_strs, joined_strs]; + } else + return [ + + {strs} + , + joined_strs, + ]; } else { const eval_str = eval_item.toString().trim().toLowerCase(); const color = SUCCESS_EVAL_SCORES.has(eval_str) @@ -49,14 +68,15 @@ export const getEvalResultStr = ( : FAILURE_EVAL_SCORES.has(eval_str) ? "red" : "black"; - if (onlyString) return `score: ${eval_str}`; + if (onlyString) return [eval_str, eval_str]; else - return ( + return [ <> {!hide_prefix && {"score: "}} {eval_str} - - ); + , + eval_str, + ]; } }; @@ -271,7 +291,7 @@ export const genResponseTextsDisplay = ( )} {eval_res_items ? (

- {getEvalResultStr(resp_str_to_eval_res[r], true)} + {getEvalResultStr(resp_str_to_eval_res[r], true)[0]}

) : ( <>