diff --git a/chain-forge/src/InspectorNode.js b/chain-forge/src/InspectorNode.js index 4a1720b..f107dc0 100644 --- a/chain-forge/src/InspectorNode.js +++ b/chain-forge/src/InspectorNode.js @@ -1,9 +1,24 @@ import React, { useState, useEffect } from 'react'; import { Handle } from 'react-flow-renderer'; +import { Badge } from '@mantine/core'; import useStore from './store'; import NodeLabel from './NodeLabelComponent' import {BASE_URL} from './store'; +// Helper funcs +const truncStr = (s, maxLen) => { + if (s.length > maxLen) // Cut the name short if it's long + return s.substring(0, maxLen) + '...' + else + return s; +} +const vars_to_str = (vars) => { + const pairs = Object.keys(vars).map(varname => { + const s = truncStr(vars[varname].trim(), 12); + return `${varname} = '${s}'`; + }); + return pairs; +}; const bucketResponsesByLLM = (responses) => { let responses_by_llm = {}; responses.forEach(item => { @@ -58,27 +73,19 @@ const InspectorNode = ({ data, id }) => { // // Create a dict version // let tempvars = {}; - const vars_to_str = (vars) => { - const pairs = Object.keys(vars).map(varname => { - let s = vars[varname].trim(); - if (s.length > 12) - s = s.substring(0, 12) + '...' - return `${varname} = '${s}'`; - }); - return pairs.join('; '); - }; - - const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9']; + const colors = ['#ace1aeb1', '#f1b963b1', '#e46161b1', '#f8f398b1', '#defcf9b1', '#cadefcb1', '#c3bef0b1', '#cca8e9b1']; setResponses(Object.keys(responses_by_llm).map((llm, llm_idx) => { const res_divs = responses_by_llm[llm].map((res_obj, res_idx) => { const ps = res_obj.responses.map((r, idx) => (
{r}
) ); - // Object.keys(res_obj.vars).forEach(v => {tempvars[v].add(res_obj.vars[v])}); const vars = vars_to_str(res_obj.vars); + const var_tags = vars.map((v) => ( + {v} + )); return (
-

{vars}

+ {var_tags} {ps}
); diff --git a/chain-forge/src/NodeLabelComponent.js b/chain-forge/src/NodeLabelComponent.js index 81a2686..8715a00 100644 --- a/chain-forge/src/NodeLabelComponent.js +++ b/chain-forge/src/NodeLabelComponent.js @@ -36,7 +36,7 @@ export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editabl const run_btn = (); if (runButtonTooltip) setRunButton( - + {run_btn} ); diff --git a/chain-forge/src/PromptNode.js b/chain-forge/src/PromptNode.js index 0aa588c..134d987 100644 --- a/chain-forge/src/PromptNode.js +++ b/chain-forge/src/PromptNode.js @@ -54,17 +54,17 @@ const PromptNode = ({ data, id }) => { const getNode = useStore((state) => state.getNode); const [templateVars, setTemplateVars] = useState(data.vars || []); - const [promptText, setPromptText] = useState(data.prompt); + const [promptText, setPromptText] = useState(data.prompt || ""); const [promptTextOnLastRun, setPromptTextOnLastRun] = useState(null); const [status, setStatus] = useState('none'); - const [responsePreviews, setReponsePreviews] = useState([]); + const [responsePreviews, setResponsePreviews] = useState([]); const [numGenerations, setNumGenerations] = useState(data.n || 1); // For displaying error messages to user const alertModal = useRef(null); // Selecting LLM models to prompt - const [llmItems, setLLMItems] = useState(initLLMs.map((i, idx) => ({key: uuid(), ...i}))); + const [llmItems, setLLMItems] = useState(data.llms || initLLMs.map((i) => ({key: uuid(), ...i}))); const [llmItemsCurrState, setLLMItemsCurrState] = useState([]); const resetLLMItemsProgress = useCallback(() => { setLLMItems(llmItemsCurrState.map(item => { @@ -101,8 +101,23 @@ const PromptNode = ({ data, id }) => { const onLLMListItemsChange = useCallback((new_items) => { setLLMItemsCurrState(new_items); + setDataPropsForNode(id, { llms: new_items }); }, [setLLMItemsCurrState]); + const refreshTemplateHooks = (text) => { + // Update template var fields + handles + const braces_regex = /(? 0) { + const temp_var_names = found_template_vars.map( + name => name.substring(1, name.length-1) // remove brackets {} + ) + setTemplateVars(temp_var_names); + } else { + setTemplateVars([]); + } + }; + const handleInputChange = (event) => { const value = event.target.value; @@ -119,19 +134,14 @@ const PromptNode = ({ data, id }) => { } } - // Update template var fields + handles - const braces_regex = /(? 0) { - const temp_var_names = found_template_vars.map( - name => name.substring(1, name.length-1) // remove brackets {} - ) - setTemplateVars(temp_var_names); - } else { - setTemplateVars([]); - } + refreshTemplateHooks(value); }; + // On initialization + useEffect(() => { + refreshTemplateHooks(promptText); + }, []); + // Pull all inputs needed to request responses. // Returns [prompt, vars dict] const pullInputData = () => { @@ -144,6 +154,7 @@ const PromptNode = ({ data, id }) => { if (e.target == nodeId && e.targetHandle == varname) { // Get the immediate output: let out = output(e.source, e.sourceHandle); + if (!out) return; // Save the var data from the pulled output if (varname in pulled_data) @@ -284,7 +295,7 @@ const PromptNode = ({ data, id }) => { // Set status indicator setStatus('loading'); - setReponsePreviews([]); + setResponsePreviews([]); const [py_prompt_template, pulled_data] = pullInputData(); @@ -408,12 +419,15 @@ const PromptNode = ({ data, id }) => { // Save prompt text so we remember what prompt we have responses cache'd for: setPromptTextOnLastRun(promptText); + // Save response texts as 'fields' of data, for any prompt nodes pulling the outputs + setDataPropsForNode(id, {fields: json.responses.map(r => r['responses']).flat()}); + // Save preview strings of responses, for quick glance // Bucket responses by LLM: const responses_by_llm = bucketResponsesByLLM(json.responses); // const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9']; // const colors = ['green', 'yellow', 'orange', 'red', 'pink', 'grape', 'violet', 'indigo', 'blue', 'gray', 'cyan', 'lime']; - setReponsePreviews(Object.keys(responses_by_llm).map((llm, llm_idx) => { + setResponsePreviews(Object.keys(responses_by_llm).map((llm, llm_idx) => { const resp_boxes = responses_by_llm[llm].map((res_obj, idx) => { const num_resp = res_obj['responses'].length; const resp_prevs = res_obj['responses'].map((r, i) => @@ -421,7 +435,7 @@ const PromptNode = ({ data, id }) => { ); const vars = vars_to_str(res_obj.vars); const var_tags = vars.map((v, i) => ( - {v} + {v} )); return (
diff --git a/chain-forge/src/store.js b/chain-forge/src/store.js index 4108387..360206e 100644 --- a/chain-forge/src/store.js +++ b/chain-forge/src/store.js @@ -60,7 +60,10 @@ const useStore = create((set, get) => ({ if (src_node) { // Get the data related to that handle: if ("fields" in src_node.data) { - return Object.values(src_node.data["fields"]); + if (Array.isArray(src_node.data["fields"])) + return src_node.data["fields"]; + else + return Object.values(src_node.data["fields"]); } // NOTE: This assumes it's on the 'data' prop, with the same id as the handle: else return src_node.data[sourceHandleKey]; diff --git a/chain-forge/src/text-fields-node.css b/chain-forge/src/text-fields-node.css index 7f00cc1..a5a53ec 100644 --- a/chain-forge/src/text-fields-node.css +++ b/chain-forge/src/text-fields-node.css @@ -163,7 +163,8 @@ } .inspect-response-container { overflow-y: auto; - width: 450px; + min-width: 150px; + max-width: 450px; max-height: 350px; resize: both; } @@ -172,9 +173,9 @@ font-size: 8pt; font-family: monospace; border-style: dotted; - border-color: #aaa; + border-color: #fff; padding: 2px; - margin: 0px; + margin: 2px 1px; background-color: rgba(255, 255, 255, 0.4); white-space: pre-wrap; } @@ -199,6 +200,7 @@ .response-box { padding: 2px; margin: 0px 2px 4px 2px; + border-radius: 5px; } .response-tag { font-size: 9pt; diff --git a/python-backend/promptengine/query.py b/python-backend/promptengine/query.py index 56f336e..0633df7 100644 --- a/python-backend/promptengine/query.py +++ b/python-backend/promptengine/query.py @@ -70,7 +70,9 @@ class PromptPipeline: "responses": extracted_resps[:n], "raw_response": cached_resp["raw_response"], "llm": cached_resp["llm"] if "llm" in cached_resp else LLM.ChatGPT.value, - "info": cached_resp["info"], + # We want to use the new info, since 'vars' could have changed even though + # the prompt text is the same (e.g., "this is a tool -> this is a {x} where x='tool'") + "info": prompt.fill_history, } continue