Dynamic run tooltip for prompt node

This commit is contained in:
Ian Arawjo 2023-05-04 20:29:46 -04:00
parent ccf6880da4
commit d4e0630564
8 changed files with 126 additions and 58 deletions

View File

@ -155,6 +155,7 @@ const EvaluatorNode = ({ data, id }) => {
status={status}
alertModal={alertModal}
handleRunClick={handleRunClick}
runButtonTooltip="Run evaluator over inputs"
/>
<Handle
type="target"

View File

@ -12,9 +12,11 @@ export default function LLMItemButtonGroup( {onClickTrash, onClickSettings, ring
</Modal>
<Group position="right" style={{float: 'right', height:'20px'}}>
{(ringProgress) ?
(<RingProgress size={20} thickness={3} sections={[{ value: ringProgress, color: ringProgress < 99 ? 'blue' : 'green' }]} width='16px' />) :
<></>
{ringProgress !== undefined ?
(ringProgress > 0 ?
(<RingProgress size={20} thickness={3} sections={[{ value: ringProgress, color: ringProgress < 99 ? 'blue' : 'green' }]} width='16px' />) :
(<div className="lds-ring"><div></div><div></div><div></div><div></div></div>))
: (<></>)
}
<Button onClick={onClickTrash} size="xs" variant="light" compact color="red" style={{padding: '0px'}} ><IconTrash size={"95%"} /></Button>
<Button onClick={onClickSettings} size="xs" variant="light" compact>Settings&nbsp;<IconSettings size={"110%"} /></Button>

View File

@ -4,9 +4,9 @@ import 'react-edit-text/dist/index.css';
import StatusIndicator from './StatusIndicatorComponent';
import AlertModal from './AlertModal';
import { useState, useEffect} from 'react';
import { CloseButton } from '@mantine/core';
import { Tooltip } from '@mantine/core';
export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editable, status, alertModal, handleRunClick }) {
export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editable, status, alertModal, handleRunClick, handleRunHover, runButtonTooltip }) {
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const [statusIndicator, setStatusIndicator] = useState('none');
const [runButton, setRunButton] = useState('none');
@ -33,12 +33,20 @@ export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editabl
useEffect(() => {
if(handleRunClick !== undefined) {
setRunButton(<button className="AmitSahoo45-button-3 nodrag" onClick={handleRunClick}>&#9654;</button>);
const run_btn = (<button className="AmitSahoo45-button-3 nodrag" onClick={handleRunClick} onPointerEnter={handleRunHover}>&#9654;</button>);
if (runButtonTooltip)
setRunButton(
<Tooltip label={runButtonTooltip} withArrow arrowSize={6} arrowRadius={2}>
{run_btn}
</Tooltip>
);
else
setRunButton(run_btn);
}
else {
setRunButton(<></>);
}
}, [handleRunClick]);
}, [handleRunClick, runButtonTooltip]);
const handleCloseButtonClick = () => {
removeNode(nodeId);

View File

@ -74,6 +74,7 @@ const PromptNode = ({ data, id }) => {
// Progress when querying responses
const [progress, setProgress] = useState(100);
const [runTooltip, setRunTooltip] = useState(null);
const triggerAlert = (msg) => {
setProgress(100);
@ -130,32 +131,10 @@ const PromptNode = ({ data, id }) => {
}
};
const handleRunClick = (event) => {
// Go through all template hooks (if any) and check they're connected:
const is_fully_connected = templateVars.every(varname => {
// Check that some edge has, as its target, this node and its template hook:
return edges.some(e => (e.target == id && e.targetHandle == varname));
});
if (!is_fully_connected) {
console.log('Not connected! :(');
triggerAlert('Missing inputs to one or more template variables.');
return;
}
console.log('Connected!');
// Check that there is at least one LLM selected:
if (llmItemsCurrState.length === 0) {
alert('Please select at least one LLM to prompt.')
return;
}
// Set status indicator
setStatus('loading');
setReponsePreviews([]);
// Pull data from each source, recursively:
// Pull all inputs needed to request responses.
// Returns [prompt, vars dict]
const pullInputData = () => {
// Pull data from each source recursively:
const pulled_data = {};
const get_outputs = (varnames, nodeId) => {
varnames.forEach(varname => {
@ -189,7 +168,78 @@ const PromptNode = ({ data, id }) => {
Object.keys(pulled_data).forEach(varname => {
pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val));
});
console.log(pulled_data);
return [py_prompt_template, pulled_data];
};
// Ask the backend how many responses it needs to collect, given the input data:
const fetchResponseCounts = (prompt, vars, llms, rejected) => {
return fetch(BASE_URL + 'api/countQueriesRequired', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
prompt: prompt,
vars: vars,
llms: llms,
})}, rejected).then(function(response) {
return response.json();
}, rejected).then(function(json) {
if (!json || !json.counts) {
throw new Error('Request was sent and received by backend server, but there was no response.');
}
return json.counts;
}, rejected);
};
// On hover over the 'Run' button, request how many responses are required and update the tooltip. Soft fails.
const handleRunHover = () => {
// Check if there's at least one model in the list; if not, nothing to run on.
if (!llmItemsCurrState || llmItemsCurrState.length == 0) {
setRunTooltip('No LLMs to query.');
return;
}
// Get input data and prompt
const [py_prompt, pulled_vars] = pullInputData();
const llms = llmItemsCurrState.map(item => item.model);
const num_llms = llms.length;
// Fetch response counts from backend
fetchResponseCounts(py_prompt, pulled_vars, llms, (err) => {
console.warn(err.message); // soft fail
}).then((counts) => {
const n = counts[Object.keys(counts)[0]];
const req = n > 1 ? 'requests' : 'request';
setRunTooltip(`Will send ${n} ${req}` + (num_llms > 1 ? ' per LLM' : ''));
});
};
const handleRunClick = (event) => {
// Go through all template hooks (if any) and check they're connected:
const is_fully_connected = templateVars.every(varname => {
// Check that some edge has, as its target, this node and its template hook:
return edges.some(e => (e.target == id && e.targetHandle == varname));
});
if (!is_fully_connected) {
console.log('Not connected! :(');
triggerAlert('Missing inputs to one or more template variables.');
return;
}
console.log('Connected!');
// Check that there is at least one LLM selected:
if (llmItemsCurrState.length === 0) {
alert('Please select at least one LLM to prompt.')
return;
}
// Set status indicator
setStatus('loading');
setReponsePreviews([]);
const [py_prompt_template, pulled_data] = pullInputData();
let FINISHED_QUERY = false;
const rejected = (err) => {
@ -208,27 +258,12 @@ const PromptNode = ({ data, id }) => {
})}, rejected);
};
// Query the backend to ask how many responses it needs to collect, given the input data:
const fetch_resp_count = () => {
return fetch(BASE_URL + 'api/countQueriesRequired', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
prompt: py_prompt_template,
vars: pulled_data,
llms: llmItemsCurrState.map(item => item.model),
})}, rejected).then(function(response) {
return response.json();
}, rejected).then(function(json) {
if (!json || !json.count) {
throw new Error('Request was sent and received by backend server, but there was no response.');
}
return json.count;
}, rejected);
};
// Fetch info about the number of queries we'll need to make
const fetch_resp_count = () => fetchResponseCounts(
py_prompt_template, pulled_data, llmItemsCurrState.map(item => item.model), rejected);
// Open a socket to listen for progress
const open_progress_listener_socket = (max_responses) => {
const open_progress_listener_socket = (response_counts) => {
// With the counts information we can create progress bars. Now we load a socket connection to
// the socketio server that will stream to us the current progress:
const socket = io('http://localhost:8001/' + 'queryllm', {
@ -236,9 +271,19 @@ const PromptNode = ({ data, id }) => {
cors: {origin: "http://localhost:3000/"},
});
const max_responses = Object.keys(response_counts).reduce((acc, llm) => acc + response_counts[llm], 0);
// On connect to the server, ask it to give us the current progress
// for task 'queryllm' with id 'id', and stop when it reads progress >= 'max'.
socket.on("connect", (msg) => {
// Initialize progress bars to small amounts
setProgress(5);
setLLMItems(llmItemsCurrState.map(item => {
item.progress = 0;
return item;
}));
// Request progress bar updates
socket.emit("queryllm", {'id': id, 'max': max_responses});
});
socket.on("disconnect", (msg) => {
@ -380,6 +425,8 @@ const PromptNode = ({ data, id }) => {
status={status}
alertModal={alertModal}
handleRunClick={handleRunClick}
handleRunHover={handleRunHover}
runButtonTooltip={runTooltip}
/>
<div className="input-field">
<textarea

View File

@ -300,6 +300,10 @@
-moz-transition: all 0.2s ease-out;
transition: all 0.2s ease-out;
}
.AmitSahoo45-button-3:active {
background: #40a829;
color: yellow;
}
.AmitSahoo45-button-3:hover::before {
-moz-animation: sh02 0.5s 0s linear;
@ -364,6 +368,10 @@
-moz-transition: all 0.2s ease-out;
transition: all 0.2s ease-out;
}
.close-button:active {
background: #660000;
color: yellow;
}
.close-button:hover::before {
-moz-animation: sh02 0.5s 0s linear;

View File

@ -228,9 +228,11 @@ def countQueries():
return jsonify({'error': str(e)})
# TODO: Send more informative data back including how many queries per LLM based on cache'd data
num_queries = len(all_prompt_permutations) * len(data['llms'])
num_queries = {} # len(all_prompt_permutations) * len(data['llms'])
for llm in data['llms']:
num_queries[llm] = len(all_prompt_permutations)
ret = jsonify({'count': num_queries})
ret = jsonify({'counts': num_queries})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret

View File

@ -105,7 +105,7 @@ class PromptPipeline:
# Yield responses as they come in
for task in asyncio.as_completed(tasks):
# Collect the response from the earliest completed task
print(f'awaiting a response from {llm.name}...')
print(f'awaiting a task to call {llm.name}...')
prompt, query, response = await task
print('Completed!')

View File

@ -1,7 +1,7 @@
from typing import Dict, Tuple, List, Union
from enum import Enum
import openai
import json, os, time
import json, os, time, asyncio
DALAI_MODEL = None
DALAI_RESPONSE = None
@ -102,7 +102,7 @@ async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperat
# Blocking --wait for request to complete:
while DALAI_RESPONSE is None:
time.sleep(0.01)
await asyncio.sleep(0.01)
response = DALAI_RESPONSE['response']
if response[-5:] == '<end>': # strip ending <end> tag, if present