diff --git a/chainforge/react-server/src/LLMListComponent.tsx b/chainforge/react-server/src/LLMListComponent.tsx index b39c11a..78501c4 100644 --- a/chainforge/react-server/src/LLMListComponent.tsx +++ b/chainforge/react-server/src/LLMListComponent.tsx @@ -11,6 +11,10 @@ import React, { import { DragDropContext, Draggable, + DraggableProvided, + DraggableRubric, + DraggableStateSnapshot, + DroppableProvided, OnDragEndResponder, } from "react-beautiful-dnd"; import { Menu } from "@mantine/core"; @@ -205,7 +209,7 @@ export function LLMList({ ( + renderClone={(provided: DraggableProvided, snapshot: DraggableStateSnapshot, rubric: DraggableRubric) => ( )} > - {(provided) => ( + {(provided: DroppableProvided) => (
{items.map((item, index) => ( - + {(provided, snapshot) => ( void; + setZeroPercProgress: () => void; + updateProgress: (itemProcessorFunc: (llm: LLMSpec) => LLMSpec) => void; + ensureLLMItemsErrorProgress: (llm_keys_w_errors: string[]) => void; + getLLMListItemForKey: (key: string) => LLMSpec | undefined; + refreshLLMProviderList: () => void; +} + +export interface LLMListContainerProps { + description: string; + modelSelectButtonText: string; + initLLMItems: LLMSpec[]; + onSelectModel: (llm: LLMSpec, new_llms: LLMSpec[]) => void; + onItemsChange: (new_llms: LLMSpec[], old_llms: LLMSpec[]) => void; + hideTrashIcon: boolean; + bgColor: string; + selectModelAction?: "add" | "replace"; +} + +export const LLMListContainer = forwardRef(function LLMListContainer( { description, modelSelectButtonText, @@ -293,16 +317,16 @@ export const LLMListContainer = forwardRef(function LLMListContainer( ); }, [llmItemsCurrState]); const updateProgress = useCallback( - (itemProcessorFunc: (llms: LLMSpec) => LLMSpec) => { + (itemProcessorFunc: (llm: LLMSpec) => LLMSpec) => { setLLMItems(llmItemsCurrState.map(itemProcessorFunc)); }, [llmItemsCurrState], ); const ensureLLMItemsErrorProgress = useCallback( - (llm_keys_w_errors) => { + (llm_keys_w_errors: string[]) => { setLLMItems( llmItemsCurrState.map((item) => { - if (llm_keys_w_errors.includes(item.key)) { + if (item.key !== undefined && llm_keys_w_errors.includes(item.key)) { if (!item.progress) item.progress = { success: 0, error: 100 }; else { const succ_perc = item.progress.success; @@ -321,7 +345,7 @@ export const LLMListContainer = forwardRef(function LLMListContainer( ); const getLLMListItemForKey = useCallback( - (key) => { + (key: string) => { return llmItemsCurrState.find((item) => item.key === key); }, [llmItemsCurrState], @@ -353,7 +377,7 @@ export const LLMListContainer = forwardRef(function LLMListContainer( item.name = unique_name; item.formData = { shortname: unique_name }; - let new_items; + let new_items: LLMSpec[] = []; if (selectModelAction === "add" || selectModelAction === undefined) { // Add model to the LLM list (regardless of it's present already or not). new_items = llmItemsCurrState.concat([item]); diff --git a/chainforge/react-server/src/LLMListItem.tsx b/chainforge/react-server/src/LLMListItem.tsx index c393dda..fc3c64f 100644 --- a/chainforge/react-server/src/LLMListItem.tsx +++ b/chainforge/react-server/src/LLMListItem.tsx @@ -80,10 +80,10 @@ export interface LLMListItemProps { item: LLMSpec; provided: DraggableProvided; snapshot: DraggableStateSnapshot; - removeCallback: (key: string) => void; - onClickSettings: () => void; + removeCallback?: (key: string) => void; + onClickSettings?: () => void; progress?: QueryProgress; - hideTrashIcon: boolean; + hideTrashIcon?: boolean; } const LLMListItem: React.FC = ({ @@ -136,7 +136,7 @@ const LLMListItem: React.FC = ({ )} removeCallback(item.key ?? "undefined")} + onClickTrash={() => removeCallback && removeCallback(item.key ?? "undefined")} ringProgress={progress} onClickSettings={onClickSettings} hideTrashIcon={hideTrashIcon} diff --git a/chainforge/react-server/src/LLMResponseInspector.js b/chainforge/react-server/src/LLMResponseInspector.tsx similarity index 89% rename from chainforge/react-server/src/LLMResponseInspector.js rename to chainforge/react-server/src/LLMResponseInspector.tsx index 3304d0d..e3b0af1 100644 --- a/chainforge/react-server/src/LLMResponseInspector.js +++ b/chainforge/react-server/src/LLMResponseInspector.tsx @@ -38,13 +38,14 @@ import { genResponseTextsDisplay, } from "./ResponseBoxes"; import { getLabelForResponse } from "./ResponseRatingToolbar"; +import { Dict, LLMResponse } from "./backend/typing"; // Helper funcs -const getLLMName = (resp_obj) => +const getLLMName = (resp_obj: LLMResponse) => typeof resp_obj?.llm === "string" ? resp_obj.llm : resp_obj?.llm?.name; -const escapeRegExp = (txt) => txt.replace(/[-[\]{}()*+?.,\\^$|#\s]/g, "\\$&"); +const escapeRegExp = (txt: string) => txt.replace(/[-[\]{}()*+?.,\\^$|#\s]/g, "\\$&"); -function getIndicesOfSubstringMatches(s, substr, caseSensitive) { +function getIndicesOfSubstringMatches(s: string, substr: string, caseSensitive?: boolean) { const regex = new RegExp( escapeRegExp(substr), "g" + (caseSensitive ? "" : "i"), @@ -56,7 +57,7 @@ function getIndicesOfSubstringMatches(s, substr, caseSensitive) { } // Splits a string by a substring or regex, but includes the delimiter (substring/regex match) elements in the returned array. -function splitAndIncludeDelimiter(s, substr, caseSensitive) { +function splitAndIncludeDelimiter(s: string, substr: string, caseSensitive?: boolean) { const indices = getIndicesOfSubstringMatches(s, substr, caseSensitive); if (indices.length === 0) return [s]; @@ -78,7 +79,7 @@ function splitAndIncludeDelimiter(s, substr, caseSensitive) { } // Returns an HTML version of text where 'searchValue' is highlighted. -function genSpansForHighlightedValue(text, searchValue, caseSensitive) { +function genSpansForHighlightedValue(text: string, searchValue: string, caseSensitive?: boolean) { // Split texts by searchValue and map to and elements return splitAndIncludeDelimiter(text, searchValue, caseSensitive).map( (s, idx) => { @@ -94,7 +95,7 @@ function genSpansForHighlightedValue(text, searchValue, caseSensitive) { } // Export the JSON responses to an excel file (downloads the file): -export const exportToExcel = (jsonResponses, filename) => { +export const exportToExcel = (jsonResponses: LLMResponse[], filename: string) => { if (!filename) filename = "responses.xlsx"; // Check that there are responses to export: @@ -123,7 +124,7 @@ export const exportToExcel = (jsonResponses, filename) => { }; const eval_res_items = res_obj.eval_res ? res_obj.eval_res.items : null; return res_obj.responses.map((r, r_idx) => { - const row = { + const row: Dict = { LLM: llm, Prompt: prompt, Response: r, @@ -173,17 +174,23 @@ export const exportToExcel = (jsonResponses, filename) => { XLSX.writeFile(wb, filename); }; -const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { + +export interface LLMResponseInspectorProps { + jsonResponses: LLMResponse[]; + wideFormat?: boolean; +} + +const LLMResponseInspector: React.FC = ({ jsonResponses, wideFormat }) => { // Responses - const [responses, setResponses] = useState([]); + const [responseDivs, setResponseDivs] = useState([]); const [receivedResponsesOnce, setReceivedResponsesOnce] = useState(false); // The type of view to use to display responses. Can be either hierarchy or table. const [viewFormat, setViewFormat] = useState("hierarchy"); // The MultiSelect so people can dynamically set what vars they care about - const [multiSelectVars, setMultiSelectVars] = useState([]); - const [multiSelectValue, setMultiSelectValue] = useState([]); + const [multiSelectVars, setMultiSelectVars] = useState<{value: string; label: string}[]>([]); + const [multiSelectValue, setMultiSelectValue] = useState([]); // Search bar functionality const [searchValue, setSearchValue] = useState(""); @@ -236,13 +243,13 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { return; // Find all vars in responses - let found_vars = new Set(); - let found_metavars = new Set(); - let found_llms = new Set(); + let found_vars: Array | Set = new Set(); + let found_metavars: Array | Set = new Set(); + let found_llms: Array | Set = new Set(); batchedResponses.forEach((res_obj) => { - Object.keys(res_obj.vars).forEach((v) => found_vars.add(v)); - Object.keys(res_obj.metavars).forEach((v) => found_metavars.add(v)); - found_llms.add(getLLMName(res_obj)); + Object.keys(res_obj.vars).forEach((v) => (found_vars as Set).add(v)); + Object.keys(res_obj.metavars).forEach((v) => (found_metavars as Set).add(v)); + (found_llms as Set).add(getLLMName(res_obj)); }); found_vars = Array.from(found_vars); found_metavars = Array.from(found_metavars).filter(cleanMetavarsFilterFunc); @@ -256,7 +263,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { // Set the variables accessible in the MultiSelect for 'group by' const msvars = found_vars - .map((name) => + .map((name: string) => // We add a $ prefix to mark this as a prompt parameter, and so // in the future we can add special types of variables without name collisions ({ value: name, label: name }), @@ -310,7 +317,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { ); // Functions to associate a color to each LLM in responses - const color_for_llm = (llm) => getColorForLLMAndSetIfNotFound(llm) + "99"; + const color_for_llm = (llm: string) => getColorForLLMAndSetIfNotFound(llm) + "99"; const header_bg_colors = ["#e0f4fa", "#c0def9", "#a9c0f9", "#a6b2ea"]; const response_box_colors = [ "#eee", @@ -321,11 +328,11 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { "#ddd", "#eee", ]; - const rgroup_color = (depth) => + const rgroup_color = (depth: number) => response_box_colors[depth % response_box_colors.length]; - const getHeaderBadge = (key, val, depth) => { - if (val) { + const getHeaderBadge = (key: string, val: string | undefined, depth: number) => { + if (val !== undefined) { const s = truncStr(val.trim(), 1024); return (
{ } }; - const generateResponseBoxes = (resps, eatenvars, fixed_width) => { + const generateResponseBoxes = (resps: LLMResponse[], eatenvars: string[], fixed_width: number) => { const hide_llm_name = eatenvars.includes("LLM"); return resps.map((res_obj, res_idx) => { // If user has searched for something, further filter the response texts by only those that contain the search term - const respsFilterFunc = (responses) => { + const respsFilterFunc = (responses: string[]) => { if (searchValue.length === 0) return responses; const filtered_resps = responses.filter( search_regex.test.bind(search_regex), @@ -399,8 +406,8 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { if (viewFormat === "table") { // Generate a table, with default columns for: input vars, LLMs queried // First get column names as input vars + LLMs: - let var_cols, colnames, getColVal, found_sel_var_vals; - let metavar_cols = []; // found_metavars; -- Disabling this functionality for now, since it is usually annoying. + let var_cols: string[], colnames: string[], getColVal: (r: LLMResponse) => string | number | boolean | undefined, found_sel_var_vals: string[]; + let metavar_cols: string[] = []; // found_metavars; -- Disabling this functionality for now, since it is usually annoying. if (tableColVar === "LLM") { var_cols = found_vars; getColVal = getLLMName; @@ -418,16 +425,16 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { responses.reduce((acc, res_obj) => { acc.add( tableColVar in res_obj.vars - ? res_obj.vars[tableColVar] + ? res_obj.vars[tableColVar] as string : "(unspecified)", ); return acc; - }, new Set()), + }, new Set()), ); colnames = var_cols.concat(found_sel_var_vals); } - const getVar = (r, v) => (v === "LLM" ? getLLMName(r) : r.vars[v]); + const getVar = (r: LLMResponse, v: string) => (v === "LLM" ? getLLMName(r) : r.vars[v]); // Then group responses by prompts. Each prompt will become a separate row of the table (will be treated as unique) const responses_by_prompt = groupResponsesBy(responses, (r) => @@ -486,7 +493,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { }, ); - setResponses([ + setResponseDivs([ @@ -506,7 +513,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { // :: nested divs first grouped by LLM (first level), then by var1, then var2 (deepest level). let leaf_id = 0; let first_opened = false; - const groupByVars = (resps, varnames, eatenvars, header) => { + const groupByVars = (resps: LLMResponse[], varnames: string[], eatenvars: string[], header: React.ReactNode) => { if (resps.length === 0) return []; if (varnames.length === 0) { // Base case. Display n response(s) to each single prompt, back-to-back: @@ -560,16 +567,16 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { ); const get_header = group_name === "LLM" - ? (key, val) => ( + ? (key: string, val?: string) => (
{val}
) - : (key, val) => getHeaderBadge(key, val, eatenvars.length); + : (key: string, val?: string) => getHeaderBadge(key, val, eatenvars.length); // Now produce nested divs corresponding to the groups const remaining_vars = varnames.slice(1); @@ -630,7 +637,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { // Produce DIV elements grouped by selected vars const divs = groupByVars(responses, selected_vars, [], null); - setResponses(divs); + setResponseDivs(divs); } setNumMatches(numResponsesDisplayed); @@ -651,15 +658,14 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { // When the user clicks an item in the drop-down, // we want to autoclose the multiselect drop-down: - const multiSelectRef = useRef(null); - const handleMultiSelectValueChange = (new_val) => { - if (multiSelectRef) { + const multiSelectRef = useRef(null); + const handleMultiSelectValueChange = (new_val: string[]) => { + if (multiSelectRef?.current) multiSelectRef.current.blur(); - } setMultiSelectValue(new_val); }; - const handleSearchValueChange = (content) => { + const handleSearchValueChange = (content: React.ChangeEvent) => { setSearchValue(content.target.value); }; @@ -730,7 +736,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
setViewFormat(val ?? "hierarchy")} styles={{ tabLabel: { fontSize: wideFormat ? "12pt" : "9pt" } }} > @@ -804,7 +810,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { -
{responses}
+
{responseDivs}
); }; diff --git a/chainforge/react-server/src/ModelSettingsModal.tsx b/chainforge/react-server/src/ModelSettingsModal.tsx index 235d6f8..4b39ecd 100644 --- a/chainforge/react-server/src/ModelSettingsModal.tsx +++ b/chainforge/react-server/src/ModelSettingsModal.tsx @@ -29,7 +29,9 @@ export interface ModelSettingsModalRef { } export interface ModelSettingsModalProps { model?: LLMSpec; - onSettingsSubmit?: () => void; + onSettingsSubmit?: (savedItem: LLMSpec, + formData: Dict, + settingsData: Dict) => void; } type FormData = LLMSpec["formData"]; diff --git a/chainforge/react-server/src/ResponseRatingToolbar.js b/chainforge/react-server/src/ResponseRatingToolbar.tsx similarity index 81% rename from chainforge/react-server/src/ResponseRatingToolbar.js rename to chainforge/react-server/src/ResponseRatingToolbar.tsx index d071b00..e987b67 100644 --- a/chainforge/react-server/src/ResponseRatingToolbar.js +++ b/chainforge/react-server/src/ResponseRatingToolbar.tsx @@ -11,8 +11,10 @@ import StorageCache from "./backend/cache"; import useStore from "./store"; import { deepcopy } from "./backend/utils"; -const getRatingKeyForResponse = (uid, label_name) => `r.${uid}.${label_name}`; -const collapse_ratings = (rating_dict, idxs) => { +type RatingDict = Record; + +const getRatingKeyForResponse = (uid: string, label_name: string) => `r.${uid}.${label_name}`; +const collapse_ratings = (rating_dict: RatingDict, idxs: number[]) => { if (rating_dict === undefined) return undefined; for (let j = 0; j < idxs.length; j++) { if (idxs[j] in rating_dict && rating_dict[idxs[j]] !== undefined) @@ -21,14 +23,20 @@ const collapse_ratings = (rating_dict, idxs) => { return undefined; }; -export const getLabelForResponse = (uid, label_name) => { +export const getLabelForResponse = (uid: string, label_name: string) => { return StorageCache.get(getRatingKeyForResponse(uid, label_name)); }; -export const setLabelForResponse = (uid, label_name, payload) => { +export const setLabelForResponse = (uid: string, label_name: string, payload: RatingDict) => { StorageCache.store(getRatingKeyForResponse(uid, label_name), payload); }; -const ToolbarButton = forwardRef(function ToolbarButton( +interface ToolbarButtonProps { + selected: boolean; + onClick: () => void; + children: React.ReactNode; +} + +const ToolbarButton = forwardRef(function ToolbarButton( { selected, onClick, children }, ref, ) { @@ -47,7 +55,14 @@ const ToolbarButton = forwardRef(function ToolbarButton( ); }); -const ResponseRatingToolbar = ({ +export interface ResponseRatingToolbarProps { + uid: string; + wideFormat?: boolean; + innerIdxs: number[]; + onUpdateResponses: () => void; +} + +const ResponseRatingToolbar: React.FC = ({ uid, wideFormat, innerIdxs, @@ -62,10 +77,10 @@ const ResponseRatingToolbar = ({ // :: for this component changes. // const state = useStore((store) => store.state); const setState = useStore((store) => store.setState); - const gradeState = useStore((store) => store.state[gradeKey]); - const noteState = useStore((store) => store.state[noteKey]); + const gradeState = useStore((store) => store.state[gradeKey]); + const noteState = useStore((store) => store.state[noteKey]); const setRating = useCallback( - (uid, label, payload) => { + (uid: string, label: string, payload: RatingDict) => { const key = getRatingKeyForResponse(uid, label); const safe_payload = deepcopy(payload); setState(key, safe_payload); @@ -91,7 +106,7 @@ const ResponseRatingToolbar = ({ // Override the text in the internal textarea whenever upstream annotation changes. useEffect(() => { - setNoteText(note); + setNoteText(note !== undefined ? note.toString() : ""); }, [note]); // The label for the pop-up comment box. @@ -107,22 +122,22 @@ const ResponseRatingToolbar = ({ }, [wideFormat]); // For human labeling of responses in the inspector - const onGrade = (grade) => { + const onGrade = (grade: boolean | undefined) => { if (uid === undefined) return; const new_grades = gradeState ?? {}; - innerIdxs.forEach((idx) => { + innerIdxs.forEach((idx: number) => { new_grades[idx] = grade; }); setRating(uid, "grade", new_grades); if (onUpdateResponses) onUpdateResponses(); }; - const onAnnotate = (label) => { + const onAnnotate = (label?: string) => { if (uid === undefined) return; if (typeof label === "string" && label.trim().length === 0) label = undefined; // empty strings are undefined const new_notes = noteState ?? {}; - innerIdxs.forEach((idx) => { + innerIdxs.forEach((idx: number) => { new_notes[idx] = label; }); setRating(uid, "note", new_notes); diff --git a/chainforge/react-server/src/backend/utils.ts b/chainforge/react-server/src/backend/utils.ts index e2933dd..5dbf92a 100644 --- a/chainforge/react-server/src/backend/utils.ts +++ b/chainforge/react-server/src/backend/utils.ts @@ -23,6 +23,7 @@ import { TemplateVarInfo, BaseLLMResponseObject, LLMSpec, + EvaluationScore, } from "./typing"; import { v4 as uuid } from "uuid"; import { StringTemplate } from "./template"; @@ -1786,10 +1787,10 @@ export const truncStr = ( export const groupResponsesBy = ( responses: LLMResponse[], - keyFunc: (item: LLMResponse) => string | null | undefined, -) => { - const responses_by_key: Dict = {}; - const unspecified_group: Dict[] = []; + keyFunc: (item: LLMResponse) => string | number | null | undefined, +): [Dict, LLMResponse[]] => { + const responses_by_key: Dict = {}; + const unspecified_group: LLMResponse[] = []; responses.forEach((item) => { const key = keyFunc(item); if (key === null || key === undefined) { @@ -1802,22 +1803,28 @@ export const groupResponsesBy = ( return [responses_by_key, unspecified_group]; }; -export const batchResponsesByUID = (responses: LLMResponse[]) => { +/** + * Merges inner .responses and eval_res.items properties for LLMResponses with the same + * uid, returning the (smaller) list of merged items. + * @param responses + * @returns + */ +export const batchResponsesByUID = (responses: LLMResponse[]): LLMResponse[] => { const [batches, unspecified_id_group] = groupResponsesBy( responses, (resp_obj) => resp_obj.uid, ); return Object.values(batches) - .map((resp_objs: Dict[]) => { + .map((resp_objs: LLMResponse[]) => { if (resp_objs.length === 1) { return resp_objs[0]; } else { const batched = deepcopy_and_modify(resp_objs[0], { responses: resp_objs.map((resp_obj) => resp_obj.responses).flat(), - }); - if (batched.eval_res !== undefined) { + }) as LLMResponse; + if (batched.eval_res?.items !== undefined) { batched.eval_res.items = resp_objs - .map((resp_obj) => resp_obj.eval_res.items) + .map((resp_obj) => resp_obj?.eval_res?.items as EvaluationScore[]) .flat(); } return batched; @@ -1911,8 +1918,8 @@ export async function retryAsyncFunc( // Filters internally used keys LLM_{idx} and __{str} from metavar dictionaries. // This method is used to pass around information hidden from the user. -export function cleanMetavarsFilterFunc() { - return (key: string) => !(key.startsWith("LLM_") || key.startsWith("__pt")); +export function cleanMetavarsFilterFunc(key: string) { + return !(key.startsWith("LLM_") || key.startsWith("__pt")); } // Verify data integrity: check that uids are present for all responses.