mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-15 00:36:29 +00:00
wip
This commit is contained in:
parent
3bf4c0d2b9
commit
f45d99e3e2
@ -1,5 +1,5 @@
|
||||
import React, { useState, useCallback, useRef, useEffect } from "react";
|
||||
import ReactFlow, { Controls, Background } from "reactflow";
|
||||
import ReactFlow, { Controls, Background, ReactFlowInstance } from "reactflow";
|
||||
import {
|
||||
Button,
|
||||
Menu,
|
||||
@ -31,7 +31,7 @@ import CodeEvaluatorNode from "./CodeEvaluatorNode";
|
||||
import VisNode from "./VisNode";
|
||||
import InspectNode from "./InspectorNode";
|
||||
import ScriptNode from "./ScriptNode";
|
||||
import AlertModal from "./AlertModal";
|
||||
import AlertModal, { AlertModalHandles } from "./AlertModal";
|
||||
import ItemsNode from "./ItemsNode";
|
||||
import TabularDataNode from "./TabularDataNode";
|
||||
import JoinNode from "./JoinNode";
|
||||
@ -56,7 +56,7 @@ import "./text-fields-node.css"; // project
|
||||
|
||||
// State management (from https://reactflow.dev/docs/guides/state-management/)
|
||||
import { shallow } from "zustand/shallow";
|
||||
import useStore from "./store";
|
||||
import useStore, { StoreHandles } from "./store";
|
||||
import fetch_from_backend from "./fetch_from_backend";
|
||||
import StorageCache from "./backend/cache";
|
||||
import { APP_IS_RUNNING_LOCALLY, browserTabIsActive } from "./backend/utils";
|
||||
@ -69,19 +69,20 @@ import {
|
||||
isEdgeChromium,
|
||||
isChromium,
|
||||
} from "react-device-detect";
|
||||
import { Dict, LLMSpec } from "./backend/typing";
|
||||
const IS_ACCEPTED_BROWSER =
|
||||
(isChrome ||
|
||||
isChromium ||
|
||||
isEdgeChromium ||
|
||||
isFirefox ||
|
||||
navigator?.brave !== undefined) &&
|
||||
(navigator as any)?.brave !== undefined) &&
|
||||
!isMobile;
|
||||
|
||||
// Whether we are running on localhost or not, and hence whether
|
||||
// we have access to the Flask backend for, e.g., Python code evaluation.
|
||||
const IS_RUNNING_LOCALLY = APP_IS_RUNNING_LOCALLY();
|
||||
|
||||
const selector = (state) => ({
|
||||
const selector = (state: StoreHandles) => ({
|
||||
nodes: state.nodes,
|
||||
edges: state.edges,
|
||||
onNodesChange: state.onNodesChange,
|
||||
@ -108,7 +109,7 @@ const INITIAL_LLM = () => {
|
||||
temp: 1.0,
|
||||
settings: getDefaultModelSettings("hf"),
|
||||
formData: getDefaultModelFormData("hf"),
|
||||
};
|
||||
} satisfies LLMSpec;
|
||||
falcon7b.formData.shortname = falcon7b.name;
|
||||
falcon7b.formData.model = falcon7b.model;
|
||||
return falcon7b;
|
||||
@ -123,7 +124,7 @@ const INITIAL_LLM = () => {
|
||||
temp: 1.0,
|
||||
settings: getDefaultModelSettings("gpt-3.5-turbo"),
|
||||
formData: getDefaultModelFormData("gpt-3.5-turbo"),
|
||||
};
|
||||
} satisfies LLMSpec;
|
||||
chatgpt.formData.shortname = chatgpt.name;
|
||||
chatgpt.formData.model = chatgpt.model;
|
||||
return chatgpt;
|
||||
@ -173,7 +174,13 @@ const getSharedFlowURLParam = () => {
|
||||
return undefined;
|
||||
};
|
||||
|
||||
const MenuTooltip = ({ label, children }) => {
|
||||
const MenuTooltip = ({
|
||||
label,
|
||||
children,
|
||||
}: {
|
||||
label: string;
|
||||
children: React.ReactNode;
|
||||
}) => {
|
||||
return (
|
||||
<Tooltip
|
||||
label={label}
|
||||
@ -208,7 +215,7 @@ const App = () => {
|
||||
} = useStore(selector, shallow);
|
||||
|
||||
// For saving / loading
|
||||
const [rfInstance, setRfInstance] = useState(null);
|
||||
const [rfInstance, setRfInstance] = useState<ReactFlowInstance | null>(null);
|
||||
const [autosavingInterval, setAutosavingInterval] = useState(null);
|
||||
|
||||
// For 'share' button
|
||||
@ -236,7 +243,7 @@ const App = () => {
|
||||
const { hideContextMenu } = useContextMenu();
|
||||
|
||||
// For displaying error messages to user
|
||||
const alertModal = useRef(null);
|
||||
const alertModal = useRef<AlertModalHandles>(null);
|
||||
|
||||
// For displaying a pending 'loading' status
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
@ -252,6 +259,7 @@ const App = () => {
|
||||
};
|
||||
const getViewportCenter = () => {
|
||||
const { centerX, centerY } = getWindowCenter();
|
||||
if (rfInstance === null) return { x: centerX, y: centerY };
|
||||
// Support Zoom
|
||||
const { x, y, zoom } = rfInstance.getViewport();
|
||||
return { x: -(x / zoom) + centerX / zoom, y: -(y / zoom) + centerY / zoom };
|
||||
@ -293,7 +301,7 @@ const App = () => {
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
};
|
||||
const addEvalNode = (progLang) => {
|
||||
const addEvalNode = (progLang: string) => {
|
||||
const { x, y } = getViewportCenter();
|
||||
let code = "";
|
||||
if (progLang === "python")
|
||||
@ -388,7 +396,7 @@ const App = () => {
|
||||
position: { x: x - 200, y: y - 100 },
|
||||
});
|
||||
};
|
||||
const addProcessorNode = (progLang) => {
|
||||
const addProcessorNode = (progLang: string) => {
|
||||
const { x, y } = getViewportCenter();
|
||||
let code = "";
|
||||
if (progLang === "python")
|
||||
@ -410,11 +418,12 @@ const App = () => {
|
||||
if (settingsModal && settingsModal.current) settingsModal.current.trigger();
|
||||
};
|
||||
|
||||
const handleError = (err) => {
|
||||
const handleError = (err: Error | string) => {
|
||||
const msg = typeof err === "string" ? err : err.message;
|
||||
setIsLoading(false);
|
||||
setWaitingForShare(false);
|
||||
if (alertModal.current) alertModal.current.trigger(err.message);
|
||||
console.error(err.message);
|
||||
if (alertModal.current) alertModal.current.trigger(msg);
|
||||
console.error(msg);
|
||||
};
|
||||
|
||||
/**
|
||||
@ -468,7 +477,7 @@ const App = () => {
|
||||
const resetFlow = useCallback(() => {
|
||||
resetLLMColors();
|
||||
|
||||
const uid = (id) => `${id}-${Date.now()}`;
|
||||
const uid = (id: string) => `${id}-${Date.now()}`;
|
||||
const starting_nodes = [
|
||||
{
|
||||
id: uid("prompt"),
|
||||
@ -493,39 +502,38 @@ const App = () => {
|
||||
if (rfInstance) rfInstance.setViewport({ x: 200, y: 80, zoom: 1 });
|
||||
}, [setNodes, setEdges, resetLLMColors, rfInstance]);
|
||||
|
||||
const loadFlow = async (flow, rf_inst) => {
|
||||
if (flow) {
|
||||
if (rf_inst) {
|
||||
if (flow.viewport)
|
||||
rf_inst.setViewport({
|
||||
x: flow.viewport.x || 0,
|
||||
y: flow.viewport.y || 0,
|
||||
zoom: flow.viewport.zoom || 1,
|
||||
});
|
||||
else rf_inst.setViewport({ x: 0, y: 0, zoom: 1 });
|
||||
}
|
||||
resetLLMColors();
|
||||
|
||||
// First, clear the ReactFlow state entirely
|
||||
// NOTE: We need to do this so it forgets any node/edge ids, which might have cross-over in the loaded flow.
|
||||
setNodes([]);
|
||||
setEdges([]);
|
||||
|
||||
// After a delay, load in the new state.
|
||||
setTimeout(() => {
|
||||
setNodes(flow.nodes || []);
|
||||
setEdges(flow.edges || []);
|
||||
|
||||
// Save flow that user loaded to autosave cache, in case they refresh the browser
|
||||
StorageCache.saveToLocalStorage("chainforge-flow", flow);
|
||||
|
||||
// Cancel loading spinner
|
||||
setIsLoading(false);
|
||||
}, 10);
|
||||
|
||||
// Start auto-saving, if it's not already enabled
|
||||
if (rf_inst) initAutosaving(rf_inst);
|
||||
const loadFlow = async (flow?: Dict, rf_inst?: ReactFlowInstance | null) => {
|
||||
if (flow === undefined) return;
|
||||
if (rf_inst) {
|
||||
if (flow.viewport)
|
||||
rf_inst.setViewport({
|
||||
x: flow.viewport.x || 0,
|
||||
y: flow.viewport.y || 0,
|
||||
zoom: flow.viewport.zoom || 1,
|
||||
});
|
||||
else rf_inst.setViewport({ x: 0, y: 0, zoom: 1 });
|
||||
}
|
||||
resetLLMColors();
|
||||
|
||||
// First, clear the ReactFlow state entirely
|
||||
// NOTE: We need to do this so it forgets any node/edge ids, which might have cross-over in the loaded flow.
|
||||
setNodes([]);
|
||||
setEdges([]);
|
||||
|
||||
// After a delay, load in the new state.
|
||||
setTimeout(() => {
|
||||
setNodes(flow.nodes || []);
|
||||
setEdges(flow.edges || []);
|
||||
|
||||
// Save flow that user loaded to autosave cache, in case they refresh the browser
|
||||
StorageCache.saveToLocalStorage("chainforge-flow", flow);
|
||||
|
||||
// Cancel loading spinner
|
||||
setIsLoading(false);
|
||||
}, 10);
|
||||
|
||||
// Start auto-saving, if it's not already enabled
|
||||
if (rf_inst) initAutosaving(rf_inst);
|
||||
};
|
||||
|
||||
const importGlobalStateFromCache = useCallback(() => {
|
||||
@ -535,7 +543,7 @@ const App = () => {
|
||||
const autosavedFlowExists = () => {
|
||||
return window.localStorage.getItem("chainforge-flow") !== null;
|
||||
};
|
||||
const loadFlowFromAutosave = async (rf_inst) => {
|
||||
const loadFlowFromAutosave = async (rf_inst: ReactFlowInstance) => {
|
||||
const saved_flow = StorageCache.loadFromLocalStorage(
|
||||
"chainforge-flow",
|
||||
false,
|
||||
@ -602,8 +610,8 @@ const App = () => {
|
||||
);
|
||||
|
||||
const importFlowFromJSON = useCallback(
|
||||
(flowJSON, rf_inst) => {
|
||||
const rf = rf_inst || rfInstance;
|
||||
(flowJSON: Dict, rf_inst?: ReactFlowInstance | null) => {
|
||||
const rf = rf_inst ?? rfInstance;
|
||||
|
||||
setIsLoading(true);
|
||||
|
||||
@ -656,13 +664,16 @@ const App = () => {
|
||||
// Handle file load event
|
||||
reader.addEventListener("load", function () {
|
||||
try {
|
||||
if (typeof reader.result !== "string")
|
||||
throw new Error("File could not be read: Unknown format or empty.");
|
||||
|
||||
// We try to parse the JSON response
|
||||
const flow_and_cache = JSON.parse(reader.result);
|
||||
|
||||
// Import it to React Flow and import cache data on the backend
|
||||
importFlowFromJSON(flow_and_cache);
|
||||
} catch (error) {
|
||||
handleError(error);
|
||||
handleError(error as Error);
|
||||
}
|
||||
});
|
||||
|
||||
@ -675,7 +686,7 @@ const App = () => {
|
||||
};
|
||||
|
||||
// Downloads the selected OpenAI eval file (preconverted to a .cforge flow)
|
||||
const importFlowFromOpenAIEval = (evalname) => {
|
||||
const importFlowFromOpenAIEval = (evalname: string) => {
|
||||
setIsLoading(true);
|
||||
|
||||
fetch_from_backend(
|
||||
@ -707,7 +718,7 @@ const App = () => {
|
||||
};
|
||||
|
||||
// Load flow from examples modal
|
||||
const onSelectExampleFlow = (name, example_category) => {
|
||||
const onSelectExampleFlow = (name: string, example_category: string) => {
|
||||
// Trigger the 'loading' modal
|
||||
setIsLoading(true);
|
||||
|
||||
@ -772,9 +783,9 @@ const App = () => {
|
||||
}
|
||||
|
||||
// Helper function
|
||||
function isFileSizeLessThan5MB(str) {
|
||||
function isFileSizeLessThan5MB(json_str: string) {
|
||||
const encoder = new TextEncoder();
|
||||
const encodedString = encoder.encode(str);
|
||||
const encodedString = encoder.encode(json_str);
|
||||
const fileSizeInBytes = encodedString.length;
|
||||
const fileSizeInMB = fileSizeInBytes / (1024 * 1024); // Convert bytes to megabytes
|
||||
return fileSizeInMB < 5;
|
||||
@ -894,7 +905,7 @@ const App = () => {
|
||||
};
|
||||
|
||||
// Run once upon ReactFlow initialization
|
||||
const onInit = (rf_inst) => {
|
||||
const onInit = (rf_inst: ReactFlowInstance) => {
|
||||
setRfInstance(rf_inst);
|
||||
|
||||
if (IS_RUNNING_LOCALLY) {
|
@ -10,7 +10,14 @@
|
||||
* Descriptions of OpenAI model parameters copied from OpenAI's official chat completions documentation: https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
*/
|
||||
|
||||
import { LLMProvider, RATE_LIMITS, getProvider } from "./backend/models";
|
||||
import { LLM, LLMProvider, RATE_LIMITS, getProvider } from "./backend/models";
|
||||
import {
|
||||
Dict,
|
||||
JSONCompatible,
|
||||
CustomLLMProviderSpec,
|
||||
ModelSettingsDict,
|
||||
LLMSpec,
|
||||
} from "./backend/typing";
|
||||
import { transformDict } from "./backend/utils";
|
||||
import useStore from "./store";
|
||||
|
||||
@ -21,9 +28,9 @@ const UI_SUBMIT_BUTTON_SPEC = {
|
||||
},
|
||||
norender: false,
|
||||
submitText: "Submit",
|
||||
};
|
||||
} satisfies Dict;
|
||||
|
||||
const ChatGPTSettings = {
|
||||
const ChatGPTSettings: ModelSettingsDict = {
|
||||
fullName: "GPT-3.5+ (OpenAI)",
|
||||
schema: {
|
||||
type: "object",
|
||||
@ -224,29 +231,32 @@ const ChatGPTSettings = {
|
||||
},
|
||||
|
||||
postprocessors: {
|
||||
functions: (str) => {
|
||||
functions: (str: string | number | boolean) => {
|
||||
if (typeof str !== "string") return str;
|
||||
if (str.trim().length === 0) return [];
|
||||
return JSON.parse(str); // parse the JSON schema
|
||||
},
|
||||
function_call: (str) => {
|
||||
function_call: (str: string | number | boolean) => {
|
||||
if (typeof str !== "string") return str;
|
||||
const s = str.trim();
|
||||
if (s.length === 0) return "";
|
||||
if (s === "auto" || s === "none") return s;
|
||||
else return { name: s };
|
||||
},
|
||||
stop: (str) => {
|
||||
stop: (str: string | number | boolean) => {
|
||||
if (typeof str !== "string") return str;
|
||||
if (str.trim().length === 0) return [];
|
||||
return str
|
||||
.match(/"((?:[^"\\]|\\.)*)"/g)
|
||||
.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
|
||||
?.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
|
||||
},
|
||||
response_format: (str) => {
|
||||
response_format: (str: string | number | boolean) => {
|
||||
return { type: str };
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const GPT4Settings = {
|
||||
const GPT4Settings: ModelSettingsDict = {
|
||||
fullName: ChatGPTSettings.fullName,
|
||||
schema: {
|
||||
type: "object",
|
||||
@ -275,7 +285,7 @@ const GPT4Settings = {
|
||||
postprocessors: ChatGPTSettings.postprocessors,
|
||||
};
|
||||
|
||||
const ClaudeSettings = {
|
||||
const ClaudeSettings: ModelSettingsDict = {
|
||||
fullName: "Claude (Anthropic)",
|
||||
schema: {
|
||||
type: "object",
|
||||
@ -420,16 +430,17 @@ const ClaudeSettings = {
|
||||
},
|
||||
|
||||
postprocessors: {
|
||||
stop_sequences: (str) => {
|
||||
stop_sequences: (str: string | number | boolean) => {
|
||||
if (typeof str !== "string") return str;
|
||||
if (str.trim().length === 0) return ["\n\nHuman:"];
|
||||
return str
|
||||
.match(/"((?:[^"\\]|\\.)*)"/g)
|
||||
.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
|
||||
?.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const PaLM2Settings = {
|
||||
const PaLM2Settings: ModelSettingsDict = {
|
||||
fullName: "Google AI Models (Gemini & PaLM)",
|
||||
schema: {
|
||||
type: "object",
|
||||
@ -531,16 +542,17 @@ const PaLM2Settings = {
|
||||
},
|
||||
|
||||
postprocessors: {
|
||||
stop_sequences: (str) => {
|
||||
stop_sequences: (str: string | number | boolean) => {
|
||||
if (typeof str !== "string") return str;
|
||||
if (str.trim().length === 0) return [];
|
||||
return str
|
||||
.match(/"((?:[^"\\]|\\.)*)"/g)
|
||||
.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
|
||||
?.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const DalaiModelSettings = {
|
||||
const DalaiModelSettings: ModelSettingsDict = {
|
||||
fullName: "Dalai-hosted local model (Alpaca, Llama)",
|
||||
schema: {
|
||||
type: "object",
|
||||
@ -691,7 +703,7 @@ const DalaiModelSettings = {
|
||||
postprocessors: {},
|
||||
};
|
||||
|
||||
const AzureOpenAISettings = {
|
||||
const AzureOpenAISettings: ModelSettingsDict = {
|
||||
fullName: "Azure OpenAI Model",
|
||||
schema: {
|
||||
type: "object",
|
||||
@ -738,7 +750,7 @@ const AzureOpenAISettings = {
|
||||
postprocessors: ChatGPTSettings.postprocessors,
|
||||
};
|
||||
|
||||
const HuggingFaceTextInferenceSettings = {
|
||||
const HuggingFaceTextInferenceSettings: ModelSettingsDict = {
|
||||
fullName: "HuggingFace-hosted text generation models",
|
||||
schema: {
|
||||
type: "object",
|
||||
@ -910,7 +922,7 @@ const HuggingFaceTextInferenceSettings = {
|
||||
postprocessors: {},
|
||||
};
|
||||
|
||||
const AlephAlphaLuminousSettings = {
|
||||
const AlephAlphaLuminousSettings: ModelSettingsDict = {
|
||||
fullName: "Aleph Alpha Luminous",
|
||||
schema: {
|
||||
type: "object",
|
||||
@ -1111,22 +1123,25 @@ const AlephAlphaLuminousSettings = {
|
||||
},
|
||||
},
|
||||
postprocessors: {
|
||||
stop_sequences: (str) => {
|
||||
stop_sequences: (str: string | number | boolean) => {
|
||||
if (typeof str !== "string") return str;
|
||||
if (str.trim().length === 0) return [];
|
||||
return str
|
||||
.match(/"((?:[^"\\]|\\.)*)"/g)
|
||||
.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
|
||||
?.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
|
||||
},
|
||||
log_probs: (bool) => {
|
||||
log_probs: (bool: boolean | number | string) => {
|
||||
if (typeof bool !== "boolean") return bool;
|
||||
return bool ? 3 : null;
|
||||
},
|
||||
best_of: (a) => {
|
||||
best_of: (a: number | string | boolean) => {
|
||||
if (typeof a !== "number") return a;
|
||||
return a === 1 ? null : a;
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const OllamaSettings = {
|
||||
const OllamaSettings: ModelSettingsDict = {
|
||||
fullName: "Ollama",
|
||||
schema: {
|
||||
type: "object",
|
||||
@ -1234,11 +1249,12 @@ const OllamaSettings = {
|
||||
},
|
||||
},
|
||||
postprocessors: {
|
||||
stop_sequences: (str) => {
|
||||
stop_sequences: (str: string | number | boolean) => {
|
||||
if (typeof str !== "string") return str;
|
||||
if (str.trim().length === 0) return [];
|
||||
return str
|
||||
.match(/"((?:[^"\\]|\\.)*)"/g)
|
||||
.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
|
||||
?.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
|
||||
},
|
||||
},
|
||||
};
|
||||
@ -1929,7 +1945,7 @@ const MetaLlama2ChatSettings = {
|
||||
};
|
||||
|
||||
// A lookup table indexed by base_model.
|
||||
export const ModelSettings = {
|
||||
export const ModelSettings: Dict<ModelSettingsDict> = {
|
||||
"gpt-3.5-turbo": ChatGPTSettings,
|
||||
"gpt-4": GPT4Settings,
|
||||
"claude-v1": ClaudeSettings,
|
||||
@ -1948,7 +1964,9 @@ export const ModelSettings = {
|
||||
"br.meta.llama2": MetaLlama2ChatSettings,
|
||||
};
|
||||
|
||||
export function getSettingsSchemaForLLM(llm_name) {
|
||||
export function getSettingsSchemaForLLM(
|
||||
llm_name: string,
|
||||
): ModelSettingsDict | undefined {
|
||||
const llm_provider = getProvider(llm_name);
|
||||
|
||||
const provider_to_settings_schema = {
|
||||
@ -1963,13 +1981,13 @@ export function getSettingsSchemaForLLM(llm_name) {
|
||||
};
|
||||
|
||||
if (llm_provider === LLMProvider.Custom) return ModelSettings[llm_name];
|
||||
else if (llm_provider in provider_to_settings_schema)
|
||||
else if (llm_provider && llm_provider in provider_to_settings_schema)
|
||||
return provider_to_settings_schema[llm_provider];
|
||||
else if (llm_provider === LLMProvider.Bedrock) {
|
||||
return ModelSettings[llm_name.split("-")[0]];
|
||||
} else {
|
||||
console.error(`Could not find provider for llm ${llm_name}`);
|
||||
return {};
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1978,7 +1996,10 @@ export function getSettingsSchemaForLLM(llm_name) {
|
||||
* @param {*} settings_dict A dict of form setting_name: value (string: string)
|
||||
* @param {*} llm A string of the name of the model to query.
|
||||
*/
|
||||
export function typecastSettingsDict(settings_dict, llm) {
|
||||
export function typecastSettingsDict(
|
||||
settings_dict: ModelSettingsDict,
|
||||
llm: string,
|
||||
) {
|
||||
const settings = getSettingsSchemaForLLM(llm);
|
||||
const schema = settings?.schema?.properties ?? {};
|
||||
const postprocessors = settings?.postprocessors ?? {};
|
||||
@ -2014,16 +2035,16 @@ export function typecastSettingsDict(settings_dict, llm) {
|
||||
* @param {*} settings_schema
|
||||
*/
|
||||
export const setCustomProvider = (
|
||||
name,
|
||||
emoji,
|
||||
models,
|
||||
rate_limit,
|
||||
settings_schema,
|
||||
name: string,
|
||||
emoji: string,
|
||||
models?: string[],
|
||||
rate_limit?: number,
|
||||
settings_schema?: CustomLLMProviderSpec["settings_schema"],
|
||||
) => {
|
||||
if (typeof emoji === "string" && (emoji.length === 0 || emoji.length > 2))
|
||||
throw new Error(`Emoji for a custom provider must have a character.`);
|
||||
|
||||
const new_provider = { name };
|
||||
const new_provider: Dict<JSONCompatible> = { name };
|
||||
new_provider.emoji = emoji || "✨";
|
||||
|
||||
// Each LLM *model* must have a unique name. To avoid name collisions, for custom providers,
|
||||
@ -2036,7 +2057,7 @@ export const setCustomProvider = (
|
||||
(Array.isArray(models) && models.length > 0 ? `${models[0]}` : "");
|
||||
|
||||
// Build the settings form schema for this new custom provider
|
||||
const compiled_schema = {
|
||||
const compiled_schema: ModelSettingsDict = {
|
||||
fullName: `${name} (custom provider)`,
|
||||
schema: {
|
||||
type: "object",
|
||||
@ -2057,6 +2078,7 @@ export const setCustomProvider = (
|
||||
"ui:autofocus": true,
|
||||
},
|
||||
},
|
||||
postprocessors: {},
|
||||
};
|
||||
|
||||
// Add a models selector if there's multiple models
|
||||
@ -2093,8 +2115,9 @@ export const setCustomProvider = (
|
||||
// Add the built provider and its settings to the global lookups:
|
||||
const AvailableLLMs = useStore.getState().AvailableLLMs;
|
||||
const prev_provider_idx = AvailableLLMs.findIndex((d) => d.name === name);
|
||||
if (prev_provider_idx > -1) AvailableLLMs[prev_provider_idx] = new_provider;
|
||||
else AvailableLLMs.push(new_provider);
|
||||
if (prev_provider_idx > -1)
|
||||
AvailableLLMs[prev_provider_idx] = new_provider as LLMSpec;
|
||||
else AvailableLLMs.push(new_provider as LLMSpec);
|
||||
ModelSettings[base_model] = compiled_schema;
|
||||
|
||||
// Add rate limit info, if specified
|
||||
@ -2113,7 +2136,7 @@ export const setCustomProvider = (
|
||||
useStore.getState().setAvailableLLMs(AvailableLLMs);
|
||||
};
|
||||
|
||||
export const setCustomProviders = (providers) => {
|
||||
export const setCustomProviders = (providers: CustomLLMProviderSpec[]) => {
|
||||
for (const p of providers)
|
||||
setCustomProvider(
|
||||
p.name,
|
||||
@ -2124,7 +2147,7 @@ export const setCustomProviders = (providers) => {
|
||||
);
|
||||
};
|
||||
|
||||
export const getTemperatureSpecForModel = (modelName) => {
|
||||
export const getTemperatureSpecForModel = (modelName: string) => {
|
||||
if (modelName in ModelSettings) {
|
||||
const temperature_property =
|
||||
ModelSettings[modelName].schema?.properties?.temperature;
|
||||
@ -2139,11 +2162,14 @@ export const getTemperatureSpecForModel = (modelName) => {
|
||||
return null;
|
||||
};
|
||||
|
||||
export const postProcessFormData = (settingsSpec, formData) => {
|
||||
export const postProcessFormData = (
|
||||
settingsSpec: ModelSettingsDict,
|
||||
formData: Dict<JSONCompatible>,
|
||||
) => {
|
||||
// Strip all 'model' and 'shortname' props in the submitted form, as these are passed elsewhere or unecessary for the backend
|
||||
const skip_keys = { model: true, shortname: true };
|
||||
|
||||
const new_data = {};
|
||||
const new_data: Dict<JSONCompatible> = {};
|
||||
const postprocessors = settingsSpec?.postprocessors
|
||||
? settingsSpec.postprocessors
|
||||
: {};
|
||||
@ -2151,28 +2177,32 @@ export const postProcessFormData = (settingsSpec, formData) => {
|
||||
Object.keys(formData).forEach((key) => {
|
||||
if (key in skip_keys) return;
|
||||
if (key in postprocessors)
|
||||
new_data[key] = postprocessors[key](formData[key]);
|
||||
new_data[key] = postprocessors[key](
|
||||
formData[key] as string | number | boolean,
|
||||
);
|
||||
else new_data[key] = formData[key];
|
||||
});
|
||||
|
||||
return new_data;
|
||||
};
|
||||
|
||||
export const getDefaultModelFormData = (settingsSpec) => {
|
||||
export const getDefaultModelFormData = (
|
||||
settingsSpec: string | ModelSettingsDict,
|
||||
) => {
|
||||
if (typeof settingsSpec === "string")
|
||||
settingsSpec = ModelSettings[settingsSpec];
|
||||
const default_formdata = {};
|
||||
const default_formdata: Dict<JSONCompatible> = {};
|
||||
const schema = settingsSpec.schema;
|
||||
Object.keys(schema.properties).forEach((key) => {
|
||||
default_formdata[key] =
|
||||
"default" in schema.properties[key]
|
||||
? schema.properties[key].default
|
||||
: undefined;
|
||||
: null;
|
||||
});
|
||||
return default_formdata;
|
||||
};
|
||||
|
||||
export const getDefaultModelSettings = (modelName) => {
|
||||
export const getDefaultModelSettings = (modelName: string) => {
|
||||
if (!(modelName in ModelSettings)) {
|
||||
console.warn(
|
||||
`Model ${modelName} not found in list of available model settings.`,
|
@ -6,6 +6,7 @@ import {
|
||||
RawLLMResponseObject,
|
||||
isEqualChatHistory,
|
||||
ChatHistoryInfo,
|
||||
ModelSettingsDict,
|
||||
} from "./typing";
|
||||
import {
|
||||
extract_responses,
|
||||
@ -288,7 +289,13 @@ export class PromptPipeline {
|
||||
num_queries_sent,
|
||||
max_req,
|
||||
wait_secs,
|
||||
{ ...llm_params, ...typecastSettingsDict(settings_params, llm) },
|
||||
{
|
||||
...llm_params,
|
||||
...typecastSettingsDict(
|
||||
settings_params as ModelSettingsDict,
|
||||
llm,
|
||||
),
|
||||
},
|
||||
chat_history,
|
||||
should_cancel,
|
||||
),
|
||||
@ -305,7 +312,13 @@ export class PromptPipeline {
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
{ ...llm_params, ...typecastSettingsDict(settings_params, llm) },
|
||||
{
|
||||
...llm_params,
|
||||
...typecastSettingsDict(
|
||||
settings_params as ModelSettingsDict,
|
||||
llm,
|
||||
),
|
||||
},
|
||||
chat_history,
|
||||
should_cancel,
|
||||
);
|
||||
|
@ -16,6 +16,13 @@ export interface Dict<T = any> {
|
||||
// Function types
|
||||
export type Func<T = void> = (...args: any[]) => T;
|
||||
|
||||
// JSON-compatible leaf types
|
||||
export type JSONLeaf = string | number | boolean | null;
|
||||
export type JSONCompatible<T = JSONLeaf> =
|
||||
| T
|
||||
| JSONCompatible<T>[]
|
||||
| Dict<JSONCompatible<T>>;
|
||||
|
||||
/** OpenAI function call format */
|
||||
export interface OpenAIFunctionCall {
|
||||
name: string;
|
||||
@ -110,6 +117,30 @@ export type LLMSpec = {
|
||||
progress?: Dict<number>; // only used for front-end to display progress collecting responses for this LLM
|
||||
};
|
||||
|
||||
/** A spec for a user-defined custom LLM provider */
|
||||
export type CustomLLMProviderSpec = {
|
||||
name: string;
|
||||
emoji: string;
|
||||
models?: string[];
|
||||
rate_limit?: number;
|
||||
settings_schema?: {
|
||||
settings: Dict<Dict<JSONCompatible>>;
|
||||
ui: Dict<Dict<JSONCompatible>>;
|
||||
};
|
||||
};
|
||||
|
||||
/** Internal description of model settings, passed to react-json-schema */
|
||||
export interface ModelSettingsDict {
|
||||
fullName: string;
|
||||
schema: {
|
||||
type: "object";
|
||||
required: string[];
|
||||
properties: Dict<Dict<JSONCompatible>>;
|
||||
};
|
||||
uiSchema: Dict<JSONCompatible>;
|
||||
postprocessors: Dict<(val: string | number | boolean) => any>;
|
||||
}
|
||||
|
||||
export type ResponseUID = string;
|
||||
|
||||
/** Standard properties that every LLM response object must have. */
|
||||
|
@ -22,8 +22,9 @@ import {
|
||||
PromptVarType,
|
||||
PromptVarsDict,
|
||||
TemplateVarInfo,
|
||||
TabularDataColType,
|
||||
TabularDataRowType,
|
||||
} from "./backend/typing";
|
||||
import { TabularDataColType, TabularDataRowType } from "./TabularDataNode";
|
||||
|
||||
// Initial project settings
|
||||
const initialAPIKeys = {};
|
||||
@ -227,7 +228,7 @@ export const initLLMProviders = initLLMProviderMenu
|
||||
.map((item) => ("group" in item && "items" in item ? item.items : item))
|
||||
.flat();
|
||||
|
||||
interface StoreHandles {
|
||||
export interface StoreHandles {
|
||||
// Nodes and edges
|
||||
nodes: Node[];
|
||||
edges: Edge[];
|
||||
|
Loading…
x
Reference in New Issue
Block a user