mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-15 00:36:29 +00:00
wip
This commit is contained in:
parent
67432f9853
commit
a6309ad444
@ -11,6 +11,10 @@ import React, {
|
||||
import {
|
||||
DragDropContext,
|
||||
Draggable,
|
||||
DraggableProvided,
|
||||
DraggableRubric,
|
||||
DraggableStateSnapshot,
|
||||
DroppableProvided,
|
||||
OnDragEndResponder,
|
||||
} from "react-beautiful-dnd";
|
||||
import { Menu } from "@mantine/core";
|
||||
@ -205,7 +209,7 @@ export function LLMList({
|
||||
<DragDropContext onDragEnd={onDragEnd}>
|
||||
<StrictModeDroppable
|
||||
droppableId="llm-list-droppable"
|
||||
renderClone={(provided, snapshot, rubric) => (
|
||||
renderClone={(provided: DraggableProvided, snapshot: DraggableStateSnapshot, rubric: DraggableRubric) => (
|
||||
<LLMListItemClone
|
||||
provided={provided}
|
||||
snapshot={snapshot}
|
||||
@ -214,10 +218,10 @@ export function LLMList({
|
||||
/>
|
||||
)}
|
||||
>
|
||||
{(provided) => (
|
||||
{(provided: DroppableProvided) => (
|
||||
<div {...provided.droppableProps} ref={provided.innerRef}>
|
||||
{items.map((item, index) => (
|
||||
<Draggable key={item.key} draggableId={item.key} index={index}>
|
||||
<Draggable key={item.key} draggableId={item.key ?? index.toString()} index={index}>
|
||||
{(provided, snapshot) => (
|
||||
<LLMListItem
|
||||
provided={provided}
|
||||
@ -240,7 +244,27 @@ export function LLMList({
|
||||
);
|
||||
}
|
||||
|
||||
export const LLMListContainer = forwardRef(function LLMListContainer(
|
||||
export interface LLMListContainerHandle {
|
||||
resetLLMItemsProgress: () => void;
|
||||
setZeroPercProgress: () => void;
|
||||
updateProgress: (itemProcessorFunc: (llm: LLMSpec) => LLMSpec) => void;
|
||||
ensureLLMItemsErrorProgress: (llm_keys_w_errors: string[]) => void;
|
||||
getLLMListItemForKey: (key: string) => LLMSpec | undefined;
|
||||
refreshLLMProviderList: () => void;
|
||||
}
|
||||
|
||||
export interface LLMListContainerProps {
|
||||
description: string;
|
||||
modelSelectButtonText: string;
|
||||
initLLMItems: LLMSpec[];
|
||||
onSelectModel: (llm: LLMSpec, new_llms: LLMSpec[]) => void;
|
||||
onItemsChange: (new_llms: LLMSpec[], old_llms: LLMSpec[]) => void;
|
||||
hideTrashIcon: boolean;
|
||||
bgColor: string;
|
||||
selectModelAction?: "add" | "replace";
|
||||
}
|
||||
|
||||
export const LLMListContainer = forwardRef<LLMListContainerHandle, LLMListContainerProps>(function LLMListContainer(
|
||||
{
|
||||
description,
|
||||
modelSelectButtonText,
|
||||
@ -293,16 +317,16 @@ export const LLMListContainer = forwardRef(function LLMListContainer(
|
||||
);
|
||||
}, [llmItemsCurrState]);
|
||||
const updateProgress = useCallback(
|
||||
(itemProcessorFunc: (llms: LLMSpec) => LLMSpec) => {
|
||||
(itemProcessorFunc: (llm: LLMSpec) => LLMSpec) => {
|
||||
setLLMItems(llmItemsCurrState.map(itemProcessorFunc));
|
||||
},
|
||||
[llmItemsCurrState],
|
||||
);
|
||||
const ensureLLMItemsErrorProgress = useCallback(
|
||||
(llm_keys_w_errors) => {
|
||||
(llm_keys_w_errors: string[]) => {
|
||||
setLLMItems(
|
||||
llmItemsCurrState.map((item) => {
|
||||
if (llm_keys_w_errors.includes(item.key)) {
|
||||
if (item.key !== undefined && llm_keys_w_errors.includes(item.key)) {
|
||||
if (!item.progress) item.progress = { success: 0, error: 100 };
|
||||
else {
|
||||
const succ_perc = item.progress.success;
|
||||
@ -321,7 +345,7 @@ export const LLMListContainer = forwardRef(function LLMListContainer(
|
||||
);
|
||||
|
||||
const getLLMListItemForKey = useCallback(
|
||||
(key) => {
|
||||
(key: string) => {
|
||||
return llmItemsCurrState.find((item) => item.key === key);
|
||||
},
|
||||
[llmItemsCurrState],
|
||||
@ -353,7 +377,7 @@ export const LLMListContainer = forwardRef(function LLMListContainer(
|
||||
item.name = unique_name;
|
||||
item.formData = { shortname: unique_name };
|
||||
|
||||
let new_items;
|
||||
let new_items: LLMSpec[] = [];
|
||||
if (selectModelAction === "add" || selectModelAction === undefined) {
|
||||
// Add model to the LLM list (regardless of it's present already or not).
|
||||
new_items = llmItemsCurrState.concat([item]);
|
||||
|
@ -80,10 +80,10 @@ export interface LLMListItemProps {
|
||||
item: LLMSpec;
|
||||
provided: DraggableProvided;
|
||||
snapshot: DraggableStateSnapshot;
|
||||
removeCallback: (key: string) => void;
|
||||
onClickSettings: () => void;
|
||||
removeCallback?: (key: string) => void;
|
||||
onClickSettings?: () => void;
|
||||
progress?: QueryProgress;
|
||||
hideTrashIcon: boolean;
|
||||
hideTrashIcon?: boolean;
|
||||
}
|
||||
|
||||
const LLMListItem: React.FC<LLMListItemProps> = ({
|
||||
@ -136,7 +136,7 @@ const LLMListItem: React.FC<LLMListItemProps> = ({
|
||||
)}
|
||||
</CardHeader>
|
||||
<LLMItemButtonGroup
|
||||
onClickTrash={() => removeCallback(item.key ?? "undefined")}
|
||||
onClickTrash={() => removeCallback && removeCallback(item.key ?? "undefined")}
|
||||
ringProgress={progress}
|
||||
onClickSettings={onClickSettings}
|
||||
hideTrashIcon={hideTrashIcon}
|
||||
|
@ -38,13 +38,14 @@ import {
|
||||
genResponseTextsDisplay,
|
||||
} from "./ResponseBoxes";
|
||||
import { getLabelForResponse } from "./ResponseRatingToolbar";
|
||||
import { Dict, LLMResponse } from "./backend/typing";
|
||||
|
||||
// Helper funcs
|
||||
const getLLMName = (resp_obj) =>
|
||||
const getLLMName = (resp_obj: LLMResponse) =>
|
||||
typeof resp_obj?.llm === "string" ? resp_obj.llm : resp_obj?.llm?.name;
|
||||
const escapeRegExp = (txt) => txt.replace(/[-[\]{}()*+?.,\\^$|#\s]/g, "\\$&");
|
||||
const escapeRegExp = (txt: string) => txt.replace(/[-[\]{}()*+?.,\\^$|#\s]/g, "\\$&");
|
||||
|
||||
function getIndicesOfSubstringMatches(s, substr, caseSensitive) {
|
||||
function getIndicesOfSubstringMatches(s: string, substr: string, caseSensitive?: boolean) {
|
||||
const regex = new RegExp(
|
||||
escapeRegExp(substr),
|
||||
"g" + (caseSensitive ? "" : "i"),
|
||||
@ -56,7 +57,7 @@ function getIndicesOfSubstringMatches(s, substr, caseSensitive) {
|
||||
}
|
||||
|
||||
// Splits a string by a substring or regex, but includes the delimiter (substring/regex match) elements in the returned array.
|
||||
function splitAndIncludeDelimiter(s, substr, caseSensitive) {
|
||||
function splitAndIncludeDelimiter(s: string, substr: string, caseSensitive?: boolean) {
|
||||
const indices = getIndicesOfSubstringMatches(s, substr, caseSensitive);
|
||||
if (indices.length === 0) return [s];
|
||||
|
||||
@ -78,7 +79,7 @@ function splitAndIncludeDelimiter(s, substr, caseSensitive) {
|
||||
}
|
||||
|
||||
// Returns an HTML version of text where 'searchValue' is highlighted.
|
||||
function genSpansForHighlightedValue(text, searchValue, caseSensitive) {
|
||||
function genSpansForHighlightedValue(text: string, searchValue: string, caseSensitive?: boolean) {
|
||||
// Split texts by searchValue and map to <span> and <mark> elements
|
||||
return splitAndIncludeDelimiter(text, searchValue, caseSensitive).map(
|
||||
(s, idx) => {
|
||||
@ -94,7 +95,7 @@ function genSpansForHighlightedValue(text, searchValue, caseSensitive) {
|
||||
}
|
||||
|
||||
// Export the JSON responses to an excel file (downloads the file):
|
||||
export const exportToExcel = (jsonResponses, filename) => {
|
||||
export const exportToExcel = (jsonResponses: LLMResponse[], filename: string) => {
|
||||
if (!filename) filename = "responses.xlsx";
|
||||
|
||||
// Check that there are responses to export:
|
||||
@ -123,7 +124,7 @@ export const exportToExcel = (jsonResponses, filename) => {
|
||||
};
|
||||
const eval_res_items = res_obj.eval_res ? res_obj.eval_res.items : null;
|
||||
return res_obj.responses.map((r, r_idx) => {
|
||||
const row = {
|
||||
const row: Dict<string | number | boolean> = {
|
||||
LLM: llm,
|
||||
Prompt: prompt,
|
||||
Response: r,
|
||||
@ -173,17 +174,23 @@ export const exportToExcel = (jsonResponses, filename) => {
|
||||
XLSX.writeFile(wb, filename);
|
||||
};
|
||||
|
||||
const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
|
||||
export interface LLMResponseInspectorProps {
|
||||
jsonResponses: LLMResponse[];
|
||||
wideFormat?: boolean;
|
||||
}
|
||||
|
||||
const LLMResponseInspector: React.FC<LLMResponseInspectorProps> = ({ jsonResponses, wideFormat }) => {
|
||||
// Responses
|
||||
const [responses, setResponses] = useState([]);
|
||||
const [responseDivs, setResponseDivs] = useState<React.ReactNode>([]);
|
||||
const [receivedResponsesOnce, setReceivedResponsesOnce] = useState(false);
|
||||
|
||||
// The type of view to use to display responses. Can be either hierarchy or table.
|
||||
const [viewFormat, setViewFormat] = useState("hierarchy");
|
||||
|
||||
// The MultiSelect so people can dynamically set what vars they care about
|
||||
const [multiSelectVars, setMultiSelectVars] = useState([]);
|
||||
const [multiSelectValue, setMultiSelectValue] = useState([]);
|
||||
const [multiSelectVars, setMultiSelectVars] = useState<{value: string; label: string}[]>([]);
|
||||
const [multiSelectValue, setMultiSelectValue] = useState<string[]>([]);
|
||||
|
||||
// Search bar functionality
|
||||
const [searchValue, setSearchValue] = useState("");
|
||||
@ -236,13 +243,13 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
return;
|
||||
|
||||
// Find all vars in responses
|
||||
let found_vars = new Set();
|
||||
let found_metavars = new Set();
|
||||
let found_llms = new Set();
|
||||
let found_vars: Array<string> | Set<string> = new Set<string>();
|
||||
let found_metavars: Array<string> | Set<string> = new Set<string>();
|
||||
let found_llms: Array<string> | Set<string> = new Set<string>();
|
||||
batchedResponses.forEach((res_obj) => {
|
||||
Object.keys(res_obj.vars).forEach((v) => found_vars.add(v));
|
||||
Object.keys(res_obj.metavars).forEach((v) => found_metavars.add(v));
|
||||
found_llms.add(getLLMName(res_obj));
|
||||
Object.keys(res_obj.vars).forEach((v) => (found_vars as Set<string>).add(v));
|
||||
Object.keys(res_obj.metavars).forEach((v) => (found_metavars as Set<string>).add(v));
|
||||
(found_llms as Set<string>).add(getLLMName(res_obj));
|
||||
});
|
||||
found_vars = Array.from(found_vars);
|
||||
found_metavars = Array.from(found_metavars).filter(cleanMetavarsFilterFunc);
|
||||
@ -256,7 +263,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
|
||||
// Set the variables accessible in the MultiSelect for 'group by'
|
||||
const msvars = found_vars
|
||||
.map((name) =>
|
||||
.map((name: string) =>
|
||||
// We add a $ prefix to mark this as a prompt parameter, and so
|
||||
// in the future we can add special types of variables without name collisions
|
||||
({ value: name, label: name }),
|
||||
@ -310,7 +317,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
);
|
||||
|
||||
// Functions to associate a color to each LLM in responses
|
||||
const color_for_llm = (llm) => getColorForLLMAndSetIfNotFound(llm) + "99";
|
||||
const color_for_llm = (llm: string) => getColorForLLMAndSetIfNotFound(llm) + "99";
|
||||
const header_bg_colors = ["#e0f4fa", "#c0def9", "#a9c0f9", "#a6b2ea"];
|
||||
const response_box_colors = [
|
||||
"#eee",
|
||||
@ -321,11 +328,11 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
"#ddd",
|
||||
"#eee",
|
||||
];
|
||||
const rgroup_color = (depth) =>
|
||||
const rgroup_color = (depth: number) =>
|
||||
response_box_colors[depth % response_box_colors.length];
|
||||
|
||||
const getHeaderBadge = (key, val, depth) => {
|
||||
if (val) {
|
||||
const getHeaderBadge = (key: string, val: string | undefined, depth: number) => {
|
||||
if (val !== undefined) {
|
||||
const s = truncStr(val.trim(), 1024);
|
||||
return (
|
||||
<div
|
||||
@ -346,11 +353,11 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
}
|
||||
};
|
||||
|
||||
const generateResponseBoxes = (resps, eatenvars, fixed_width) => {
|
||||
const generateResponseBoxes = (resps: LLMResponse[], eatenvars: string[], fixed_width: number) => {
|
||||
const hide_llm_name = eatenvars.includes("LLM");
|
||||
return resps.map((res_obj, res_idx) => {
|
||||
// If user has searched for something, further filter the response texts by only those that contain the search term
|
||||
const respsFilterFunc = (responses) => {
|
||||
const respsFilterFunc = (responses: string[]) => {
|
||||
if (searchValue.length === 0) return responses;
|
||||
const filtered_resps = responses.filter(
|
||||
search_regex.test.bind(search_regex),
|
||||
@ -399,8 +406,8 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
if (viewFormat === "table") {
|
||||
// Generate a table, with default columns for: input vars, LLMs queried
|
||||
// First get column names as input vars + LLMs:
|
||||
let var_cols, colnames, getColVal, found_sel_var_vals;
|
||||
let metavar_cols = []; // found_metavars; -- Disabling this functionality for now, since it is usually annoying.
|
||||
let var_cols: string[], colnames: string[], getColVal: (r: LLMResponse) => string | number | boolean | undefined, found_sel_var_vals: string[];
|
||||
let metavar_cols: string[] = []; // found_metavars; -- Disabling this functionality for now, since it is usually annoying.
|
||||
if (tableColVar === "LLM") {
|
||||
var_cols = found_vars;
|
||||
getColVal = getLLMName;
|
||||
@ -418,16 +425,16 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
responses.reduce((acc, res_obj) => {
|
||||
acc.add(
|
||||
tableColVar in res_obj.vars
|
||||
? res_obj.vars[tableColVar]
|
||||
? res_obj.vars[tableColVar] as string
|
||||
: "(unspecified)",
|
||||
);
|
||||
return acc;
|
||||
}, new Set()),
|
||||
}, new Set<string>()),
|
||||
);
|
||||
colnames = var_cols.concat(found_sel_var_vals);
|
||||
}
|
||||
|
||||
const getVar = (r, v) => (v === "LLM" ? getLLMName(r) : r.vars[v]);
|
||||
const getVar = (r: LLMResponse, v: string) => (v === "LLM" ? getLLMName(r) : r.vars[v]);
|
||||
|
||||
// Then group responses by prompts. Each prompt will become a separate row of the table (will be treated as unique)
|
||||
const responses_by_prompt = groupResponsesBy(responses, (r) =>
|
||||
@ -486,7 +493,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
},
|
||||
);
|
||||
|
||||
setResponses([
|
||||
setResponseDivs([
|
||||
<Table key="table" fontSize={wideFormat ? "sm" : "xs"}>
|
||||
<thead>
|
||||
<tr>
|
||||
@ -506,7 +513,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
// :: nested divs first grouped by LLM (first level), then by var1, then var2 (deepest level).
|
||||
let leaf_id = 0;
|
||||
let first_opened = false;
|
||||
const groupByVars = (resps, varnames, eatenvars, header) => {
|
||||
const groupByVars = (resps: LLMResponse[], varnames: string[], eatenvars: string[], header: React.ReactNode) => {
|
||||
if (resps.length === 0) return [];
|
||||
if (varnames.length === 0) {
|
||||
// Base case. Display n response(s) to each single prompt, back-to-back:
|
||||
@ -560,16 +567,16 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
);
|
||||
const get_header =
|
||||
group_name === "LLM"
|
||||
? (key, val) => (
|
||||
? (key: string, val?: string) => (
|
||||
<div
|
||||
key={val}
|
||||
style={{ backgroundColor: color_for_llm(val) }}
|
||||
style={{ backgroundColor: val ? color_for_llm(val) : "#eee" }}
|
||||
className="response-llm-header"
|
||||
>
|
||||
{val}
|
||||
</div>
|
||||
)
|
||||
: (key, val) => getHeaderBadge(key, val, eatenvars.length);
|
||||
: (key: string, val?: string) => getHeaderBadge(key, val, eatenvars.length);
|
||||
|
||||
// Now produce nested divs corresponding to the groups
|
||||
const remaining_vars = varnames.slice(1);
|
||||
@ -630,7 +637,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
|
||||
// Produce DIV elements grouped by selected vars
|
||||
const divs = groupByVars(responses, selected_vars, [], null);
|
||||
setResponses(divs);
|
||||
setResponseDivs(divs);
|
||||
}
|
||||
|
||||
setNumMatches(numResponsesDisplayed);
|
||||
@ -651,15 +658,14 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
|
||||
// When the user clicks an item in the drop-down,
|
||||
// we want to autoclose the multiselect drop-down:
|
||||
const multiSelectRef = useRef(null);
|
||||
const handleMultiSelectValueChange = (new_val) => {
|
||||
if (multiSelectRef) {
|
||||
const multiSelectRef = useRef<HTMLInputElement>(null);
|
||||
const handleMultiSelectValueChange = (new_val: string[]) => {
|
||||
if (multiSelectRef?.current)
|
||||
multiSelectRef.current.blur();
|
||||
}
|
||||
setMultiSelectValue(new_val);
|
||||
};
|
||||
|
||||
const handleSearchValueChange = (content) => {
|
||||
const handleSearchValueChange = (content: React.ChangeEvent<HTMLInputElement>) => {
|
||||
setSearchValue(content.target.value);
|
||||
};
|
||||
|
||||
@ -730,7 +736,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
<div style={{ height: "100%" }}>
|
||||
<Tabs
|
||||
value={viewFormat}
|
||||
onTabChange={setViewFormat}
|
||||
onTabChange={(val) => setViewFormat(val ?? "hierarchy")}
|
||||
styles={{ tabLabel: { fontSize: wideFormat ? "12pt" : "9pt" } }}
|
||||
>
|
||||
<Tabs.List>
|
||||
@ -804,7 +810,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
</Tabs.Panel>
|
||||
</Tabs>
|
||||
|
||||
<div className="nowheel nodrag">{responses}</div>
|
||||
<div className="nowheel nodrag">{responseDivs}</div>
|
||||
</div>
|
||||
);
|
||||
};
|
@ -29,7 +29,9 @@ export interface ModelSettingsModalRef {
|
||||
}
|
||||
export interface ModelSettingsModalProps {
|
||||
model?: LLMSpec;
|
||||
onSettingsSubmit?: () => void;
|
||||
onSettingsSubmit?: (savedItem: LLMSpec,
|
||||
formData: Dict<JSONCompatible>,
|
||||
settingsData: Dict<JSONCompatible>) => void;
|
||||
}
|
||||
type FormData = LLMSpec["formData"];
|
||||
|
||||
|
@ -11,8 +11,10 @@ import StorageCache from "./backend/cache";
|
||||
import useStore from "./store";
|
||||
import { deepcopy } from "./backend/utils";
|
||||
|
||||
const getRatingKeyForResponse = (uid, label_name) => `r.${uid}.${label_name}`;
|
||||
const collapse_ratings = (rating_dict, idxs) => {
|
||||
type RatingDict = Record<number, boolean | string | undefined>;
|
||||
|
||||
const getRatingKeyForResponse = (uid: string, label_name: string) => `r.${uid}.${label_name}`;
|
||||
const collapse_ratings = (rating_dict: RatingDict, idxs: number[]) => {
|
||||
if (rating_dict === undefined) return undefined;
|
||||
for (let j = 0; j < idxs.length; j++) {
|
||||
if (idxs[j] in rating_dict && rating_dict[idxs[j]] !== undefined)
|
||||
@ -21,14 +23,20 @@ const collapse_ratings = (rating_dict, idxs) => {
|
||||
return undefined;
|
||||
};
|
||||
|
||||
export const getLabelForResponse = (uid, label_name) => {
|
||||
export const getLabelForResponse = (uid: string, label_name: string) => {
|
||||
return StorageCache.get(getRatingKeyForResponse(uid, label_name));
|
||||
};
|
||||
export const setLabelForResponse = (uid, label_name, payload) => {
|
||||
export const setLabelForResponse = (uid: string, label_name: string, payload: RatingDict) => {
|
||||
StorageCache.store(getRatingKeyForResponse(uid, label_name), payload);
|
||||
};
|
||||
|
||||
const ToolbarButton = forwardRef(function ToolbarButton(
|
||||
interface ToolbarButtonProps {
|
||||
selected: boolean;
|
||||
onClick: () => void;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
const ToolbarButton = forwardRef<HTMLButtonElement, ToolbarButtonProps>(function ToolbarButton(
|
||||
{ selected, onClick, children },
|
||||
ref,
|
||||
) {
|
||||
@ -47,7 +55,14 @@ const ToolbarButton = forwardRef(function ToolbarButton(
|
||||
);
|
||||
});
|
||||
|
||||
const ResponseRatingToolbar = ({
|
||||
export interface ResponseRatingToolbarProps {
|
||||
uid: string;
|
||||
wideFormat?: boolean;
|
||||
innerIdxs: number[];
|
||||
onUpdateResponses: () => void;
|
||||
}
|
||||
|
||||
const ResponseRatingToolbar: React.FC<ResponseRatingToolbarProps> = ({
|
||||
uid,
|
||||
wideFormat,
|
||||
innerIdxs,
|
||||
@ -62,10 +77,10 @@ const ResponseRatingToolbar = ({
|
||||
// :: for this component changes.
|
||||
// const state = useStore((store) => store.state);
|
||||
const setState = useStore((store) => store.setState);
|
||||
const gradeState = useStore((store) => store.state[gradeKey]);
|
||||
const noteState = useStore((store) => store.state[noteKey]);
|
||||
const gradeState = useStore<RatingDict>((store) => store.state[gradeKey]);
|
||||
const noteState = useStore<RatingDict>((store) => store.state[noteKey]);
|
||||
const setRating = useCallback(
|
||||
(uid, label, payload) => {
|
||||
(uid: string, label: string, payload: RatingDict) => {
|
||||
const key = getRatingKeyForResponse(uid, label);
|
||||
const safe_payload = deepcopy(payload);
|
||||
setState(key, safe_payload);
|
||||
@ -91,7 +106,7 @@ const ResponseRatingToolbar = ({
|
||||
|
||||
// Override the text in the internal textarea whenever upstream annotation changes.
|
||||
useEffect(() => {
|
||||
setNoteText(note);
|
||||
setNoteText(note !== undefined ? note.toString() : "");
|
||||
}, [note]);
|
||||
|
||||
// The label for the pop-up comment box.
|
||||
@ -107,22 +122,22 @@ const ResponseRatingToolbar = ({
|
||||
}, [wideFormat]);
|
||||
|
||||
// For human labeling of responses in the inspector
|
||||
const onGrade = (grade) => {
|
||||
const onGrade = (grade: boolean | undefined) => {
|
||||
if (uid === undefined) return;
|
||||
const new_grades = gradeState ?? {};
|
||||
innerIdxs.forEach((idx) => {
|
||||
innerIdxs.forEach((idx: number) => {
|
||||
new_grades[idx] = grade;
|
||||
});
|
||||
setRating(uid, "grade", new_grades);
|
||||
if (onUpdateResponses) onUpdateResponses();
|
||||
};
|
||||
|
||||
const onAnnotate = (label) => {
|
||||
const onAnnotate = (label?: string) => {
|
||||
if (uid === undefined) return;
|
||||
if (typeof label === "string" && label.trim().length === 0)
|
||||
label = undefined; // empty strings are undefined
|
||||
const new_notes = noteState ?? {};
|
||||
innerIdxs.forEach((idx) => {
|
||||
innerIdxs.forEach((idx: number) => {
|
||||
new_notes[idx] = label;
|
||||
});
|
||||
setRating(uid, "note", new_notes);
|
@ -23,6 +23,7 @@ import {
|
||||
TemplateVarInfo,
|
||||
BaseLLMResponseObject,
|
||||
LLMSpec,
|
||||
EvaluationScore,
|
||||
} from "./typing";
|
||||
import { v4 as uuid } from "uuid";
|
||||
import { StringTemplate } from "./template";
|
||||
@ -1786,10 +1787,10 @@ export const truncStr = (
|
||||
|
||||
export const groupResponsesBy = (
|
||||
responses: LLMResponse[],
|
||||
keyFunc: (item: LLMResponse) => string | null | undefined,
|
||||
) => {
|
||||
const responses_by_key: Dict = {};
|
||||
const unspecified_group: Dict[] = [];
|
||||
keyFunc: (item: LLMResponse) => string | number | null | undefined,
|
||||
): [Dict<LLMResponse[]>, LLMResponse[]] => {
|
||||
const responses_by_key: Dict<LLMResponse[]> = {};
|
||||
const unspecified_group: LLMResponse[] = [];
|
||||
responses.forEach((item) => {
|
||||
const key = keyFunc(item);
|
||||
if (key === null || key === undefined) {
|
||||
@ -1802,22 +1803,28 @@ export const groupResponsesBy = (
|
||||
return [responses_by_key, unspecified_group];
|
||||
};
|
||||
|
||||
export const batchResponsesByUID = (responses: LLMResponse[]) => {
|
||||
/**
|
||||
* Merges inner .responses and eval_res.items properties for LLMResponses with the same
|
||||
* uid, returning the (smaller) list of merged items.
|
||||
* @param responses
|
||||
* @returns
|
||||
*/
|
||||
export const batchResponsesByUID = (responses: LLMResponse[]): LLMResponse[] => {
|
||||
const [batches, unspecified_id_group] = groupResponsesBy(
|
||||
responses,
|
||||
(resp_obj) => resp_obj.uid,
|
||||
);
|
||||
return Object.values(batches)
|
||||
.map((resp_objs: Dict[]) => {
|
||||
.map((resp_objs: LLMResponse[]) => {
|
||||
if (resp_objs.length === 1) {
|
||||
return resp_objs[0];
|
||||
} else {
|
||||
const batched = deepcopy_and_modify(resp_objs[0], {
|
||||
responses: resp_objs.map((resp_obj) => resp_obj.responses).flat(),
|
||||
});
|
||||
if (batched.eval_res !== undefined) {
|
||||
}) as LLMResponse;
|
||||
if (batched.eval_res?.items !== undefined) {
|
||||
batched.eval_res.items = resp_objs
|
||||
.map((resp_obj) => resp_obj.eval_res.items)
|
||||
.map((resp_obj) => resp_obj?.eval_res?.items as EvaluationScore[])
|
||||
.flat();
|
||||
}
|
||||
return batched;
|
||||
@ -1911,8 +1918,8 @@ export async function retryAsyncFunc<T>(
|
||||
|
||||
// Filters internally used keys LLM_{idx} and __{str} from metavar dictionaries.
|
||||
// This method is used to pass around information hidden from the user.
|
||||
export function cleanMetavarsFilterFunc() {
|
||||
return (key: string) => !(key.startsWith("LLM_") || key.startsWith("__pt"));
|
||||
export function cleanMetavarsFilterFunc(key: string) {
|
||||
return !(key.startsWith("LLM_") || key.startsWith("__pt"));
|
||||
}
|
||||
|
||||
// Verify data integrity: check that uids are present for all responses.
|
||||
|
Loading…
x
Reference in New Issue
Block a user