This commit is contained in:
Ian Arawjo 2024-03-23 20:00:35 -04:00
parent 67432f9853
commit a6309ad444
6 changed files with 135 additions and 81 deletions

View File

@ -11,6 +11,10 @@ import React, {
import {
DragDropContext,
Draggable,
DraggableProvided,
DraggableRubric,
DraggableStateSnapshot,
DroppableProvided,
OnDragEndResponder,
} from "react-beautiful-dnd";
import { Menu } from "@mantine/core";
@ -205,7 +209,7 @@ export function LLMList({
<DragDropContext onDragEnd={onDragEnd}>
<StrictModeDroppable
droppableId="llm-list-droppable"
renderClone={(provided, snapshot, rubric) => (
renderClone={(provided: DraggableProvided, snapshot: DraggableStateSnapshot, rubric: DraggableRubric) => (
<LLMListItemClone
provided={provided}
snapshot={snapshot}
@ -214,10 +218,10 @@ export function LLMList({
/>
)}
>
{(provided) => (
{(provided: DroppableProvided) => (
<div {...provided.droppableProps} ref={provided.innerRef}>
{items.map((item, index) => (
<Draggable key={item.key} draggableId={item.key} index={index}>
<Draggable key={item.key} draggableId={item.key ?? index.toString()} index={index}>
{(provided, snapshot) => (
<LLMListItem
provided={provided}
@ -240,7 +244,27 @@ export function LLMList({
);
}
export const LLMListContainer = forwardRef(function LLMListContainer(
export interface LLMListContainerHandle {
resetLLMItemsProgress: () => void;
setZeroPercProgress: () => void;
updateProgress: (itemProcessorFunc: (llm: LLMSpec) => LLMSpec) => void;
ensureLLMItemsErrorProgress: (llm_keys_w_errors: string[]) => void;
getLLMListItemForKey: (key: string) => LLMSpec | undefined;
refreshLLMProviderList: () => void;
}
export interface LLMListContainerProps {
description: string;
modelSelectButtonText: string;
initLLMItems: LLMSpec[];
onSelectModel: (llm: LLMSpec, new_llms: LLMSpec[]) => void;
onItemsChange: (new_llms: LLMSpec[], old_llms: LLMSpec[]) => void;
hideTrashIcon: boolean;
bgColor: string;
selectModelAction?: "add" | "replace";
}
export const LLMListContainer = forwardRef<LLMListContainerHandle, LLMListContainerProps>(function LLMListContainer(
{
description,
modelSelectButtonText,
@ -293,16 +317,16 @@ export const LLMListContainer = forwardRef(function LLMListContainer(
);
}, [llmItemsCurrState]);
const updateProgress = useCallback(
(itemProcessorFunc: (llms: LLMSpec) => LLMSpec) => {
(itemProcessorFunc: (llm: LLMSpec) => LLMSpec) => {
setLLMItems(llmItemsCurrState.map(itemProcessorFunc));
},
[llmItemsCurrState],
);
const ensureLLMItemsErrorProgress = useCallback(
(llm_keys_w_errors) => {
(llm_keys_w_errors: string[]) => {
setLLMItems(
llmItemsCurrState.map((item) => {
if (llm_keys_w_errors.includes(item.key)) {
if (item.key !== undefined && llm_keys_w_errors.includes(item.key)) {
if (!item.progress) item.progress = { success: 0, error: 100 };
else {
const succ_perc = item.progress.success;
@ -321,7 +345,7 @@ export const LLMListContainer = forwardRef(function LLMListContainer(
);
const getLLMListItemForKey = useCallback(
(key) => {
(key: string) => {
return llmItemsCurrState.find((item) => item.key === key);
},
[llmItemsCurrState],
@ -353,7 +377,7 @@ export const LLMListContainer = forwardRef(function LLMListContainer(
item.name = unique_name;
item.formData = { shortname: unique_name };
let new_items;
let new_items: LLMSpec[] = [];
if (selectModelAction === "add" || selectModelAction === undefined) {
// Add model to the LLM list (regardless of it's present already or not).
new_items = llmItemsCurrState.concat([item]);

View File

@ -80,10 +80,10 @@ export interface LLMListItemProps {
item: LLMSpec;
provided: DraggableProvided;
snapshot: DraggableStateSnapshot;
removeCallback: (key: string) => void;
onClickSettings: () => void;
removeCallback?: (key: string) => void;
onClickSettings?: () => void;
progress?: QueryProgress;
hideTrashIcon: boolean;
hideTrashIcon?: boolean;
}
const LLMListItem: React.FC<LLMListItemProps> = ({
@ -136,7 +136,7 @@ const LLMListItem: React.FC<LLMListItemProps> = ({
)}
</CardHeader>
<LLMItemButtonGroup
onClickTrash={() => removeCallback(item.key ?? "undefined")}
onClickTrash={() => removeCallback && removeCallback(item.key ?? "undefined")}
ringProgress={progress}
onClickSettings={onClickSettings}
hideTrashIcon={hideTrashIcon}

View File

@ -38,13 +38,14 @@ import {
genResponseTextsDisplay,
} from "./ResponseBoxes";
import { getLabelForResponse } from "./ResponseRatingToolbar";
import { Dict, LLMResponse } from "./backend/typing";
// Helper funcs
const getLLMName = (resp_obj) =>
const getLLMName = (resp_obj: LLMResponse) =>
typeof resp_obj?.llm === "string" ? resp_obj.llm : resp_obj?.llm?.name;
const escapeRegExp = (txt) => txt.replace(/[-[\]{}()*+?.,\\^$|#\s]/g, "\\$&");
const escapeRegExp = (txt: string) => txt.replace(/[-[\]{}()*+?.,\\^$|#\s]/g, "\\$&");
function getIndicesOfSubstringMatches(s, substr, caseSensitive) {
function getIndicesOfSubstringMatches(s: string, substr: string, caseSensitive?: boolean) {
const regex = new RegExp(
escapeRegExp(substr),
"g" + (caseSensitive ? "" : "i"),
@ -56,7 +57,7 @@ function getIndicesOfSubstringMatches(s, substr, caseSensitive) {
}
// Splits a string by a substring or regex, but includes the delimiter (substring/regex match) elements in the returned array.
function splitAndIncludeDelimiter(s, substr, caseSensitive) {
function splitAndIncludeDelimiter(s: string, substr: string, caseSensitive?: boolean) {
const indices = getIndicesOfSubstringMatches(s, substr, caseSensitive);
if (indices.length === 0) return [s];
@ -78,7 +79,7 @@ function splitAndIncludeDelimiter(s, substr, caseSensitive) {
}
// Returns an HTML version of text where 'searchValue' is highlighted.
function genSpansForHighlightedValue(text, searchValue, caseSensitive) {
function genSpansForHighlightedValue(text: string, searchValue: string, caseSensitive?: boolean) {
// Split texts by searchValue and map to <span> and <mark> elements
return splitAndIncludeDelimiter(text, searchValue, caseSensitive).map(
(s, idx) => {
@ -94,7 +95,7 @@ function genSpansForHighlightedValue(text, searchValue, caseSensitive) {
}
// Export the JSON responses to an excel file (downloads the file):
export const exportToExcel = (jsonResponses, filename) => {
export const exportToExcel = (jsonResponses: LLMResponse[], filename: string) => {
if (!filename) filename = "responses.xlsx";
// Check that there are responses to export:
@ -123,7 +124,7 @@ export const exportToExcel = (jsonResponses, filename) => {
};
const eval_res_items = res_obj.eval_res ? res_obj.eval_res.items : null;
return res_obj.responses.map((r, r_idx) => {
const row = {
const row: Dict<string | number | boolean> = {
LLM: llm,
Prompt: prompt,
Response: r,
@ -173,17 +174,23 @@ export const exportToExcel = (jsonResponses, filename) => {
XLSX.writeFile(wb, filename);
};
const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
export interface LLMResponseInspectorProps {
jsonResponses: LLMResponse[];
wideFormat?: boolean;
}
const LLMResponseInspector: React.FC<LLMResponseInspectorProps> = ({ jsonResponses, wideFormat }) => {
// Responses
const [responses, setResponses] = useState([]);
const [responseDivs, setResponseDivs] = useState<React.ReactNode>([]);
const [receivedResponsesOnce, setReceivedResponsesOnce] = useState(false);
// The type of view to use to display responses. Can be either hierarchy or table.
const [viewFormat, setViewFormat] = useState("hierarchy");
// The MultiSelect so people can dynamically set what vars they care about
const [multiSelectVars, setMultiSelectVars] = useState([]);
const [multiSelectValue, setMultiSelectValue] = useState([]);
const [multiSelectVars, setMultiSelectVars] = useState<{value: string; label: string}[]>([]);
const [multiSelectValue, setMultiSelectValue] = useState<string[]>([]);
// Search bar functionality
const [searchValue, setSearchValue] = useState("");
@ -236,13 +243,13 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
return;
// Find all vars in responses
let found_vars = new Set();
let found_metavars = new Set();
let found_llms = new Set();
let found_vars: Array<string> | Set<string> = new Set<string>();
let found_metavars: Array<string> | Set<string> = new Set<string>();
let found_llms: Array<string> | Set<string> = new Set<string>();
batchedResponses.forEach((res_obj) => {
Object.keys(res_obj.vars).forEach((v) => found_vars.add(v));
Object.keys(res_obj.metavars).forEach((v) => found_metavars.add(v));
found_llms.add(getLLMName(res_obj));
Object.keys(res_obj.vars).forEach((v) => (found_vars as Set<string>).add(v));
Object.keys(res_obj.metavars).forEach((v) => (found_metavars as Set<string>).add(v));
(found_llms as Set<string>).add(getLLMName(res_obj));
});
found_vars = Array.from(found_vars);
found_metavars = Array.from(found_metavars).filter(cleanMetavarsFilterFunc);
@ -256,7 +263,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
// Set the variables accessible in the MultiSelect for 'group by'
const msvars = found_vars
.map((name) =>
.map((name: string) =>
// We add a $ prefix to mark this as a prompt parameter, and so
// in the future we can add special types of variables without name collisions
({ value: name, label: name }),
@ -310,7 +317,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
);
// Functions to associate a color to each LLM in responses
const color_for_llm = (llm) => getColorForLLMAndSetIfNotFound(llm) + "99";
const color_for_llm = (llm: string) => getColorForLLMAndSetIfNotFound(llm) + "99";
const header_bg_colors = ["#e0f4fa", "#c0def9", "#a9c0f9", "#a6b2ea"];
const response_box_colors = [
"#eee",
@ -321,11 +328,11 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
"#ddd",
"#eee",
];
const rgroup_color = (depth) =>
const rgroup_color = (depth: number) =>
response_box_colors[depth % response_box_colors.length];
const getHeaderBadge = (key, val, depth) => {
if (val) {
const getHeaderBadge = (key: string, val: string | undefined, depth: number) => {
if (val !== undefined) {
const s = truncStr(val.trim(), 1024);
return (
<div
@ -346,11 +353,11 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
}
};
const generateResponseBoxes = (resps, eatenvars, fixed_width) => {
const generateResponseBoxes = (resps: LLMResponse[], eatenvars: string[], fixed_width: number) => {
const hide_llm_name = eatenvars.includes("LLM");
return resps.map((res_obj, res_idx) => {
// If user has searched for something, further filter the response texts by only those that contain the search term
const respsFilterFunc = (responses) => {
const respsFilterFunc = (responses: string[]) => {
if (searchValue.length === 0) return responses;
const filtered_resps = responses.filter(
search_regex.test.bind(search_regex),
@ -399,8 +406,8 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
if (viewFormat === "table") {
// Generate a table, with default columns for: input vars, LLMs queried
// First get column names as input vars + LLMs:
let var_cols, colnames, getColVal, found_sel_var_vals;
let metavar_cols = []; // found_metavars; -- Disabling this functionality for now, since it is usually annoying.
let var_cols: string[], colnames: string[], getColVal: (r: LLMResponse) => string | number | boolean | undefined, found_sel_var_vals: string[];
let metavar_cols: string[] = []; // found_metavars; -- Disabling this functionality for now, since it is usually annoying.
if (tableColVar === "LLM") {
var_cols = found_vars;
getColVal = getLLMName;
@ -418,16 +425,16 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
responses.reduce((acc, res_obj) => {
acc.add(
tableColVar in res_obj.vars
? res_obj.vars[tableColVar]
? res_obj.vars[tableColVar] as string
: "(unspecified)",
);
return acc;
}, new Set()),
}, new Set<string>()),
);
colnames = var_cols.concat(found_sel_var_vals);
}
const getVar = (r, v) => (v === "LLM" ? getLLMName(r) : r.vars[v]);
const getVar = (r: LLMResponse, v: string) => (v === "LLM" ? getLLMName(r) : r.vars[v]);
// Then group responses by prompts. Each prompt will become a separate row of the table (will be treated as unique)
const responses_by_prompt = groupResponsesBy(responses, (r) =>
@ -486,7 +493,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
},
);
setResponses([
setResponseDivs([
<Table key="table" fontSize={wideFormat ? "sm" : "xs"}>
<thead>
<tr>
@ -506,7 +513,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
// :: nested divs first grouped by LLM (first level), then by var1, then var2 (deepest level).
let leaf_id = 0;
let first_opened = false;
const groupByVars = (resps, varnames, eatenvars, header) => {
const groupByVars = (resps: LLMResponse[], varnames: string[], eatenvars: string[], header: React.ReactNode) => {
if (resps.length === 0) return [];
if (varnames.length === 0) {
// Base case. Display n response(s) to each single prompt, back-to-back:
@ -560,16 +567,16 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
);
const get_header =
group_name === "LLM"
? (key, val) => (
? (key: string, val?: string) => (
<div
key={val}
style={{ backgroundColor: color_for_llm(val) }}
style={{ backgroundColor: val ? color_for_llm(val) : "#eee" }}
className="response-llm-header"
>
{val}
</div>
)
: (key, val) => getHeaderBadge(key, val, eatenvars.length);
: (key: string, val?: string) => getHeaderBadge(key, val, eatenvars.length);
// Now produce nested divs corresponding to the groups
const remaining_vars = varnames.slice(1);
@ -630,7 +637,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
// Produce DIV elements grouped by selected vars
const divs = groupByVars(responses, selected_vars, [], null);
setResponses(divs);
setResponseDivs(divs);
}
setNumMatches(numResponsesDisplayed);
@ -651,15 +658,14 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
// When the user clicks an item in the drop-down,
// we want to autoclose the multiselect drop-down:
const multiSelectRef = useRef(null);
const handleMultiSelectValueChange = (new_val) => {
if (multiSelectRef) {
const multiSelectRef = useRef<HTMLInputElement>(null);
const handleMultiSelectValueChange = (new_val: string[]) => {
if (multiSelectRef?.current)
multiSelectRef.current.blur();
}
setMultiSelectValue(new_val);
};
const handleSearchValueChange = (content) => {
const handleSearchValueChange = (content: React.ChangeEvent<HTMLInputElement>) => {
setSearchValue(content.target.value);
};
@ -730,7 +736,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
<div style={{ height: "100%" }}>
<Tabs
value={viewFormat}
onTabChange={setViewFormat}
onTabChange={(val) => setViewFormat(val ?? "hierarchy")}
styles={{ tabLabel: { fontSize: wideFormat ? "12pt" : "9pt" } }}
>
<Tabs.List>
@ -804,7 +810,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
</Tabs.Panel>
</Tabs>
<div className="nowheel nodrag">{responses}</div>
<div className="nowheel nodrag">{responseDivs}</div>
</div>
);
};

View File

@ -29,7 +29,9 @@ export interface ModelSettingsModalRef {
}
export interface ModelSettingsModalProps {
model?: LLMSpec;
onSettingsSubmit?: () => void;
onSettingsSubmit?: (savedItem: LLMSpec,
formData: Dict<JSONCompatible>,
settingsData: Dict<JSONCompatible>) => void;
}
type FormData = LLMSpec["formData"];

View File

@ -11,8 +11,10 @@ import StorageCache from "./backend/cache";
import useStore from "./store";
import { deepcopy } from "./backend/utils";
const getRatingKeyForResponse = (uid, label_name) => `r.${uid}.${label_name}`;
const collapse_ratings = (rating_dict, idxs) => {
type RatingDict = Record<number, boolean | string | undefined>;
const getRatingKeyForResponse = (uid: string, label_name: string) => `r.${uid}.${label_name}`;
const collapse_ratings = (rating_dict: RatingDict, idxs: number[]) => {
if (rating_dict === undefined) return undefined;
for (let j = 0; j < idxs.length; j++) {
if (idxs[j] in rating_dict && rating_dict[idxs[j]] !== undefined)
@ -21,14 +23,20 @@ const collapse_ratings = (rating_dict, idxs) => {
return undefined;
};
export const getLabelForResponse = (uid, label_name) => {
export const getLabelForResponse = (uid: string, label_name: string) => {
return StorageCache.get(getRatingKeyForResponse(uid, label_name));
};
export const setLabelForResponse = (uid, label_name, payload) => {
export const setLabelForResponse = (uid: string, label_name: string, payload: RatingDict) => {
StorageCache.store(getRatingKeyForResponse(uid, label_name), payload);
};
const ToolbarButton = forwardRef(function ToolbarButton(
interface ToolbarButtonProps {
selected: boolean;
onClick: () => void;
children: React.ReactNode;
}
const ToolbarButton = forwardRef<HTMLButtonElement, ToolbarButtonProps>(function ToolbarButton(
{ selected, onClick, children },
ref,
) {
@ -47,7 +55,14 @@ const ToolbarButton = forwardRef(function ToolbarButton(
);
});
const ResponseRatingToolbar = ({
export interface ResponseRatingToolbarProps {
uid: string;
wideFormat?: boolean;
innerIdxs: number[];
onUpdateResponses: () => void;
}
const ResponseRatingToolbar: React.FC<ResponseRatingToolbarProps> = ({
uid,
wideFormat,
innerIdxs,
@ -62,10 +77,10 @@ const ResponseRatingToolbar = ({
// :: for this component changes.
// const state = useStore((store) => store.state);
const setState = useStore((store) => store.setState);
const gradeState = useStore((store) => store.state[gradeKey]);
const noteState = useStore((store) => store.state[noteKey]);
const gradeState = useStore<RatingDict>((store) => store.state[gradeKey]);
const noteState = useStore<RatingDict>((store) => store.state[noteKey]);
const setRating = useCallback(
(uid, label, payload) => {
(uid: string, label: string, payload: RatingDict) => {
const key = getRatingKeyForResponse(uid, label);
const safe_payload = deepcopy(payload);
setState(key, safe_payload);
@ -91,7 +106,7 @@ const ResponseRatingToolbar = ({
// Override the text in the internal textarea whenever upstream annotation changes.
useEffect(() => {
setNoteText(note);
setNoteText(note !== undefined ? note.toString() : "");
}, [note]);
// The label for the pop-up comment box.
@ -107,22 +122,22 @@ const ResponseRatingToolbar = ({
}, [wideFormat]);
// For human labeling of responses in the inspector
const onGrade = (grade) => {
const onGrade = (grade: boolean | undefined) => {
if (uid === undefined) return;
const new_grades = gradeState ?? {};
innerIdxs.forEach((idx) => {
innerIdxs.forEach((idx: number) => {
new_grades[idx] = grade;
});
setRating(uid, "grade", new_grades);
if (onUpdateResponses) onUpdateResponses();
};
const onAnnotate = (label) => {
const onAnnotate = (label?: string) => {
if (uid === undefined) return;
if (typeof label === "string" && label.trim().length === 0)
label = undefined; // empty strings are undefined
const new_notes = noteState ?? {};
innerIdxs.forEach((idx) => {
innerIdxs.forEach((idx: number) => {
new_notes[idx] = label;
});
setRating(uid, "note", new_notes);

View File

@ -23,6 +23,7 @@ import {
TemplateVarInfo,
BaseLLMResponseObject,
LLMSpec,
EvaluationScore,
} from "./typing";
import { v4 as uuid } from "uuid";
import { StringTemplate } from "./template";
@ -1786,10 +1787,10 @@ export const truncStr = (
export const groupResponsesBy = (
responses: LLMResponse[],
keyFunc: (item: LLMResponse) => string | null | undefined,
) => {
const responses_by_key: Dict = {};
const unspecified_group: Dict[] = [];
keyFunc: (item: LLMResponse) => string | number | null | undefined,
): [Dict<LLMResponse[]>, LLMResponse[]] => {
const responses_by_key: Dict<LLMResponse[]> = {};
const unspecified_group: LLMResponse[] = [];
responses.forEach((item) => {
const key = keyFunc(item);
if (key === null || key === undefined) {
@ -1802,22 +1803,28 @@ export const groupResponsesBy = (
return [responses_by_key, unspecified_group];
};
export const batchResponsesByUID = (responses: LLMResponse[]) => {
/**
* Merges inner .responses and eval_res.items properties for LLMResponses with the same
* uid, returning the (smaller) list of merged items.
* @param responses
* @returns
*/
export const batchResponsesByUID = (responses: LLMResponse[]): LLMResponse[] => {
const [batches, unspecified_id_group] = groupResponsesBy(
responses,
(resp_obj) => resp_obj.uid,
);
return Object.values(batches)
.map((resp_objs: Dict[]) => {
.map((resp_objs: LLMResponse[]) => {
if (resp_objs.length === 1) {
return resp_objs[0];
} else {
const batched = deepcopy_and_modify(resp_objs[0], {
responses: resp_objs.map((resp_obj) => resp_obj.responses).flat(),
});
if (batched.eval_res !== undefined) {
}) as LLMResponse;
if (batched.eval_res?.items !== undefined) {
batched.eval_res.items = resp_objs
.map((resp_obj) => resp_obj.eval_res.items)
.map((resp_obj) => resp_obj?.eval_res?.items as EvaluationScore[])
.flat();
}
return batched;
@ -1911,8 +1918,8 @@ export async function retryAsyncFunc<T>(
// Filters internally used keys LLM_{idx} and __{str} from metavar dictionaries.
// This method is used to pass around information hidden from the user.
export function cleanMetavarsFilterFunc() {
return (key: string) => !(key.startsWith("LLM_") || key.startsWith("__pt"));
export function cleanMetavarsFilterFunc(key: string) {
return !(key.startsWith("LLM_") || key.startsWith("__pt"));
}
// Verify data integrity: check that uids are present for all responses.