mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-15 00:36:29 +00:00
WIP. Refactor addNodes in App.tsx to be simpler.
This commit is contained in:
parent
00e5d0764b
commit
d98de2b7ca
12
chainforge/react-server/package-lock.json
generated
12
chainforge/react-server/package-lock.json
generated
@ -107,6 +107,7 @@
|
||||
"@types/papaparse": "^5.3.14",
|
||||
"@types/react-beautiful-dnd": "^13.1.8",
|
||||
"@types/react-edit-text": "^5.0.4",
|
||||
"@types/styled-components": "^5.1.34",
|
||||
"eslint": "^8.56.0",
|
||||
"eslint-config-prettier": "^9.1.0",
|
||||
"eslint-config-semistandard": "^17.0.0",
|
||||
@ -7111,6 +7112,17 @@
|
||||
"resolved": "https://registry.npmjs.org/@types/stack-utils/-/stack-utils-2.0.1.tgz",
|
||||
"integrity": "sha512-Hl219/BT5fLAaz6NDkSuhzasy49dwQS/DSdu4MdggFB8zcXv7vflBI3xp7FEmkmdDkBUI2bPUNeMttp2knYdxw=="
|
||||
},
|
||||
"node_modules/@types/styled-components": {
|
||||
"version": "5.1.34",
|
||||
"resolved": "https://registry.npmjs.org/@types/styled-components/-/styled-components-5.1.34.tgz",
|
||||
"integrity": "sha512-mmiVvwpYklFIv9E8qfxuPyIt/OuyIrn6gMOAMOFUO3WJfSrSE+sGUoa4PiZj77Ut7bKZpaa6o1fBKS/4TOEvnA==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@types/hoist-non-react-statics": "*",
|
||||
"@types/react": "*",
|
||||
"csstype": "^3.0.2"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/testing-library__jest-dom": {
|
||||
"version": "5.14.9",
|
||||
"resolved": "https://registry.npmjs.org/@types/testing-library__jest-dom/-/testing-library__jest-dom-5.14.9.tgz",
|
||||
|
@ -133,6 +133,7 @@
|
||||
"@types/papaparse": "^5.3.14",
|
||||
"@types/react-beautiful-dnd": "^13.1.8",
|
||||
"@types/react-edit-text": "^5.0.4",
|
||||
"@types/styled-components": "^5.1.34",
|
||||
"eslint": "^8.56.0",
|
||||
"eslint-config-prettier": "^9.1.0",
|
||||
"eslint-config-semistandard": "^17.0.0",
|
||||
|
@ -37,8 +37,8 @@ import TabularDataNode from "./TabularDataNode";
|
||||
import JoinNode from "./JoinNode";
|
||||
import SplitNode from "./SplitNode";
|
||||
import CommentNode from "./CommentNode";
|
||||
import GlobalSettingsModal from "./GlobalSettingsModal";
|
||||
import ExampleFlowsModal from "./ExampleFlowsModal";
|
||||
import GlobalSettingsModal, { GlobalSettingsModalRef } from "./GlobalSettingsModal";
|
||||
import ExampleFlowsModal, { ExampleFlowsModalRef } from "./ExampleFlowsModal";
|
||||
import AreYouSureModal from "./AreYouSureModal";
|
||||
import LLMEvaluatorNode from "./LLMEvalNode";
|
||||
import SimpleEvalNode from "./SimpleEvalNode";
|
||||
@ -69,7 +69,7 @@ import {
|
||||
isEdgeChromium,
|
||||
isChromium,
|
||||
} from "react-device-detect";
|
||||
import { Dict, LLMSpec } from "./backend/typing";
|
||||
import { Dict, JSONCompatible, LLMSpec } from "./backend/typing";
|
||||
const IS_ACCEPTED_BROWSER =
|
||||
(isChrome ||
|
||||
isChromium ||
|
||||
@ -196,7 +196,7 @@ const MenuTooltip = ({
|
||||
};
|
||||
|
||||
// const connectionLineStyle = { stroke: '#ddd' };
|
||||
const snapGrid = [16, 16];
|
||||
const snapGrid: [number, number] = [16, 16];
|
||||
|
||||
const App = () => {
|
||||
// Get nodes, edges, etc. state from the Zustand store:
|
||||
@ -206,7 +206,7 @@ const App = () => {
|
||||
onNodesChange,
|
||||
onEdgesChange,
|
||||
onConnect,
|
||||
addNode,
|
||||
addNode: addNodeToStore,
|
||||
setNodes,
|
||||
setEdges,
|
||||
resetLLMColors,
|
||||
@ -216,17 +216,17 @@ const App = () => {
|
||||
|
||||
// For saving / loading
|
||||
const [rfInstance, setRfInstance] = useState<ReactFlowInstance | null>(null);
|
||||
const [autosavingInterval, setAutosavingInterval] = useState(null);
|
||||
const [autosavingInterval, setAutosavingInterval] = useState<NodeJS.Timeout | undefined>(undefined);
|
||||
|
||||
// For 'share' button
|
||||
const clipboard = useClipboard({ timeout: 1500 });
|
||||
const [waitingForShare, setWaitingForShare] = useState(false);
|
||||
|
||||
// For modal popup to set global settings like API keys
|
||||
const settingsModal = useRef(null);
|
||||
const settingsModal = useRef<GlobalSettingsModalRef>(null);
|
||||
|
||||
// For modal popup of example flows
|
||||
const examplesModal = useRef(null);
|
||||
const examplesModal = useRef<ExampleFlowsModalRef>(null);
|
||||
|
||||
// For an info pop-up that welcomes new users
|
||||
// const [welcomeModalOpened, { open: openWelcomeModal, close: closeWelcomeModal }] = useDisclosure(false);
|
||||
@ -265,150 +265,44 @@ const App = () => {
|
||||
return { x: -(x / zoom) + centerX / zoom, y: -(y / zoom) + centerY / zoom };
|
||||
};
|
||||
|
||||
const addTextFieldsNode = () => {
|
||||
const addNode = (id: string, type?: string, data?: Dict, offsetX?: number, offsetY?: number) => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({
|
||||
id: "textFieldsNode-" + Date.now(),
|
||||
type: "textfields",
|
||||
data: {},
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
};
|
||||
const addPromptNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({
|
||||
id: "promptNode-" + Date.now(),
|
||||
type: "prompt",
|
||||
data: { prompt: "" },
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
};
|
||||
const addChatTurnNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({
|
||||
id: "chatTurn-" + Date.now(),
|
||||
type: "chat",
|
||||
data: { prompt: "" },
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
};
|
||||
const addSimpleEvalNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({
|
||||
id: "simpleEval-" + Date.now(),
|
||||
type: "simpleval",
|
||||
data: {},
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
addNodeToStore({
|
||||
id: `${id}-` + Date.now(),
|
||||
type: type ?? id,
|
||||
data: data ?? {},
|
||||
position: { x: x - 200 + (offsetX ? offsetX : 0), y: y - 100 + (offsetY ? offsetY : 0)},
|
||||
});
|
||||
};
|
||||
|
||||
const addTextFieldsNode = () => addNode("textFieldsNode", "textfields");
|
||||
const addPromptNode = () => addNode("promptNode", "prompt", { prompt: "" });
|
||||
const addChatTurnNode = () => addNode("chatTurn", "chat", { prompt: "" });
|
||||
const addSimpleEvalNode = () => addNode("simpleEval", "simpleval");
|
||||
const addEvalNode = (progLang: string) => {
|
||||
const { x, y } = getViewportCenter();
|
||||
let code = "";
|
||||
if (progLang === "python")
|
||||
code = "def evaluate(response):\n return len(response.text)";
|
||||
else if (progLang === "javascript")
|
||||
code = "function evaluate(response) {\n return response.text.length;\n}";
|
||||
addNode({
|
||||
id: "evalNode-" + Date.now(),
|
||||
type: "evaluator",
|
||||
data: { language: progLang, code },
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
};
|
||||
const addVisNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({
|
||||
id: "visNode-" + Date.now(),
|
||||
type: "vis",
|
||||
data: {},
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
};
|
||||
const addInspectNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({
|
||||
id: "inspectNode-" + Date.now(),
|
||||
type: "inspect",
|
||||
data: {},
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
};
|
||||
const addScriptNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({
|
||||
id: "scriptNode-" + Date.now(),
|
||||
type: "script",
|
||||
data: {},
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
};
|
||||
const addItemsNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({
|
||||
id: "csvNode-" + Date.now(),
|
||||
type: "csv",
|
||||
data: {},
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
};
|
||||
const addTabularDataNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({
|
||||
id: "table-" + Date.now(),
|
||||
type: "table",
|
||||
data: {},
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
};
|
||||
const addCommentNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({
|
||||
id: "comment-" + Date.now(),
|
||||
type: "comment",
|
||||
data: {},
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
};
|
||||
const addLLMEvalNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({
|
||||
id: "llmeval-" + Date.now(),
|
||||
type: "llmeval",
|
||||
data: {},
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
};
|
||||
const addJoinNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({
|
||||
id: "join-" + Date.now(),
|
||||
type: "join",
|
||||
data: {},
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
};
|
||||
const addSplitNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({
|
||||
id: "split-" + Date.now(),
|
||||
type: "split",
|
||||
data: {},
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
addNode("evalNode", "evaluator", { language: progLang, code });
|
||||
};
|
||||
const addVisNode = () => addNode("visNode", "vis", {});
|
||||
const addInspectNode = () => addNode("inspectNode", "inspect");
|
||||
const addScriptNode = () => addNode("scriptNode", "script");
|
||||
const addItemsNode = () => addNode("csvNode", "csv");
|
||||
const addTabularDataNode = () => addNode("table");
|
||||
const addCommentNode = () => addNode("comment");
|
||||
const addLLMEvalNode = () => addNode("llmeval");
|
||||
const addJoinNode = () => addNode("join");
|
||||
const addSplitNode = () => addNode("split");
|
||||
const addProcessorNode = (progLang: string) => {
|
||||
const { x, y } = getViewportCenter();
|
||||
let code = "";
|
||||
if (progLang === "python")
|
||||
code = "def process(response):\n return response.text;";
|
||||
else if (progLang === "javascript")
|
||||
code = "function process(response) {\n return response.text;\n}";
|
||||
addNode({
|
||||
id: "process-" + Date.now(),
|
||||
type: "processor",
|
||||
data: { language: progLang, code },
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
addNode("process", "processor", { language: progLang, code });
|
||||
};
|
||||
|
||||
const onClickExamples = () => {
|
||||
@ -429,7 +323,7 @@ const App = () => {
|
||||
/**
|
||||
* SAVING / LOADING, IMPORT / EXPORT (from JSON)
|
||||
*/
|
||||
const downloadJSON = (jsonData, filename) => {
|
||||
const downloadJSON = (jsonData: JSONCompatible, filename: string) => {
|
||||
// Convert JSON object to JSON string
|
||||
const jsonString = JSON.stringify(jsonData, null, 2);
|
||||
|
||||
@ -547,7 +441,7 @@ const App = () => {
|
||||
const saved_flow = StorageCache.loadFromLocalStorage(
|
||||
"chainforge-flow",
|
||||
false,
|
||||
);
|
||||
) as Dict;
|
||||
if (saved_flow) {
|
||||
StorageCache.loadFromLocalStorage("chainforge-state");
|
||||
importGlobalStateFromCache();
|
||||
@ -718,7 +612,7 @@ const App = () => {
|
||||
};
|
||||
|
||||
// Load flow from examples modal
|
||||
const onSelectExampleFlow = (name: string, example_category: string) => {
|
||||
const onSelectExampleFlow = (name: string, example_category?: string) => {
|
||||
// Trigger the 'loading' modal
|
||||
setIsLoading(true);
|
||||
|
||||
@ -872,9 +766,9 @@ const App = () => {
|
||||
]);
|
||||
|
||||
// Initialize auto-saving
|
||||
const initAutosaving = (rf_inst) => {
|
||||
if (autosavingInterval !== null) return; // autosaving interval already set
|
||||
console.log("Init autosaving!");
|
||||
const initAutosaving = (rf_inst: ReactFlowInstance) => {
|
||||
if (autosavingInterval !== undefined) return; // autosaving interval already set
|
||||
console.log("Init autosaving");
|
||||
|
||||
// Autosave the flow to localStorage every minute:
|
||||
const interv = setInterval(() => {
|
||||
@ -898,7 +792,7 @@ const App = () => {
|
||||
"Autosaving disabled. The time required to save to localStorage exceeds 1 second. This can happen when there's a lot of data in your flow. Make sure to export frequently to save your work.",
|
||||
);
|
||||
clearInterval(interv);
|
||||
setAutosavingInterval(null);
|
||||
setAutosavingInterval(undefined);
|
||||
}
|
||||
}, 60000); // 60000 milliseconds = 1 minute
|
||||
setAutosavingInterval(interv);
|
||||
@ -938,23 +832,16 @@ const App = () => {
|
||||
if (!response || response.startsWith("Error")) {
|
||||
// Error encountered during the query; alert the user
|
||||
// with the error message:
|
||||
handleError(new Error(response || "Unknown error"));
|
||||
return;
|
||||
throw new Error(response || "Unknown error");
|
||||
}
|
||||
|
||||
// Attempt to parse the response as a compressed flow + import it:
|
||||
try {
|
||||
const cforge_json = JSON.parse(
|
||||
LZString.decompressFromUTF16(response),
|
||||
);
|
||||
importFlowFromJSON(cforge_json, rf_inst);
|
||||
} catch (err) {
|
||||
handleError(err);
|
||||
}
|
||||
const cforge_json = JSON.parse(
|
||||
LZString.decompressFromUTF16(response),
|
||||
);
|
||||
importFlowFromJSON(cforge_json, rf_inst);
|
||||
})
|
||||
.catch((err) => {
|
||||
handleError(err);
|
||||
});
|
||||
.catch(handleError);
|
||||
} catch (err) {
|
||||
// Soft fail
|
||||
setIsLoading(false);
|
||||
@ -1064,7 +951,9 @@ const App = () => {
|
||||
onConnect={onConnect}
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
// @ts-expect-error Node types won't perfectly fit unless we explicitly extend from RF's types; ignoring this for now.
|
||||
nodeTypes={nodeTypes}
|
||||
// @ts-expect-error Edge types won't perfectly fit unless we explicitly extend from RF's types; ignoring this for now.
|
||||
edgeTypes={edgeTypes}
|
||||
zoomOnPinch={false}
|
||||
zoomOnScroll={false}
|
||||
|
@ -7,7 +7,7 @@ import React, {
|
||||
forwardRef,
|
||||
useImperativeHandle,
|
||||
} from "react";
|
||||
import { Handle, Position } from "reactflow";
|
||||
import { Handle, NodeProps, Position } from "reactflow";
|
||||
import {
|
||||
Code,
|
||||
Modal,
|
||||
|
@ -357,16 +357,16 @@ const ExampleFlowCard: React.FC<ExampleFlowCardProps> = ({
|
||||
);
|
||||
};
|
||||
|
||||
interface ExampleFlowsModalHandles {
|
||||
export interface ExampleFlowsModalRef {
|
||||
trigger: () => void;
|
||||
}
|
||||
|
||||
interface ExampleFlowsModalProps {
|
||||
export interface ExampleFlowsModalProps {
|
||||
handleOnSelect: (filename: string, category?: string) => void;
|
||||
}
|
||||
|
||||
const ExampleFlowsModal = forwardRef<
|
||||
ExampleFlowsModalHandles,
|
||||
ExampleFlowsModalRef,
|
||||
ExampleFlowsModalProps
|
||||
>(function ExampleFlowsModal({ handleOnSelect }, ref) {
|
||||
// Mantine modal popover for alerts
|
||||
|
@ -31,12 +31,15 @@ import {
|
||||
IconX,
|
||||
IconSparkles,
|
||||
} from "@tabler/icons-react";
|
||||
import { Dropzone } from "@mantine/dropzone";
|
||||
import { Dropzone, FileWithPath } from "@mantine/dropzone";
|
||||
import useStore from "./store";
|
||||
import { APP_IS_RUNNING_LOCALLY } from "./backend/utils";
|
||||
import fetch_from_backend from "./fetch_from_backend";
|
||||
import { setCustomProviders } from "./ModelSettingSchemas";
|
||||
import { getAIFeaturesModelProviders } from "./backend/ai";
|
||||
import { CustomLLMProviderSpec, Dict } from "./backend/typing";
|
||||
import { initCustomProvider, loadCachedCustomProviders, removeCustomProvider } from "./backend/backend";
|
||||
import { AlertModalRef } from "./AlertModal";
|
||||
|
||||
const _LINK_STYLE = { color: "#1E90FF", textDecoration: "none" };
|
||||
|
||||
@ -44,11 +47,11 @@ const _LINK_STYLE = { color: "#1E90FF", textDecoration: "none" };
|
||||
let LOADED_CUSTOM_PROVIDERS = false;
|
||||
|
||||
// Read a file as text and pass the text to a cb (callback) function
|
||||
const read_file = (file, cb) => {
|
||||
const read_file = (file: FileWithPath, cb: (contents: string | ArrayBuffer | null) => void) => {
|
||||
const reader = new window.FileReader();
|
||||
reader.onload = function (event) {
|
||||
const fileContent = event.target.result;
|
||||
cb(fileContent);
|
||||
const fileContent = event.target?.result;
|
||||
cb(fileContent ?? null);
|
||||
};
|
||||
reader.onerror = function (event) {
|
||||
console.error("Error reading file:", event);
|
||||
@ -56,10 +59,15 @@ const read_file = (file, cb) => {
|
||||
reader.readAsText(file);
|
||||
};
|
||||
|
||||
interface CustomProviderScriptDropzoneProps {
|
||||
onError: (err: string | Error) => void;
|
||||
onSetProviders: (providers: CustomLLMProviderSpec[]) => void;
|
||||
}
|
||||
|
||||
/** A Dropzone to load a Python `.py` script that registers a `CustomModelProvider` in the Flask backend.
|
||||
* If successful, the list of custom model providers in the ChainForge UI dropdown is updated.
|
||||
* */
|
||||
const CustomProviderScriptDropzone = ({ onError, onSetProviders }) => {
|
||||
const CustomProviderScriptDropzone: React.FC<CustomProviderScriptDropzoneProps> = ({ onError, onSetProviders }) => {
|
||||
const theme = useMantineTheme();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
@ -69,23 +77,20 @@ const CustomProviderScriptDropzone = ({ onError, onSetProviders }) => {
|
||||
onDrop={(files) => {
|
||||
if (files.length === 1) {
|
||||
setIsLoading(true);
|
||||
read_file(files[0], (content) => {
|
||||
read_file(files[0], (content: string | ArrayBuffer | null) => {
|
||||
if (typeof content !== "string") {
|
||||
console.error("File unreadable: Contents are not text.");
|
||||
return;
|
||||
}
|
||||
// Read the file into text and then send it to backend
|
||||
fetch_from_backend("initCustomProvider", {
|
||||
code: content,
|
||||
})
|
||||
.then((response) => {
|
||||
initCustomProvider(content)
|
||||
.then((providers) => {
|
||||
setIsLoading(false);
|
||||
|
||||
if (response.error || !response.providers) {
|
||||
onError(response.error);
|
||||
return;
|
||||
}
|
||||
// Successfully loaded custom providers in backend,
|
||||
// now load them into the ChainForge UI:
|
||||
console.log(response.providers);
|
||||
setCustomProviders(response.providers);
|
||||
onSetProviders(response.providers);
|
||||
console.log(providers);
|
||||
setCustomProviders(providers);
|
||||
onSetProviders(providers);
|
||||
})
|
||||
.catch((err) => {
|
||||
setIsLoading(false);
|
||||
@ -102,8 +107,6 @@ const CustomProviderScriptDropzone = ({ onError, onSetProviders }) => {
|
||||
maxSize={3 * 1024 ** 2}
|
||||
>
|
||||
<Flex
|
||||
pos="center"
|
||||
spacing="md"
|
||||
style={{ minHeight: rem(80), pointerEvents: "none" }}
|
||||
>
|
||||
<Center>
|
||||
@ -144,7 +147,15 @@ const CustomProviderScriptDropzone = ({ onError, onSetProviders }) => {
|
||||
);
|
||||
};
|
||||
|
||||
const GlobalSettingsModal = forwardRef(
|
||||
export interface GlobalSettingsModalRef {
|
||||
trigger: () => void;
|
||||
}
|
||||
|
||||
export interface GlobalSettingsModalProps {
|
||||
alertModal?: React.RefObject<AlertModalRef>;
|
||||
}
|
||||
|
||||
const GlobalSettingsModal = forwardRef<GlobalSettingsModalRef, GlobalSettingsModalProps>(
|
||||
function GlobalSettingsModal(props, ref) {
|
||||
const [opened, { open, close }] = useDisclosure(false);
|
||||
const setAPIKeys = useStore((state) => state.setAPIKeys);
|
||||
@ -160,11 +171,15 @@ const GlobalSettingsModal = forwardRef(
|
||||
);
|
||||
const aiFeaturesProvider = useStore((state) => state.aiFeaturesProvider);
|
||||
|
||||
const [aiSupportActive, setAISupportActive] = useState(
|
||||
getFlag("aiSupport"),
|
||||
const [aiSupportActive, setAISupportActive] = useState<boolean>(
|
||||
getFlag("aiSupport") as boolean,
|
||||
);
|
||||
const [aiAutocompleteActive, setAIAutocompleteActive] = useState<boolean>(
|
||||
getFlag("aiAutocomplete") as boolean,
|
||||
);
|
||||
|
||||
const handleAISupportChecked = useCallback(
|
||||
(e) => {
|
||||
(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const checked = e.currentTarget.checked;
|
||||
setAISupportActive(checked);
|
||||
setFlag("aiSupport", checked);
|
||||
@ -177,11 +192,8 @@ const GlobalSettingsModal = forwardRef(
|
||||
[setFlag, setAISupportActive],
|
||||
);
|
||||
|
||||
const [aiAutocompleteActive, setAIAutocompleteActive] = useState(
|
||||
getFlag("aiAutocomplete"),
|
||||
);
|
||||
const handleAIAutocompleteChecked = useCallback(
|
||||
(e) => {
|
||||
(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const checked = e.currentTarget.checked;
|
||||
setAIAutocompleteActive(checked);
|
||||
setFlag("aiAutocomplete", checked);
|
||||
@ -190,13 +202,14 @@ const GlobalSettingsModal = forwardRef(
|
||||
);
|
||||
|
||||
const handleError = useCallback(
|
||||
(msg) => {
|
||||
(err: string | Error) => {
|
||||
const msg = typeof err === "string" ? err : err.message;
|
||||
if (alertModal && alertModal.current) alertModal.current.trigger(msg);
|
||||
},
|
||||
[alertModal],
|
||||
);
|
||||
|
||||
const [customProviders, setLocalCustomProviders] = useState([]);
|
||||
const [customProviders, setLocalCustomProviders] = useState<CustomLLMProviderSpec[]>([]);
|
||||
const refreshLLMProviderLists = useCallback(() => {
|
||||
// We unfortunately have to force all prompt/chat nodes to refresh their LLM lists, bc
|
||||
// apparently the update to the AvailableLLMs list is not immediately propagated to them.
|
||||
@ -208,16 +221,9 @@ const GlobalSettingsModal = forwardRef(
|
||||
);
|
||||
}, [nodes, setDataPropsForNode]);
|
||||
|
||||
const removeCustomProvider = useCallback(
|
||||
(name) => {
|
||||
fetch_from_backend("removeCustomProvider", {
|
||||
name,
|
||||
})
|
||||
.then((response) => {
|
||||
if (response.error || !response.success) {
|
||||
handleError(response.error);
|
||||
return;
|
||||
}
|
||||
const handleRemoveCustomProvider = useCallback(
|
||||
(name: string) => {
|
||||
removeCustomProvider(name).then(() => {
|
||||
// Successfully deleted the custom provider from backend;
|
||||
// now updated the front-end UI to reflect this:
|
||||
setAvailableLLMs(AvailableLLMs.filter((p) => p.name !== name));
|
||||
@ -226,7 +232,7 @@ const GlobalSettingsModal = forwardRef(
|
||||
);
|
||||
refreshLLMProviderLists();
|
||||
})
|
||||
.catch((err) => handleError(err.message));
|
||||
.catch(handleError);
|
||||
},
|
||||
[customProviders, handleError, AvailableLLMs, refreshLLMProviderLists],
|
||||
);
|
||||
@ -237,20 +243,13 @@ const GlobalSettingsModal = forwardRef(
|
||||
LOADED_CUSTOM_PROVIDERS = true;
|
||||
// Is running locally; try to load any custom providers.
|
||||
// Soft fails if it encounters error:
|
||||
fetch_from_backend("loadCachedCustomProviders", {}, console.error).then(
|
||||
(json) => {
|
||||
if (json?.error || json?.providers === undefined) {
|
||||
console.error(
|
||||
json?.error ||
|
||||
"Could not load custom provider scripts: Error contacting backend.",
|
||||
);
|
||||
return;
|
||||
}
|
||||
loadCachedCustomProviders().then(
|
||||
(providers) => {
|
||||
// Success; pass custom providers list to store:
|
||||
setCustomProviders(json.providers);
|
||||
setLocalCustomProviders(json.providers);
|
||||
setCustomProviders(providers);
|
||||
setLocalCustomProviders(providers);
|
||||
},
|
||||
);
|
||||
).catch(console.error);
|
||||
}
|
||||
}, []);
|
||||
|
||||
@ -276,7 +275,7 @@ const GlobalSettingsModal = forwardRef(
|
||||
});
|
||||
|
||||
// When the API settings form is submitted
|
||||
const onSubmit = (values) => {
|
||||
const onSubmit = (values: Dict<string>) => {
|
||||
setAPIKeys(values);
|
||||
close();
|
||||
};
|
||||
@ -524,7 +523,7 @@ const GlobalSettingsModal = forwardRef(
|
||||
)}
|
||||
</Group>
|
||||
<Button
|
||||
onClick={() => removeCustomProvider(p.name)}
|
||||
onClick={() => handleRemoveCustomProvider(p.name)}
|
||||
color="red"
|
||||
p="0px"
|
||||
mt="4px"
|
||||
@ -537,7 +536,7 @@ const GlobalSettingsModal = forwardRef(
|
||||
))}
|
||||
<CustomProviderScriptDropzone
|
||||
onError={handleError}
|
||||
onSetProviders={(ps) => {
|
||||
onSetProviders={(ps: CustomLLMProviderSpec[]) => {
|
||||
refreshLLMProviderLists();
|
||||
setLocalCustomProviders(ps);
|
||||
}}
|
@ -1,12 +1,16 @@
|
||||
import React, { Button, Group, RingProgress } from "@mantine/core";
|
||||
import { IconSettings, IconTrash } from "@tabler/icons-react";
|
||||
import { QueryProgress } from "./backend/typing";
|
||||
|
||||
export function GatheringResponsesRingProgress({ progress }) {
|
||||
export function GatheringResponsesRingProgress({
|
||||
progress,
|
||||
}: {
|
||||
progress: QueryProgress | undefined;
|
||||
}) {
|
||||
return progress !== undefined ? (
|
||||
progress.success > 0 || progress.error > 0 ? (
|
||||
<RingProgress
|
||||
size={20}
|
||||
width="16px"
|
||||
thickness={3}
|
||||
sections={[
|
||||
{
|
||||
@ -29,12 +33,19 @@ export function GatheringResponsesRingProgress({ progress }) {
|
||||
);
|
||||
}
|
||||
|
||||
export interface LLMItemButtonGroupProps {
|
||||
onClickTrash?: () => void;
|
||||
onClickSettings?: () => void;
|
||||
ringProgress?: QueryProgress;
|
||||
hideTrashIcon?: boolean;
|
||||
}
|
||||
|
||||
export default function LLMItemButtonGroup({
|
||||
onClickTrash,
|
||||
onClickSettings,
|
||||
ringProgress,
|
||||
hideTrashIcon,
|
||||
}) {
|
||||
}: LLMItemButtonGroupProps) {
|
||||
return (
|
||||
<div>
|
||||
<Group position="right" style={{ float: "right", height: "20px" }}>
|
@ -5,12 +5,14 @@ import LLMItemButtonGroup from "./LLMItemButtonGroup";
|
||||
import { IconTemperature } from "@tabler/icons-react";
|
||||
import { getTemperatureSpecForModel } from "./ModelSettingSchemas";
|
||||
import { Tooltip } from "@mantine/core";
|
||||
import { LLMSpec, QueryProgress } from "./backend/typing";
|
||||
import { DraggableProvided, DraggableStateSnapshot } from "react-beautiful-dnd";
|
||||
|
||||
// == The below function perc2color modified from: ==
|
||||
// License: MIT - https://opensource.org/licenses/MIT
|
||||
// Author: Michele Locati <michele@locati.it>
|
||||
// Source: https://gist.github.com/mlocati/7210513
|
||||
const perc2color = (perc) => {
|
||||
const perc2color = (perc: number) => {
|
||||
let r = 0;
|
||||
let g = 0;
|
||||
let b = 0;
|
||||
@ -27,9 +29,9 @@ const perc2color = (perc) => {
|
||||
return "#" + ("000000" + h.toString(16)).slice(-6);
|
||||
};
|
||||
|
||||
const percTemperature = (llm_item) => {
|
||||
const percTemperature = (llm_item: LLMSpec) => {
|
||||
// Get the temp for this llm item
|
||||
const temp = llm_item.settings?.temperature;
|
||||
const temp = llm_item.settings?.temperature as number;
|
||||
if (temp === undefined) {
|
||||
console.warn(
|
||||
`Did not find temperature setting for model ${llm_item.base_model}.`,
|
||||
@ -74,7 +76,17 @@ export const DragItem = styled.div`
|
||||
flex-direction: column;
|
||||
`;
|
||||
|
||||
const LLMListItem = ({
|
||||
export interface LLMListItemProps {
|
||||
item: LLMSpec;
|
||||
provided: DraggableProvided;
|
||||
snapshot: DraggableStateSnapshot;
|
||||
removeCallback: (key: string) => void;
|
||||
onClickSettings: () => void;
|
||||
progress?: QueryProgress;
|
||||
hideTrashIcon: boolean;
|
||||
}
|
||||
|
||||
const LLMListItem: React.FC<LLMListItemProps> = ({
|
||||
item,
|
||||
provided,
|
||||
snapshot,
|
||||
@ -85,7 +97,7 @@ const LLMListItem = ({
|
||||
}) => {
|
||||
// Set color by temperature only on item change (not every render)
|
||||
const [tempColor, setTempColor] = useState(perc2color(50));
|
||||
const temperature = item.settings?.temperature;
|
||||
const temperature = item.settings?.temperature as number;
|
||||
|
||||
useEffect(() => {
|
||||
if (temperature !== undefined)
|
||||
@ -95,6 +107,7 @@ const LLMListItem = ({
|
||||
return (
|
||||
<DragItem
|
||||
ref={provided.innerRef}
|
||||
// @ts-expect-error This property is used dynamically by the Draggable library.
|
||||
snapshot={snapshot}
|
||||
{...provided.draggableProps}
|
||||
{...provided.dragHandleProps}
|
||||
@ -123,7 +136,7 @@ const LLMListItem = ({
|
||||
)}
|
||||
</CardHeader>
|
||||
<LLMItemButtonGroup
|
||||
onClickTrash={() => removeCallback(item.key)}
|
||||
onClickTrash={() => removeCallback(item.key ?? "undefined")}
|
||||
ringProgress={progress}
|
||||
onClickSettings={onClickSettings}
|
||||
hideTrashIcon={hideTrashIcon}
|
||||
@ -133,7 +146,7 @@ const LLMListItem = ({
|
||||
);
|
||||
};
|
||||
|
||||
export const LLMListItemClone = ({
|
||||
export const LLMListItemClone: React.FC<LLMListItemProps> = ({
|
||||
item,
|
||||
provided,
|
||||
snapshot,
|
||||
@ -141,7 +154,7 @@ export const LLMListItemClone = ({
|
||||
}) => {
|
||||
// Set color by temperature only on item change (not every render)
|
||||
const [tempColor, setTempColor] = useState(perc2color(50));
|
||||
const temperature = item.settings?.temperature;
|
||||
const temperature = item.settings?.temperature as number;
|
||||
|
||||
useEffect(() => {
|
||||
if (temperature !== undefined)
|
||||
@ -153,6 +166,7 @@ export const LLMListItemClone = ({
|
||||
ref={provided.innerRef}
|
||||
{...provided.draggableProps}
|
||||
{...provided.dragHandleProps}
|
||||
// @ts-expect-error This property is used dynamically by the Draggable library.
|
||||
snapshot={snapshot}
|
||||
>
|
||||
<div>
|
@ -2153,9 +2153,9 @@ export const getTemperatureSpecForModel = (modelName: string) => {
|
||||
ModelSettings[modelName].schema?.properties?.temperature;
|
||||
if (temperature_property) {
|
||||
return {
|
||||
minimum: temperature_property.minimum,
|
||||
maximum: temperature_property.maximum,
|
||||
default: temperature_property.default,
|
||||
minimum: temperature_property.minimum as number,
|
||||
maximum: temperature_property.maximum as number,
|
||||
default: temperature_property.default as number,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ import {
|
||||
LLMSpec,
|
||||
EvaluatedResponsesResults,
|
||||
TemplateVarInfo,
|
||||
CustomLLMProviderSpec,
|
||||
} from "./typing";
|
||||
import { LLM, getEnumName } from "./models";
|
||||
import {
|
||||
@ -489,7 +490,9 @@ export async function generatePrompts(
|
||||
vars: Dict<(TemplateVarInfo | string)[]>,
|
||||
): Promise<PromptTemplate[]> {
|
||||
const gen_prompts = new PromptPermutationGenerator(root_prompt);
|
||||
const all_prompt_permutations = Array.from(gen_prompts.generate(deepcopy(vars)));
|
||||
const all_prompt_permutations = Array.from(
|
||||
gen_prompts.generate(deepcopy(vars)),
|
||||
);
|
||||
return all_prompt_permutations;
|
||||
}
|
||||
|
||||
@ -1491,7 +1494,7 @@ export async function fetchOpenAIEval(evalname: string): Promise<Dict> {
|
||||
* @returns a Promise with the JSON of the response. Will include 'error' key if error'd; if success,
|
||||
* a 'providers' key with a list of all loaded custom provider callbacks, as dicts.
|
||||
*/
|
||||
export async function initCustomProvider(code: string): Promise<Dict> {
|
||||
export async function initCustomProvider(code: string): Promise<CustomLLMProviderSpec[]> {
|
||||
// Attempt to fetch the example flow from the local filesystem
|
||||
// by querying the Flask server:
|
||||
return fetch(`${FLASK_BASE_URL}app/initCustomProvider`, {
|
||||
@ -1503,6 +1506,10 @@ export async function initCustomProvider(code: string): Promise<Dict> {
|
||||
body: JSON.stringify({ code }),
|
||||
}).then(function (res) {
|
||||
return res.json();
|
||||
}).then(function (json) {
|
||||
if (!json || json.error || !json.providers)
|
||||
throw new Error(json.error ?? "Unknown error");
|
||||
return json.providers as CustomLLMProviderSpec[];
|
||||
});
|
||||
}
|
||||
|
||||
@ -1513,7 +1520,7 @@ export async function initCustomProvider(code: string): Promise<Dict> {
|
||||
* @returns a Promise with the JSON of the response. Will include 'error' key if error'd; if success,
|
||||
* a 'success' key with a true value.
|
||||
*/
|
||||
export async function removeCustomProvider(name: string): Promise<Dict> {
|
||||
export async function removeCustomProvider(name: string): Promise<boolean> {
|
||||
// Attempt to fetch the example flow from the local filesystem
|
||||
// by querying the Flask server:
|
||||
return fetch(`${FLASK_BASE_URL}app/removeCustomProvider`, {
|
||||
@ -1525,6 +1532,10 @@ export async function removeCustomProvider(name: string): Promise<Dict> {
|
||||
body: JSON.stringify({ name }),
|
||||
}).then(function (res) {
|
||||
return res.json();
|
||||
}).then(function (json) {
|
||||
if (!json || json.error || !json.success)
|
||||
throw new Error(json.error ?? "Unknown error");
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@ -1534,7 +1545,7 @@ export async function removeCustomProvider(name: string): Promise<Dict> {
|
||||
* @returns a Promise with the JSON of the response. Will include 'error' key if error'd; if success,
|
||||
* a 'providers' key with all loaded custom providers in an array. If there were none, returns empty array.
|
||||
*/
|
||||
export async function loadCachedCustomProviders(): Promise<Dict> {
|
||||
export async function loadCachedCustomProviders(): Promise<CustomLLMProviderSpec[]> {
|
||||
return fetch(`${FLASK_BASE_URL}app/loadCachedCustomProviders`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
@ -1544,5 +1555,9 @@ export async function loadCachedCustomProviders(): Promise<Dict> {
|
||||
body: "{}",
|
||||
}).then(function (res) {
|
||||
return res.json();
|
||||
}).then(function (json) {
|
||||
if (!json || json.error || !json.providers)
|
||||
throw new Error(json.error ?? "Could not load custom provider scripts: Error contacting backend.");
|
||||
return json.providers as CustomLLMProviderSpec[];
|
||||
});
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Dict } from "./typing";
|
||||
import { Dict, JSONCompatible } from "./typing";
|
||||
import LZString from "lz-string";
|
||||
|
||||
/**
|
||||
@ -98,7 +98,7 @@ export default class StorageCache {
|
||||
console.warn("Storage quota exceeded");
|
||||
} else {
|
||||
// Handle other types of storage-related errors
|
||||
console.error("Error storing data in localStorage:", error.message);
|
||||
console.error("Error storing data in localStorage:", (error as Error).message);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -114,7 +114,7 @@ export default class StorageCache {
|
||||
public static loadFromLocalStorage(
|
||||
localStorageKey = "chainforge",
|
||||
setStorageCacheData = true,
|
||||
): boolean {
|
||||
): JSONCompatible | undefined {
|
||||
const compressed = localStorage.getItem(localStorageKey);
|
||||
if (!compressed) {
|
||||
console.error(
|
||||
@ -128,7 +128,7 @@ export default class StorageCache {
|
||||
console.log("loaded", data);
|
||||
return data;
|
||||
} catch (error) {
|
||||
console.error(error.message);
|
||||
console.error((error as Error).message);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import {
|
||||
NodeChange,
|
||||
EdgeChange,
|
||||
MarkerType,
|
||||
Connection,
|
||||
} from "reactflow";
|
||||
import { escapeBraces } from "./backend/template";
|
||||
import {
|
||||
@ -248,7 +249,7 @@ export interface StoreHandles {
|
||||
removeEdge: (id: string) => void;
|
||||
onNodesChange: (changes: NodeChange[]) => void;
|
||||
onEdgesChange: (changes: EdgeChange[]) => void;
|
||||
onConnect: (connection: Edge) => void;
|
||||
onConnect: (connection: Connection | Edge) => void;
|
||||
|
||||
// The LLM providers available in the drop-down list
|
||||
AvailableLLMs: LLMSpec[];
|
||||
@ -717,7 +718,7 @@ const useStore = create<StoreHandles>((set, get) => ({
|
||||
},
|
||||
onConnect: (connection) => {
|
||||
// Get the target node information
|
||||
const target = get().getNode(connection.target);
|
||||
const target = connection.target ? get().getNode(connection.target) : undefined;
|
||||
if (target === undefined) return;
|
||||
|
||||
if (
|
||||
@ -736,6 +737,7 @@ const useStore = create<StoreHandles>((set, get) => ({
|
||||
get().setDataPropsForNode(target.id, { refresh: true });
|
||||
}
|
||||
|
||||
connection = connection as Edge;
|
||||
connection.interactionWidth = 40;
|
||||
connection.markerEnd = { type: MarkerType.Arrow, width: 22, height: 22 }; // 22px
|
||||
connection.type = "default";
|
||||
|
Loading…
x
Reference in New Issue
Block a user