WIP. Refactor addNodes in App.tsx to be simpler.

This commit is contained in:
Ian Arawjo 2024-03-10 23:16:59 -04:00
parent 00e5d0764b
commit d98de2b7ca
12 changed files with 183 additions and 240 deletions

View File

@ -107,6 +107,7 @@
"@types/papaparse": "^5.3.14",
"@types/react-beautiful-dnd": "^13.1.8",
"@types/react-edit-text": "^5.0.4",
"@types/styled-components": "^5.1.34",
"eslint": "^8.56.0",
"eslint-config-prettier": "^9.1.0",
"eslint-config-semistandard": "^17.0.0",
@ -7111,6 +7112,17 @@
"resolved": "https://registry.npmjs.org/@types/stack-utils/-/stack-utils-2.0.1.tgz",
"integrity": "sha512-Hl219/BT5fLAaz6NDkSuhzasy49dwQS/DSdu4MdggFB8zcXv7vflBI3xp7FEmkmdDkBUI2bPUNeMttp2knYdxw=="
},
"node_modules/@types/styled-components": {
"version": "5.1.34",
"resolved": "https://registry.npmjs.org/@types/styled-components/-/styled-components-5.1.34.tgz",
"integrity": "sha512-mmiVvwpYklFIv9E8qfxuPyIt/OuyIrn6gMOAMOFUO3WJfSrSE+sGUoa4PiZj77Ut7bKZpaa6o1fBKS/4TOEvnA==",
"dev": true,
"dependencies": {
"@types/hoist-non-react-statics": "*",
"@types/react": "*",
"csstype": "^3.0.2"
}
},
"node_modules/@types/testing-library__jest-dom": {
"version": "5.14.9",
"resolved": "https://registry.npmjs.org/@types/testing-library__jest-dom/-/testing-library__jest-dom-5.14.9.tgz",

View File

@ -133,6 +133,7 @@
"@types/papaparse": "^5.3.14",
"@types/react-beautiful-dnd": "^13.1.8",
"@types/react-edit-text": "^5.0.4",
"@types/styled-components": "^5.1.34",
"eslint": "^8.56.0",
"eslint-config-prettier": "^9.1.0",
"eslint-config-semistandard": "^17.0.0",

View File

@ -37,8 +37,8 @@ import TabularDataNode from "./TabularDataNode";
import JoinNode from "./JoinNode";
import SplitNode from "./SplitNode";
import CommentNode from "./CommentNode";
import GlobalSettingsModal from "./GlobalSettingsModal";
import ExampleFlowsModal from "./ExampleFlowsModal";
import GlobalSettingsModal, { GlobalSettingsModalRef } from "./GlobalSettingsModal";
import ExampleFlowsModal, { ExampleFlowsModalRef } from "./ExampleFlowsModal";
import AreYouSureModal from "./AreYouSureModal";
import LLMEvaluatorNode from "./LLMEvalNode";
import SimpleEvalNode from "./SimpleEvalNode";
@ -69,7 +69,7 @@ import {
isEdgeChromium,
isChromium,
} from "react-device-detect";
import { Dict, LLMSpec } from "./backend/typing";
import { Dict, JSONCompatible, LLMSpec } from "./backend/typing";
const IS_ACCEPTED_BROWSER =
(isChrome ||
isChromium ||
@ -196,7 +196,7 @@ const MenuTooltip = ({
};
// const connectionLineStyle = { stroke: '#ddd' };
const snapGrid = [16, 16];
const snapGrid: [number, number] = [16, 16];
const App = () => {
// Get nodes, edges, etc. state from the Zustand store:
@ -206,7 +206,7 @@ const App = () => {
onNodesChange,
onEdgesChange,
onConnect,
addNode,
addNode: addNodeToStore,
setNodes,
setEdges,
resetLLMColors,
@ -216,17 +216,17 @@ const App = () => {
// For saving / loading
const [rfInstance, setRfInstance] = useState<ReactFlowInstance | null>(null);
const [autosavingInterval, setAutosavingInterval] = useState(null);
const [autosavingInterval, setAutosavingInterval] = useState<NodeJS.Timeout | undefined>(undefined);
// For 'share' button
const clipboard = useClipboard({ timeout: 1500 });
const [waitingForShare, setWaitingForShare] = useState(false);
// For modal popup to set global settings like API keys
const settingsModal = useRef(null);
const settingsModal = useRef<GlobalSettingsModalRef>(null);
// For modal popup of example flows
const examplesModal = useRef(null);
const examplesModal = useRef<ExampleFlowsModalRef>(null);
// For an info pop-up that welcomes new users
// const [welcomeModalOpened, { open: openWelcomeModal, close: closeWelcomeModal }] = useDisclosure(false);
@ -265,150 +265,44 @@ const App = () => {
return { x: -(x / zoom) + centerX / zoom, y: -(y / zoom) + centerY / zoom };
};
const addTextFieldsNode = () => {
const addNode = (id: string, type?: string, data?: Dict, offsetX?: number, offsetY?: number) => {
const { x, y } = getViewportCenter();
addNode({
id: "textFieldsNode-" + Date.now(),
type: "textfields",
data: {},
position: { x: x - 200, y: y - 100 },
});
};
const addPromptNode = () => {
const { x, y } = getViewportCenter();
addNode({
id: "promptNode-" + Date.now(),
type: "prompt",
data: { prompt: "" },
position: { x: x - 200, y: y - 100 },
});
};
const addChatTurnNode = () => {
const { x, y } = getViewportCenter();
addNode({
id: "chatTurn-" + Date.now(),
type: "chat",
data: { prompt: "" },
position: { x: x - 200, y: y - 100 },
});
};
const addSimpleEvalNode = () => {
const { x, y } = getViewportCenter();
addNode({
id: "simpleEval-" + Date.now(),
type: "simpleval",
data: {},
position: { x: x - 200, y: y - 100 },
addNodeToStore({
id: `${id}-` + Date.now(),
type: type ?? id,
data: data ?? {},
position: { x: x - 200 + (offsetX ? offsetX : 0), y: y - 100 + (offsetY ? offsetY : 0)},
});
};
const addTextFieldsNode = () => addNode("textFieldsNode", "textfields");
const addPromptNode = () => addNode("promptNode", "prompt", { prompt: "" });
const addChatTurnNode = () => addNode("chatTurn", "chat", { prompt: "" });
const addSimpleEvalNode = () => addNode("simpleEval", "simpleval");
const addEvalNode = (progLang: string) => {
const { x, y } = getViewportCenter();
let code = "";
if (progLang === "python")
code = "def evaluate(response):\n return len(response.text)";
else if (progLang === "javascript")
code = "function evaluate(response) {\n return response.text.length;\n}";
addNode({
id: "evalNode-" + Date.now(),
type: "evaluator",
data: { language: progLang, code },
position: { x: x - 200, y: y - 100 },
});
};
const addVisNode = () => {
const { x, y } = getViewportCenter();
addNode({
id: "visNode-" + Date.now(),
type: "vis",
data: {},
position: { x: x - 200, y: y - 100 },
});
};
const addInspectNode = () => {
const { x, y } = getViewportCenter();
addNode({
id: "inspectNode-" + Date.now(),
type: "inspect",
data: {},
position: { x: x - 200, y: y - 100 },
});
};
const addScriptNode = () => {
const { x, y } = getViewportCenter();
addNode({
id: "scriptNode-" + Date.now(),
type: "script",
data: {},
position: { x: x - 200, y: y - 100 },
});
};
const addItemsNode = () => {
const { x, y } = getViewportCenter();
addNode({
id: "csvNode-" + Date.now(),
type: "csv",
data: {},
position: { x: x - 200, y: y - 100 },
});
};
const addTabularDataNode = () => {
const { x, y } = getViewportCenter();
addNode({
id: "table-" + Date.now(),
type: "table",
data: {},
position: { x: x - 200, y: y - 100 },
});
};
const addCommentNode = () => {
const { x, y } = getViewportCenter();
addNode({
id: "comment-" + Date.now(),
type: "comment",
data: {},
position: { x: x - 200, y: y - 100 },
});
};
const addLLMEvalNode = () => {
const { x, y } = getViewportCenter();
addNode({
id: "llmeval-" + Date.now(),
type: "llmeval",
data: {},
position: { x: x - 200, y: y - 100 },
});
};
const addJoinNode = () => {
const { x, y } = getViewportCenter();
addNode({
id: "join-" + Date.now(),
type: "join",
data: {},
position: { x: x - 200, y: y - 100 },
});
};
const addSplitNode = () => {
const { x, y } = getViewportCenter();
addNode({
id: "split-" + Date.now(),
type: "split",
data: {},
position: { x: x - 200, y: y - 100 },
});
addNode("evalNode", "evaluator", { language: progLang, code });
};
const addVisNode = () => addNode("visNode", "vis", {});
const addInspectNode = () => addNode("inspectNode", "inspect");
const addScriptNode = () => addNode("scriptNode", "script");
const addItemsNode = () => addNode("csvNode", "csv");
const addTabularDataNode = () => addNode("table");
const addCommentNode = () => addNode("comment");
const addLLMEvalNode = () => addNode("llmeval");
const addJoinNode = () => addNode("join");
const addSplitNode = () => addNode("split");
const addProcessorNode = (progLang: string) => {
const { x, y } = getViewportCenter();
let code = "";
if (progLang === "python")
code = "def process(response):\n return response.text;";
else if (progLang === "javascript")
code = "function process(response) {\n return response.text;\n}";
addNode({
id: "process-" + Date.now(),
type: "processor",
data: { language: progLang, code },
position: { x: x - 200, y: y - 100 },
});
addNode("process", "processor", { language: progLang, code });
};
const onClickExamples = () => {
@ -429,7 +323,7 @@ const App = () => {
/**
* SAVING / LOADING, IMPORT / EXPORT (from JSON)
*/
const downloadJSON = (jsonData, filename) => {
const downloadJSON = (jsonData: JSONCompatible, filename: string) => {
// Convert JSON object to JSON string
const jsonString = JSON.stringify(jsonData, null, 2);
@ -547,7 +441,7 @@ const App = () => {
const saved_flow = StorageCache.loadFromLocalStorage(
"chainforge-flow",
false,
);
) as Dict;
if (saved_flow) {
StorageCache.loadFromLocalStorage("chainforge-state");
importGlobalStateFromCache();
@ -718,7 +612,7 @@ const App = () => {
};
// Load flow from examples modal
const onSelectExampleFlow = (name: string, example_category: string) => {
const onSelectExampleFlow = (name: string, example_category?: string) => {
// Trigger the 'loading' modal
setIsLoading(true);
@ -872,9 +766,9 @@ const App = () => {
]);
// Initialize auto-saving
const initAutosaving = (rf_inst) => {
if (autosavingInterval !== null) return; // autosaving interval already set
console.log("Init autosaving!");
const initAutosaving = (rf_inst: ReactFlowInstance) => {
if (autosavingInterval !== undefined) return; // autosaving interval already set
console.log("Init autosaving");
// Autosave the flow to localStorage every minute:
const interv = setInterval(() => {
@ -898,7 +792,7 @@ const App = () => {
"Autosaving disabled. The time required to save to localStorage exceeds 1 second. This can happen when there's a lot of data in your flow. Make sure to export frequently to save your work.",
);
clearInterval(interv);
setAutosavingInterval(null);
setAutosavingInterval(undefined);
}
}, 60000); // 60000 milliseconds = 1 minute
setAutosavingInterval(interv);
@ -938,23 +832,16 @@ const App = () => {
if (!response || response.startsWith("Error")) {
// Error encountered during the query; alert the user
// with the error message:
handleError(new Error(response || "Unknown error"));
return;
throw new Error(response || "Unknown error");
}
// Attempt to parse the response as a compressed flow + import it:
try {
const cforge_json = JSON.parse(
LZString.decompressFromUTF16(response),
);
importFlowFromJSON(cforge_json, rf_inst);
} catch (err) {
handleError(err);
}
const cforge_json = JSON.parse(
LZString.decompressFromUTF16(response),
);
importFlowFromJSON(cforge_json, rf_inst);
})
.catch((err) => {
handleError(err);
});
.catch(handleError);
} catch (err) {
// Soft fail
setIsLoading(false);
@ -1064,7 +951,9 @@ const App = () => {
onConnect={onConnect}
nodes={nodes}
edges={edges}
// @ts-expect-error Node types won't perfectly fit unless we explicitly extend from RF's types; ignoring this for now.
nodeTypes={nodeTypes}
// @ts-expect-error Edge types won't perfectly fit unless we explicitly extend from RF's types; ignoring this for now.
edgeTypes={edgeTypes}
zoomOnPinch={false}
zoomOnScroll={false}

View File

@ -7,7 +7,7 @@ import React, {
forwardRef,
useImperativeHandle,
} from "react";
import { Handle, Position } from "reactflow";
import { Handle, NodeProps, Position } from "reactflow";
import {
Code,
Modal,

View File

@ -357,16 +357,16 @@ const ExampleFlowCard: React.FC<ExampleFlowCardProps> = ({
);
};
interface ExampleFlowsModalHandles {
export interface ExampleFlowsModalRef {
trigger: () => void;
}
interface ExampleFlowsModalProps {
export interface ExampleFlowsModalProps {
handleOnSelect: (filename: string, category?: string) => void;
}
const ExampleFlowsModal = forwardRef<
ExampleFlowsModalHandles,
ExampleFlowsModalRef,
ExampleFlowsModalProps
>(function ExampleFlowsModal({ handleOnSelect }, ref) {
// Mantine modal popover for alerts

View File

@ -31,12 +31,15 @@ import {
IconX,
IconSparkles,
} from "@tabler/icons-react";
import { Dropzone } from "@mantine/dropzone";
import { Dropzone, FileWithPath } from "@mantine/dropzone";
import useStore from "./store";
import { APP_IS_RUNNING_LOCALLY } from "./backend/utils";
import fetch_from_backend from "./fetch_from_backend";
import { setCustomProviders } from "./ModelSettingSchemas";
import { getAIFeaturesModelProviders } from "./backend/ai";
import { CustomLLMProviderSpec, Dict } from "./backend/typing";
import { initCustomProvider, loadCachedCustomProviders, removeCustomProvider } from "./backend/backend";
import { AlertModalRef } from "./AlertModal";
const _LINK_STYLE = { color: "#1E90FF", textDecoration: "none" };
@ -44,11 +47,11 @@ const _LINK_STYLE = { color: "#1E90FF", textDecoration: "none" };
let LOADED_CUSTOM_PROVIDERS = false;
// Read a file as text and pass the text to a cb (callback) function
const read_file = (file, cb) => {
const read_file = (file: FileWithPath, cb: (contents: string | ArrayBuffer | null) => void) => {
const reader = new window.FileReader();
reader.onload = function (event) {
const fileContent = event.target.result;
cb(fileContent);
const fileContent = event.target?.result;
cb(fileContent ?? null);
};
reader.onerror = function (event) {
console.error("Error reading file:", event);
@ -56,10 +59,15 @@ const read_file = (file, cb) => {
reader.readAsText(file);
};
interface CustomProviderScriptDropzoneProps {
onError: (err: string | Error) => void;
onSetProviders: (providers: CustomLLMProviderSpec[]) => void;
}
/** A Dropzone to load a Python `.py` script that registers a `CustomModelProvider` in the Flask backend.
* If successful, the list of custom model providers in the ChainForge UI dropdown is updated.
* */
const CustomProviderScriptDropzone = ({ onError, onSetProviders }) => {
const CustomProviderScriptDropzone: React.FC<CustomProviderScriptDropzoneProps> = ({ onError, onSetProviders }) => {
const theme = useMantineTheme();
const [isLoading, setIsLoading] = useState(false);
@ -69,23 +77,20 @@ const CustomProviderScriptDropzone = ({ onError, onSetProviders }) => {
onDrop={(files) => {
if (files.length === 1) {
setIsLoading(true);
read_file(files[0], (content) => {
read_file(files[0], (content: string | ArrayBuffer | null) => {
if (typeof content !== "string") {
console.error("File unreadable: Contents are not text.");
return;
}
// Read the file into text and then send it to backend
fetch_from_backend("initCustomProvider", {
code: content,
})
.then((response) => {
initCustomProvider(content)
.then((providers) => {
setIsLoading(false);
if (response.error || !response.providers) {
onError(response.error);
return;
}
// Successfully loaded custom providers in backend,
// now load them into the ChainForge UI:
console.log(response.providers);
setCustomProviders(response.providers);
onSetProviders(response.providers);
console.log(providers);
setCustomProviders(providers);
onSetProviders(providers);
})
.catch((err) => {
setIsLoading(false);
@ -102,8 +107,6 @@ const CustomProviderScriptDropzone = ({ onError, onSetProviders }) => {
maxSize={3 * 1024 ** 2}
>
<Flex
pos="center"
spacing="md"
style={{ minHeight: rem(80), pointerEvents: "none" }}
>
<Center>
@ -144,7 +147,15 @@ const CustomProviderScriptDropzone = ({ onError, onSetProviders }) => {
);
};
const GlobalSettingsModal = forwardRef(
export interface GlobalSettingsModalRef {
trigger: () => void;
}
export interface GlobalSettingsModalProps {
alertModal?: React.RefObject<AlertModalRef>;
}
const GlobalSettingsModal = forwardRef<GlobalSettingsModalRef, GlobalSettingsModalProps>(
function GlobalSettingsModal(props, ref) {
const [opened, { open, close }] = useDisclosure(false);
const setAPIKeys = useStore((state) => state.setAPIKeys);
@ -160,11 +171,15 @@ const GlobalSettingsModal = forwardRef(
);
const aiFeaturesProvider = useStore((state) => state.aiFeaturesProvider);
const [aiSupportActive, setAISupportActive] = useState(
getFlag("aiSupport"),
const [aiSupportActive, setAISupportActive] = useState<boolean>(
getFlag("aiSupport") as boolean,
);
const [aiAutocompleteActive, setAIAutocompleteActive] = useState<boolean>(
getFlag("aiAutocomplete") as boolean,
);
const handleAISupportChecked = useCallback(
(e) => {
(e: React.ChangeEvent<HTMLInputElement>) => {
const checked = e.currentTarget.checked;
setAISupportActive(checked);
setFlag("aiSupport", checked);
@ -177,11 +192,8 @@ const GlobalSettingsModal = forwardRef(
[setFlag, setAISupportActive],
);
const [aiAutocompleteActive, setAIAutocompleteActive] = useState(
getFlag("aiAutocomplete"),
);
const handleAIAutocompleteChecked = useCallback(
(e) => {
(e: React.ChangeEvent<HTMLInputElement>) => {
const checked = e.currentTarget.checked;
setAIAutocompleteActive(checked);
setFlag("aiAutocomplete", checked);
@ -190,13 +202,14 @@ const GlobalSettingsModal = forwardRef(
);
const handleError = useCallback(
(msg) => {
(err: string | Error) => {
const msg = typeof err === "string" ? err : err.message;
if (alertModal && alertModal.current) alertModal.current.trigger(msg);
},
[alertModal],
);
const [customProviders, setLocalCustomProviders] = useState([]);
const [customProviders, setLocalCustomProviders] = useState<CustomLLMProviderSpec[]>([]);
const refreshLLMProviderLists = useCallback(() => {
// We unfortunately have to force all prompt/chat nodes to refresh their LLM lists, bc
// apparently the update to the AvailableLLMs list is not immediately propagated to them.
@ -208,16 +221,9 @@ const GlobalSettingsModal = forwardRef(
);
}, [nodes, setDataPropsForNode]);
const removeCustomProvider = useCallback(
(name) => {
fetch_from_backend("removeCustomProvider", {
name,
})
.then((response) => {
if (response.error || !response.success) {
handleError(response.error);
return;
}
const handleRemoveCustomProvider = useCallback(
(name: string) => {
removeCustomProvider(name).then(() => {
// Successfully deleted the custom provider from backend;
// now updated the front-end UI to reflect this:
setAvailableLLMs(AvailableLLMs.filter((p) => p.name !== name));
@ -226,7 +232,7 @@ const GlobalSettingsModal = forwardRef(
);
refreshLLMProviderLists();
})
.catch((err) => handleError(err.message));
.catch(handleError);
},
[customProviders, handleError, AvailableLLMs, refreshLLMProviderLists],
);
@ -237,20 +243,13 @@ const GlobalSettingsModal = forwardRef(
LOADED_CUSTOM_PROVIDERS = true;
// Is running locally; try to load any custom providers.
// Soft fails if it encounters error:
fetch_from_backend("loadCachedCustomProviders", {}, console.error).then(
(json) => {
if (json?.error || json?.providers === undefined) {
console.error(
json?.error ||
"Could not load custom provider scripts: Error contacting backend.",
);
return;
}
loadCachedCustomProviders().then(
(providers) => {
// Success; pass custom providers list to store:
setCustomProviders(json.providers);
setLocalCustomProviders(json.providers);
setCustomProviders(providers);
setLocalCustomProviders(providers);
},
);
).catch(console.error);
}
}, []);
@ -276,7 +275,7 @@ const GlobalSettingsModal = forwardRef(
});
// When the API settings form is submitted
const onSubmit = (values) => {
const onSubmit = (values: Dict<string>) => {
setAPIKeys(values);
close();
};
@ -524,7 +523,7 @@ const GlobalSettingsModal = forwardRef(
)}
</Group>
<Button
onClick={() => removeCustomProvider(p.name)}
onClick={() => handleRemoveCustomProvider(p.name)}
color="red"
p="0px"
mt="4px"
@ -537,7 +536,7 @@ const GlobalSettingsModal = forwardRef(
))}
<CustomProviderScriptDropzone
onError={handleError}
onSetProviders={(ps) => {
onSetProviders={(ps: CustomLLMProviderSpec[]) => {
refreshLLMProviderLists();
setLocalCustomProviders(ps);
}}

View File

@ -1,12 +1,16 @@
import React, { Button, Group, RingProgress } from "@mantine/core";
import { IconSettings, IconTrash } from "@tabler/icons-react";
import { QueryProgress } from "./backend/typing";
export function GatheringResponsesRingProgress({ progress }) {
export function GatheringResponsesRingProgress({
progress,
}: {
progress: QueryProgress | undefined;
}) {
return progress !== undefined ? (
progress.success > 0 || progress.error > 0 ? (
<RingProgress
size={20}
width="16px"
thickness={3}
sections={[
{
@ -29,12 +33,19 @@ export function GatheringResponsesRingProgress({ progress }) {
);
}
export interface LLMItemButtonGroupProps {
onClickTrash?: () => void;
onClickSettings?: () => void;
ringProgress?: QueryProgress;
hideTrashIcon?: boolean;
}
export default function LLMItemButtonGroup({
onClickTrash,
onClickSettings,
ringProgress,
hideTrashIcon,
}) {
}: LLMItemButtonGroupProps) {
return (
<div>
<Group position="right" style={{ float: "right", height: "20px" }}>

View File

@ -5,12 +5,14 @@ import LLMItemButtonGroup from "./LLMItemButtonGroup";
import { IconTemperature } from "@tabler/icons-react";
import { getTemperatureSpecForModel } from "./ModelSettingSchemas";
import { Tooltip } from "@mantine/core";
import { LLMSpec, QueryProgress } from "./backend/typing";
import { DraggableProvided, DraggableStateSnapshot } from "react-beautiful-dnd";
// == The below function perc2color modified from: ==
// License: MIT - https://opensource.org/licenses/MIT
// Author: Michele Locati <michele@locati.it>
// Source: https://gist.github.com/mlocati/7210513
const perc2color = (perc) => {
const perc2color = (perc: number) => {
let r = 0;
let g = 0;
let b = 0;
@ -27,9 +29,9 @@ const perc2color = (perc) => {
return "#" + ("000000" + h.toString(16)).slice(-6);
};
const percTemperature = (llm_item) => {
const percTemperature = (llm_item: LLMSpec) => {
// Get the temp for this llm item
const temp = llm_item.settings?.temperature;
const temp = llm_item.settings?.temperature as number;
if (temp === undefined) {
console.warn(
`Did not find temperature setting for model ${llm_item.base_model}.`,
@ -74,7 +76,17 @@ export const DragItem = styled.div`
flex-direction: column;
`;
const LLMListItem = ({
export interface LLMListItemProps {
item: LLMSpec;
provided: DraggableProvided;
snapshot: DraggableStateSnapshot;
removeCallback: (key: string) => void;
onClickSettings: () => void;
progress?: QueryProgress;
hideTrashIcon: boolean;
}
const LLMListItem: React.FC<LLMListItemProps> = ({
item,
provided,
snapshot,
@ -85,7 +97,7 @@ const LLMListItem = ({
}) => {
// Set color by temperature only on item change (not every render)
const [tempColor, setTempColor] = useState(perc2color(50));
const temperature = item.settings?.temperature;
const temperature = item.settings?.temperature as number;
useEffect(() => {
if (temperature !== undefined)
@ -95,6 +107,7 @@ const LLMListItem = ({
return (
<DragItem
ref={provided.innerRef}
// @ts-expect-error This property is used dynamically by the Draggable library.
snapshot={snapshot}
{...provided.draggableProps}
{...provided.dragHandleProps}
@ -123,7 +136,7 @@ const LLMListItem = ({
)}
</CardHeader>
<LLMItemButtonGroup
onClickTrash={() => removeCallback(item.key)}
onClickTrash={() => removeCallback(item.key ?? "undefined")}
ringProgress={progress}
onClickSettings={onClickSettings}
hideTrashIcon={hideTrashIcon}
@ -133,7 +146,7 @@ const LLMListItem = ({
);
};
export const LLMListItemClone = ({
export const LLMListItemClone: React.FC<LLMListItemProps> = ({
item,
provided,
snapshot,
@ -141,7 +154,7 @@ export const LLMListItemClone = ({
}) => {
// Set color by temperature only on item change (not every render)
const [tempColor, setTempColor] = useState(perc2color(50));
const temperature = item.settings?.temperature;
const temperature = item.settings?.temperature as number;
useEffect(() => {
if (temperature !== undefined)
@ -153,6 +166,7 @@ export const LLMListItemClone = ({
ref={provided.innerRef}
{...provided.draggableProps}
{...provided.dragHandleProps}
// @ts-expect-error This property is used dynamically by the Draggable library.
snapshot={snapshot}
>
<div>

View File

@ -2153,9 +2153,9 @@ export const getTemperatureSpecForModel = (modelName: string) => {
ModelSettings[modelName].schema?.properties?.temperature;
if (temperature_property) {
return {
minimum: temperature_property.minimum,
maximum: temperature_property.maximum,
default: temperature_property.default,
minimum: temperature_property.minimum as number,
maximum: temperature_property.maximum as number,
default: temperature_property.default as number,
};
}
}

View File

@ -13,6 +13,7 @@ import {
LLMSpec,
EvaluatedResponsesResults,
TemplateVarInfo,
CustomLLMProviderSpec,
} from "./typing";
import { LLM, getEnumName } from "./models";
import {
@ -489,7 +490,9 @@ export async function generatePrompts(
vars: Dict<(TemplateVarInfo | string)[]>,
): Promise<PromptTemplate[]> {
const gen_prompts = new PromptPermutationGenerator(root_prompt);
const all_prompt_permutations = Array.from(gen_prompts.generate(deepcopy(vars)));
const all_prompt_permutations = Array.from(
gen_prompts.generate(deepcopy(vars)),
);
return all_prompt_permutations;
}
@ -1491,7 +1494,7 @@ export async function fetchOpenAIEval(evalname: string): Promise<Dict> {
* @returns a Promise with the JSON of the response. Will include 'error' key if error'd; if success,
* a 'providers' key with a list of all loaded custom provider callbacks, as dicts.
*/
export async function initCustomProvider(code: string): Promise<Dict> {
export async function initCustomProvider(code: string): Promise<CustomLLMProviderSpec[]> {
// Attempt to fetch the example flow from the local filesystem
// by querying the Flask server:
return fetch(`${FLASK_BASE_URL}app/initCustomProvider`, {
@ -1503,6 +1506,10 @@ export async function initCustomProvider(code: string): Promise<Dict> {
body: JSON.stringify({ code }),
}).then(function (res) {
return res.json();
}).then(function (json) {
if (!json || json.error || !json.providers)
throw new Error(json.error ?? "Unknown error");
return json.providers as CustomLLMProviderSpec[];
});
}
@ -1513,7 +1520,7 @@ export async function initCustomProvider(code: string): Promise<Dict> {
* @returns a Promise with the JSON of the response. Will include 'error' key if error'd; if success,
* a 'success' key with a true value.
*/
export async function removeCustomProvider(name: string): Promise<Dict> {
export async function removeCustomProvider(name: string): Promise<boolean> {
// Attempt to fetch the example flow from the local filesystem
// by querying the Flask server:
return fetch(`${FLASK_BASE_URL}app/removeCustomProvider`, {
@ -1525,6 +1532,10 @@ export async function removeCustomProvider(name: string): Promise<Dict> {
body: JSON.stringify({ name }),
}).then(function (res) {
return res.json();
}).then(function (json) {
if (!json || json.error || !json.success)
throw new Error(json.error ?? "Unknown error");
return true;
});
}
@ -1534,7 +1545,7 @@ export async function removeCustomProvider(name: string): Promise<Dict> {
* @returns a Promise with the JSON of the response. Will include 'error' key if error'd; if success,
* a 'providers' key with all loaded custom providers in an array. If there were none, returns empty array.
*/
export async function loadCachedCustomProviders(): Promise<Dict> {
export async function loadCachedCustomProviders(): Promise<CustomLLMProviderSpec[]> {
return fetch(`${FLASK_BASE_URL}app/loadCachedCustomProviders`, {
method: "POST",
headers: {
@ -1544,5 +1555,9 @@ export async function loadCachedCustomProviders(): Promise<Dict> {
body: "{}",
}).then(function (res) {
return res.json();
}).then(function (json) {
if (!json || json.error || !json.providers)
throw new Error(json.error ?? "Could not load custom provider scripts: Error contacting backend.");
return json.providers as CustomLLMProviderSpec[];
});
}

View File

@ -1,4 +1,4 @@
import { Dict } from "./typing";
import { Dict, JSONCompatible } from "./typing";
import LZString from "lz-string";
/**
@ -98,7 +98,7 @@ export default class StorageCache {
console.warn("Storage quota exceeded");
} else {
// Handle other types of storage-related errors
console.error("Error storing data in localStorage:", error.message);
console.error("Error storing data in localStorage:", (error as Error).message);
}
return false;
}
@ -114,7 +114,7 @@ export default class StorageCache {
public static loadFromLocalStorage(
localStorageKey = "chainforge",
setStorageCacheData = true,
): boolean {
): JSONCompatible | undefined {
const compressed = localStorage.getItem(localStorageKey);
if (!compressed) {
console.error(
@ -128,7 +128,7 @@ export default class StorageCache {
console.log("loaded", data);
return data;
} catch (error) {
console.error(error.message);
console.error((error as Error).message);
return undefined;
}
}

View File

@ -8,6 +8,7 @@ import {
NodeChange,
EdgeChange,
MarkerType,
Connection,
} from "reactflow";
import { escapeBraces } from "./backend/template";
import {
@ -248,7 +249,7 @@ export interface StoreHandles {
removeEdge: (id: string) => void;
onNodesChange: (changes: NodeChange[]) => void;
onEdgesChange: (changes: EdgeChange[]) => void;
onConnect: (connection: Edge) => void;
onConnect: (connection: Connection | Edge) => void;
// The LLM providers available in the drop-down list
AvailableLLMs: LLMSpec[];
@ -717,7 +718,7 @@ const useStore = create<StoreHandles>((set, get) => ({
},
onConnect: (connection) => {
// Get the target node information
const target = get().getNode(connection.target);
const target = connection.target ? get().getNode(connection.target) : undefined;
if (target === undefined) return;
if (
@ -736,6 +737,7 @@ const useStore = create<StoreHandles>((set, get) => ({
get().setDataPropsForNode(target.id, { refresh: true });
}
connection = connection as Edge;
connection.interactionWidth = 40;
connection.markerEnd = { type: MarkerType.Arrow, width: 22, height: 22 }; // 22px
connection.type = "default";