From d4e0630564b48165b7a353855f7debf467e1eaa3 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Thu, 4 May 2023 20:29:46 -0400 Subject: [PATCH] Dynamic run tooltip for prompt node --- chain-forge/src/EvaluatorNode.js | 1 + chain-forge/src/LLMItemButtonGroup.js | 8 +- chain-forge/src/NodeLabelComponent.js | 16 ++- chain-forge/src/PromptNode.js | 139 +++++++++++++++++--------- chain-forge/src/text-fields-node.css | 8 ++ python-backend/app.py | 6 +- python-backend/promptengine/query.py | 2 +- python-backend/promptengine/utils.py | 4 +- 8 files changed, 126 insertions(+), 58 deletions(-) diff --git a/chain-forge/src/EvaluatorNode.js b/chain-forge/src/EvaluatorNode.js index aa23e21..8f66b3e 100644 --- a/chain-forge/src/EvaluatorNode.js +++ b/chain-forge/src/EvaluatorNode.js @@ -155,6 +155,7 @@ const EvaluatorNode = ({ data, id }) => { status={status} alertModal={alertModal} handleRunClick={handleRunClick} + runButtonTooltip="Run evaluator over inputs" /> - {(ringProgress) ? - () : - <> + {ringProgress !== undefined ? + (ringProgress > 0 ? + () : + (
)) + : (<>) } diff --git a/chain-forge/src/NodeLabelComponent.js b/chain-forge/src/NodeLabelComponent.js index fe2ff33..81a2686 100644 --- a/chain-forge/src/NodeLabelComponent.js +++ b/chain-forge/src/NodeLabelComponent.js @@ -4,9 +4,9 @@ import 'react-edit-text/dist/index.css'; import StatusIndicator from './StatusIndicatorComponent'; import AlertModal from './AlertModal'; import { useState, useEffect} from 'react'; -import { CloseButton } from '@mantine/core'; +import { Tooltip } from '@mantine/core'; -export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editable, status, alertModal, handleRunClick }) { +export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editable, status, alertModal, handleRunClick, handleRunHover, runButtonTooltip }) { const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); const [statusIndicator, setStatusIndicator] = useState('none'); const [runButton, setRunButton] = useState('none'); @@ -33,12 +33,20 @@ export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editabl useEffect(() => { if(handleRunClick !== undefined) { - setRunButton(); + const run_btn = (); + if (runButtonTooltip) + setRunButton( + + {run_btn} + + ); + else + setRunButton(run_btn); } else { setRunButton(<>); } - }, [handleRunClick]); + }, [handleRunClick, runButtonTooltip]); const handleCloseButtonClick = () => { removeNode(nodeId); diff --git a/chain-forge/src/PromptNode.js b/chain-forge/src/PromptNode.js index 56b5a01..9466620 100644 --- a/chain-forge/src/PromptNode.js +++ b/chain-forge/src/PromptNode.js @@ -74,6 +74,7 @@ const PromptNode = ({ data, id }) => { // Progress when querying responses const [progress, setProgress] = useState(100); + const [runTooltip, setRunTooltip] = useState(null); const triggerAlert = (msg) => { setProgress(100); @@ -130,32 +131,10 @@ const PromptNode = ({ data, id }) => { } }; - const handleRunClick = (event) => { - // Go through all template hooks (if any) and check they're connected: - const is_fully_connected = templateVars.every(varname => { - // Check that some edge has, as its target, this node and its template hook: - return edges.some(e => (e.target == id && e.targetHandle == varname)); - }); - - if (!is_fully_connected) { - console.log('Not connected! :('); - triggerAlert('Missing inputs to one or more template variables.'); - return; - } - - console.log('Connected!'); - - // Check that there is at least one LLM selected: - if (llmItemsCurrState.length === 0) { - alert('Please select at least one LLM to prompt.') - return; - } - - // Set status indicator - setStatus('loading'); - setReponsePreviews([]); - - // Pull data from each source, recursively: + // Pull all inputs needed to request responses. + // Returns [prompt, vars dict] + const pullInputData = () => { + // Pull data from each source recursively: const pulled_data = {}; const get_outputs = (varnames, nodeId) => { varnames.forEach(varname => { @@ -189,7 +168,78 @@ const PromptNode = ({ data, id }) => { Object.keys(pulled_data).forEach(varname => { pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val)); }); - console.log(pulled_data); + + return [py_prompt_template, pulled_data]; + }; + + // Ask the backend how many responses it needs to collect, given the input data: + const fetchResponseCounts = (prompt, vars, llms, rejected) => { + return fetch(BASE_URL + 'api/countQueriesRequired', { + method: 'POST', + headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, + body: JSON.stringify({ + prompt: prompt, + vars: vars, + llms: llms, + })}, rejected).then(function(response) { + return response.json(); + }, rejected).then(function(json) { + if (!json || !json.counts) { + throw new Error('Request was sent and received by backend server, but there was no response.'); + } + return json.counts; + }, rejected); + }; + + // On hover over the 'Run' button, request how many responses are required and update the tooltip. Soft fails. + const handleRunHover = () => { + // Check if there's at least one model in the list; if not, nothing to run on. + if (!llmItemsCurrState || llmItemsCurrState.length == 0) { + setRunTooltip('No LLMs to query.'); + return; + } + + // Get input data and prompt + const [py_prompt, pulled_vars] = pullInputData(); + const llms = llmItemsCurrState.map(item => item.model); + const num_llms = llms.length; + + // Fetch response counts from backend + fetchResponseCounts(py_prompt, pulled_vars, llms, (err) => { + console.warn(err.message); // soft fail + }).then((counts) => { + const n = counts[Object.keys(counts)[0]]; + const req = n > 1 ? 'requests' : 'request'; + setRunTooltip(`Will send ${n} ${req}` + (num_llms > 1 ? ' per LLM' : '')); + }); + }; + + const handleRunClick = (event) => { + // Go through all template hooks (if any) and check they're connected: + const is_fully_connected = templateVars.every(varname => { + // Check that some edge has, as its target, this node and its template hook: + return edges.some(e => (e.target == id && e.targetHandle == varname)); + }); + + if (!is_fully_connected) { + console.log('Not connected! :('); + triggerAlert('Missing inputs to one or more template variables.'); + return; + } + + console.log('Connected!'); + + // Check that there is at least one LLM selected: + if (llmItemsCurrState.length === 0) { + alert('Please select at least one LLM to prompt.') + return; + } + + // Set status indicator + setStatus('loading'); + setReponsePreviews([]); + + const [py_prompt_template, pulled_data] = pullInputData(); let FINISHED_QUERY = false; const rejected = (err) => { @@ -208,27 +258,12 @@ const PromptNode = ({ data, id }) => { })}, rejected); }; - // Query the backend to ask how many responses it needs to collect, given the input data: - const fetch_resp_count = () => { - return fetch(BASE_URL + 'api/countQueriesRequired', { - method: 'POST', - headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, - body: JSON.stringify({ - prompt: py_prompt_template, - vars: pulled_data, - llms: llmItemsCurrState.map(item => item.model), - })}, rejected).then(function(response) { - return response.json(); - }, rejected).then(function(json) { - if (!json || !json.count) { - throw new Error('Request was sent and received by backend server, but there was no response.'); - } - return json.count; - }, rejected); - }; + // Fetch info about the number of queries we'll need to make + const fetch_resp_count = () => fetchResponseCounts( + py_prompt_template, pulled_data, llmItemsCurrState.map(item => item.model), rejected); // Open a socket to listen for progress - const open_progress_listener_socket = (max_responses) => { + const open_progress_listener_socket = (response_counts) => { // With the counts information we can create progress bars. Now we load a socket connection to // the socketio server that will stream to us the current progress: const socket = io('http://localhost:8001/' + 'queryllm', { @@ -236,9 +271,19 @@ const PromptNode = ({ data, id }) => { cors: {origin: "http://localhost:3000/"}, }); + const max_responses = Object.keys(response_counts).reduce((acc, llm) => acc + response_counts[llm], 0); + // On connect to the server, ask it to give us the current progress // for task 'queryllm' with id 'id', and stop when it reads progress >= 'max'. socket.on("connect", (msg) => { + // Initialize progress bars to small amounts + setProgress(5); + setLLMItems(llmItemsCurrState.map(item => { + item.progress = 0; + return item; + })); + + // Request progress bar updates socket.emit("queryllm", {'id': id, 'max': max_responses}); }); socket.on("disconnect", (msg) => { @@ -380,6 +425,8 @@ const PromptNode = ({ data, id }) => { status={status} alertModal={alertModal} handleRunClick={handleRunClick} + handleRunHover={handleRunHover} + runButtonTooltip={runTooltip} />