diff --git a/chainforge/react-server/src/BaseNode.tsx b/chainforge/react-server/src/BaseNode.tsx index 0f51d26..78a46de 100644 --- a/chainforge/react-server/src/BaseNode.tsx +++ b/chainforge/react-server/src/BaseNode.tsx @@ -4,11 +4,10 @@ */ import React, { useCallback, useMemo, useState, useRef } from "react"; -import { Menu } from "@mantine/core"; +import { Menu, MenuStylesNames, Styles } from "@mantine/core"; import { IconCopy, IconX } from "@tabler/icons-react"; import AreYouSureModal from "./AreYouSureModal"; import useStore from "./store"; -import { Dict } from "./backend/typing"; interface BaseNodeProps { children: React.ReactNode; // For components, HTML elements, text, etc. @@ -26,10 +25,9 @@ export const BaseNode: React.FC = ({ const removeNode = useStore((state) => state.removeNode); const duplicateNode = useStore((state) => state.duplicateNode); - const [contextMenuStyle, setContextMenuStyle] = useState({ - left: -100, - top: 0, - }); + const [contextMenuStyle, setContextMenuStyle] = useState< + Styles + >({}); const [contextMenuOpened, setContextMenuOpened] = useState(false); // For 'delete node' confirmation popup diff --git a/chainforge/react-server/src/JoinNode.js b/chainforge/react-server/src/JoinNode.js index db66de2..b14314c 100644 --- a/chainforge/react-server/src/JoinNode.js +++ b/chainforge/react-server/src/JoinNode.js @@ -30,6 +30,7 @@ import { cleanMetavarsFilterFunc, } from "./backend/utils"; import StorageCache from "./backend/cache"; +import { ResponseBox } from "./ResponseBoxes"; const formattingOptions = [ { value: "\n\n", label: "double newline \\n\\n" }, @@ -57,43 +58,19 @@ const DEFAULT_GROUPBY_VAR_ALL = { label: "all text", value: "A" }; const displayJoinedTexts = (textInfos, getColorForLLM) => { const color_for_llm = (llm) => getColorForLLM(llm) + "99"; return textInfos.map((info, idx) => { - const vars = info.fill_history; - const var_tags = - vars === undefined - ? [] - : Object.keys(vars).map((varname) => { - const v = truncStr(vars[varname].trim(), 72); - return ( -
- - {varname} =  - - {v} -
- ); - }); - const ps =
{info.text || info}
; return ( -
-
{var_tags}
- {info.llm === undefined ? ( - ps - ) : ( -
-

{info.llm?.name}

- {ps} -
- )} -
+ {ps} + ); }); }; diff --git a/chainforge/react-server/src/LLMResponseInspector.js b/chainforge/react-server/src/LLMResponseInspector.js index dcfc5b1..e59e58f 100644 --- a/chainforge/react-server/src/LLMResponseInspector.js +++ b/chainforge/react-server/src/LLMResponseInspector.js @@ -13,7 +13,6 @@ import React, { Suspense, } from "react"; import { - Collapse, MultiSelect, Table, NativeSelect, @@ -24,7 +23,7 @@ import { Tooltip, TextInput, } from "@mantine/core"; -import { useDisclosure, useToggle } from "@mantine/hooks"; +import { useToggle } from "@mantine/hooks"; import { IconTable, IconLayoutList, @@ -40,56 +39,17 @@ import { batchResponsesByUID, cleanMetavarsFilterFunc, } from "./backend/utils"; -import { getLabelForResponse } from "./ResponseRatingToolbar"; - -// Lazy load the response toolbars -const ResponseRatingToolbar = lazy(() => import("./ResponseRatingToolbar.js")); +import { + ResponseBox, + ResponseGroup, + genResponseTextsDisplay, +} from "./ResponseBoxes"; // Helper funcs -const countResponsesBy = (responses, keyFunc) => { - const responses_by_key = {}; - const unspecified_group = []; - responses.forEach((item, idx) => { - const key = keyFunc(item); - const d = key !== null ? responses_by_key : unspecified_group; - if (key in d) d[key].push(idx); - else d[key] = [idx]; - }); - return [responses_by_key, unspecified_group]; -}; const getLLMName = (resp_obj) => typeof resp_obj?.llm === "string" ? resp_obj.llm : resp_obj?.llm?.name; const escapeRegExp = (txt) => txt.replace(/[-[\]{}()*+?.,\\^$|#\s]/g, "\\$&"); -const SUCCESS_EVAL_SCORES = new Set(["true", "yes"]); -const FAILURE_EVAL_SCORES = new Set(["false", "no"]); -const getEvalResultStr = (eval_item) => { - if (Array.isArray(eval_item)) { - return "scores: " + eval_item.join(", "); - } else if (typeof eval_item === "object") { - const strs = Object.keys(eval_item).map((key) => { - let val = eval_item[key]; - if (typeof val === "number" && val.toString().indexOf(".") > -1) - val = val.toFixed(4); // truncate floats to 4 decimal places - return `${key}: ${val}`; - }); - return strs.join(", "); - } else { - const eval_str = eval_item.toString().trim().toLowerCase(); - const color = SUCCESS_EVAL_SCORES.has(eval_str) - ? "black" - : FAILURE_EVAL_SCORES.has(eval_str) - ? "red" - : "black"; - return ( - <> - {"score: "} - {eval_str} - - ); - } -}; - function getIndicesOfSubstringMatches(s, substr, caseSensitive) { const regex = new RegExp( escapeRegExp(substr), @@ -219,37 +179,6 @@ export const exportToExcel = (jsonResponses, filename) => { XLSX.writeFile(wb, filename); }; -const ResponseGroup = ({ - header, - responseBoxes, - responseBoxesWrapperClass, - displayStyle, - defaultState, -}) => { - const [opened, { toggle }] = useDisclosure(defaultState); - - return ( -
-
- {header} -
- -
- {responseBoxes} -
-
-
- ); -}; - const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { // Responses const [responses, setResponses] = useState([]); @@ -426,88 +355,27 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { const generateResponseBoxes = (resps, eatenvars, fixed_width) => { const hide_llm_name = eatenvars.includes("LLM"); return resps.map((res_obj, res_idx) => { - const eval_res_items = res_obj.eval_res ? res_obj.eval_res.items : null; - - // Bucket responses that have the same text, and sort by the - // number of same responses so that the top div is the most prevalent response. - let responses = res_obj.responses; - // If user has searched for something, further filter the response texts by only those that contain the search term - if (searchValue.length > 0) { + const respsFilterFunc = (responses) => { + if (searchValue.length === 0) return responses; const filtered_resps = responses.filter( search_regex.test.bind(search_regex), ); numResponsesDisplayed += filtered_resps.length; - if (filterBySearchValue) responses = filtered_resps; - } + if (filterBySearchValue) return filtered_resps; + else return responses; + }; - // We need to keep track of the original evaluation result per response str: - const resp_str_to_eval_res = {}; - if (eval_res_items) - responses.forEach((r, idx) => { - resp_str_to_eval_res[r] = eval_res_items[idx]; - }); - - // Counts the responses with the same keys - const same_resp_text_counts = countResponsesBy(responses, (r) => r)[0]; - const same_resp_keys = Object.keys(same_resp_text_counts).sort( - (key1, key2) => - same_resp_text_counts[key2].length - - same_resp_text_counts[key1].length, - ); - - const ps = same_resp_keys.map((r, idx) => { - const origIdxs = same_resp_text_counts[r]; - const textToShow = searchValue - ? genSpansForHighlightedValue(r, searchValue, caseSensitive) - : r; - return ( -
- - {!hide_llm_name && - idx === 0 && - same_resp_keys.length > 1 && - wideFormat === true ? ( -

{getLLMName(res_obj)}

- ) : ( - <> - )} - - - - {!hide_llm_name && - idx === 0 && - (same_resp_keys.length === 1 || !wideFormat) ? ( -

{getLLMName(res_obj)}

- ) : ( - <> - )} -
- {same_resp_text_counts[r].length > 1 ? ( - - {same_resp_text_counts[r].length} times - - ) : ( - <> - )} - {eval_res_items ? ( -

- {getEvalResultStr(resp_str_to_eval_res[r])} -

- ) : ( - <> - )} - {contains_eval_res && onlyShowScores ? ( -
{}
- ) : ( -
{textToShow}
- )} -
- ); + const innerTextsDisplay = genResponseTextsDisplay({ + res_obj: res_obj, + onlyShowScores: contains_eval_res && onlyShowScores, + filterFunc: respsFilterFunc, + customTextDisplay: (txt) => + searchValue + ? genSpansForHighlightedValue(txt, searchValue, caseSensitive) + : txt, + hideLLMName: hide_llm_name, + wideFormat: wideFormat, }); // At the deepest level, there may still be some vars left over. We want to display these @@ -517,31 +385,18 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => { res_obj.vars, (v) => !eatenvars.includes(v), ); - const var_tags = Object.keys(unused_vars).map((varname) => { - const v = truncStr(unused_vars[varname].trim(), wideFormat ? 72 : 18); - return ( -
- {varname} =  - {v} -
- ); - }); + const llmName = getLLMName(res_obj); return ( -
-
{var_tags}
- {hide_llm_name ? ( - ps - ) : ( -
{ps}
- )} -
+ {innerTextsDisplay} + ); }); }; diff --git a/chainforge/react-server/src/ResponseBoxes.js b/chainforge/react-server/src/ResponseBoxes.js new file mode 100644 index 0000000..cc3c88e --- /dev/null +++ b/chainforge/react-server/src/ResponseBoxes.js @@ -0,0 +1,220 @@ +import React, { Suspense, useMemo, useState } from "react"; +import { Collapse } from "@mantine/core"; +import { useDisclosure } from "@mantine/hooks"; +import { truncStr } from "./backend/utils"; +import { getLabelForResponse } from "./ResponseRatingToolbar"; + +// Lazy load the response toolbars +const ResponseRatingToolbar = lazy(() => import("./ResponseRatingToolbar.js")); + +/* HELPER FUNCTIONS */ +const SUCCESS_EVAL_SCORES = new Set(["true", "yes"]); +const FAILURE_EVAL_SCORES = new Set(["false", "no"]); +const getEvalResultStr = (eval_item) => { + if (Array.isArray(eval_item)) { + return "scores: " + eval_item.join(", "); + } else if (typeof eval_item === "object") { + const strs = Object.keys(eval_item).map((key) => { + let val = eval_item[key]; + if (typeof val === "number" && val.toString().indexOf(".") > -1) + val = val.toFixed(4); // truncate floats to 4 decimal places + return `${key}: ${val}`; + }); + return strs.join(", "); + } else { + const eval_str = eval_item.toString().trim().toLowerCase(); + const color = SUCCESS_EVAL_SCORES.has(eval_str) + ? "black" + : FAILURE_EVAL_SCORES.has(eval_str) + ? "red" + : "black"; + return ( + <> + {"score: "} + {eval_str} + + ); + } +}; +const countResponsesBy = (responses, keyFunc) => { + const responses_by_key = {}; + const unspecified_group = []; + responses.forEach((item) => { + const key = keyFunc(item); + const d = key !== null ? responses_by_key : unspecified_group; + if (key in d) d[key] += 1; + else d[key] = 1; + }); + return [responses_by_key, unspecified_group]; +}; + +/** + * A ResponseGroup is used in the Grouped List view to display clickable, collapseable groups of responses. + * These groups may also be ResponseGroups (nested). + */ +export const ResponseGroup = ({ + header, + responseBoxes, + responseBoxesWrapperClass, + displayStyle, + defaultState, +}) => { + const [opened, { toggle }] = useDisclosure(defaultState); + + return ( +
+
+ {header} +
+ +
+ {responseBoxes} +
+
+
+ ); +}; + +/** + * A ResponseBox is the display of an LLM's response(s) for a single prompt. + * It is the colored boxes that appear in the response inspector when you are inspecting responses. + * Note that a ResponseBox could list multiple textual responses if num responses per prompt > 1. + */ +export const ResponseBox = ({ + children, + boxColor, + width, + vars, + truncLenForVars, + llmName, +}) => { + const var_tags = useMemo(() => { + return Object.entries(vars).map(([varname, val]) => { + const v = truncStr(val.trim(), truncLenForVars ?? 18); + return ( +
+ {varname} =  + {v} +
+ ); + }); + }, [vars, truncLenForVars]); + + return ( +
+
{var_tags}
+ {llmName !== undefined ? ( + children + ) : ( +
+ {children} +
+ )} +
+ ); +}; + +/** + * Given a response object, generates the inner divs to put inside a ResponseBox. + * This is the lowest level display for response texts in ChainForge. + */ +export const genResponseTextsDisplay = ({ + res_obj, + filterFunc, + customTextDisplay, + onlyShowScores, + hideLLMName, + wideFormat, +}) => { + if (!res_obj) return <>; + + const eval_res_items = res_obj.eval_res ? res_obj.eval_res.items : null; + + // Bucket responses that have the same text, and sort by the + // number of same responses so that the top div is the most prevalent response. + let responses = res_obj.responses; + + // Perform any post-processing of responses. For instance, + // when searching for a response, we mark up the response texts to spans + // and may filter out some responses, removing them from display. + if (filterFunc) responses = filterFunc(responses); + + // Collapse responses with the same texts. + // We need to keep track of the original evaluation result per response str: + const resp_str_to_eval_res = {}; + if (eval_res_items) + responses.forEach((r, idx) => { + resp_str_to_eval_res[r] = eval_res_items[idx]; + }); + + const same_resp_text_counts = countResponsesBy(responses, (r) => r)[0]; + const same_resp_keys = Object.keys(same_resp_text_counts).sort( + (key1, key2) => same_resp_text_counts[key2] - same_resp_text_counts[key1], + ); + + return same_resp_keys.map((r, idx) => { + const origIdxs = same_resp_text_counts[r]; + const txt = customTextDisplay ? customTextDisplay(r) : r; + return ( +
+ + {!hideLLMName && + idx === 0 && + same_resp_keys.length > 1 && + wideFormat === true ? ( +

{getLLMName(res_obj)}

+ ) : ( + <> + )} + + + + {!hideLLMName && + idx === 0 && + (same_resp_keys.length === 1 || !wideFormat) ? ( +

{getLLMName(res_obj)}

+ ) : ( + <> + )} +
+ {same_resp_text_counts[r] > 1 ? ( + + {same_resp_text_counts[r]} times + + ) : ( + <> + )} + {eval_res_items ? ( +

+ {getEvalResultStr(resp_str_to_eval_res[r])} +

+ ) : ( + <> + )} + {onlyShowScores ? ( +
{}
+ ) : ( +
{txt}
+ )} +
+ ); + }); +}; diff --git a/chainforge/react-server/src/SplitNode.js b/chainforge/react-server/src/SplitNode.js index cde7287..101fd15 100644 --- a/chainforge/react-server/src/SplitNode.js +++ b/chainforge/react-server/src/SplitNode.js @@ -31,6 +31,7 @@ import { import { fromMarkdown } from "mdast-util-from-markdown"; import StorageCache from "./backend/cache"; +import { ResponseBox } from "./ResponseBoxes"; const formattingOptions = [ { value: "list", label: "- list items" }, @@ -102,43 +103,18 @@ export const splitText = (s, format, shouldEscapeBraces) => { const displaySplitTexts = (textInfos, getColorForLLM) => { const color_for_llm = (llm) => getColorForLLM(llm) + "99"; return textInfos.map((info, idx) => { - const vars = info.fill_history; - const var_tags = - vars === undefined - ? [] - : Object.keys(vars).map((varname) => { - const v = truncStr(vars[varname].trim(), 72); - return ( -
- - {varname} =  - - {v} -
- ); - }); - const ps =
{info.text ?? info}
; - return ( -
-
{var_tags}
- {info.llm === undefined ? ( - ps - ) : ( -
-

{info.llm?.name}

- {ps} -
- )} -
+ {ps} + ); }); };