mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Dynamic run tooltip for prompt node
This commit is contained in:
parent
ccf6880da4
commit
d4e0630564
@ -155,6 +155,7 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
status={status}
|
||||
alertModal={alertModal}
|
||||
handleRunClick={handleRunClick}
|
||||
runButtonTooltip="Run evaluator over inputs"
|
||||
/>
|
||||
<Handle
|
||||
type="target"
|
||||
|
@ -12,9 +12,11 @@ export default function LLMItemButtonGroup( {onClickTrash, onClickSettings, ring
|
||||
</Modal>
|
||||
|
||||
<Group position="right" style={{float: 'right', height:'20px'}}>
|
||||
{(ringProgress) ?
|
||||
(<RingProgress size={20} thickness={3} sections={[{ value: ringProgress, color: ringProgress < 99 ? 'blue' : 'green' }]} width='16px' />) :
|
||||
<></>
|
||||
{ringProgress !== undefined ?
|
||||
(ringProgress > 0 ?
|
||||
(<RingProgress size={20} thickness={3} sections={[{ value: ringProgress, color: ringProgress < 99 ? 'blue' : 'green' }]} width='16px' />) :
|
||||
(<div className="lds-ring"><div></div><div></div><div></div><div></div></div>))
|
||||
: (<></>)
|
||||
}
|
||||
<Button onClick={onClickTrash} size="xs" variant="light" compact color="red" style={{padding: '0px'}} ><IconTrash size={"95%"} /></Button>
|
||||
<Button onClick={onClickSettings} size="xs" variant="light" compact>Settings <IconSettings size={"110%"} /></Button>
|
||||
|
@ -4,9 +4,9 @@ import 'react-edit-text/dist/index.css';
|
||||
import StatusIndicator from './StatusIndicatorComponent';
|
||||
import AlertModal from './AlertModal';
|
||||
import { useState, useEffect} from 'react';
|
||||
import { CloseButton } from '@mantine/core';
|
||||
import { Tooltip } from '@mantine/core';
|
||||
|
||||
export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editable, status, alertModal, handleRunClick }) {
|
||||
export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editable, status, alertModal, handleRunClick, handleRunHover, runButtonTooltip }) {
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
const [statusIndicator, setStatusIndicator] = useState('none');
|
||||
const [runButton, setRunButton] = useState('none');
|
||||
@ -33,12 +33,20 @@ export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editabl
|
||||
|
||||
useEffect(() => {
|
||||
if(handleRunClick !== undefined) {
|
||||
setRunButton(<button className="AmitSahoo45-button-3 nodrag" onClick={handleRunClick}>▶</button>);
|
||||
const run_btn = (<button className="AmitSahoo45-button-3 nodrag" onClick={handleRunClick} onPointerEnter={handleRunHover}>▶</button>);
|
||||
if (runButtonTooltip)
|
||||
setRunButton(
|
||||
<Tooltip label={runButtonTooltip} withArrow arrowSize={6} arrowRadius={2}>
|
||||
{run_btn}
|
||||
</Tooltip>
|
||||
);
|
||||
else
|
||||
setRunButton(run_btn);
|
||||
}
|
||||
else {
|
||||
setRunButton(<></>);
|
||||
}
|
||||
}, [handleRunClick]);
|
||||
}, [handleRunClick, runButtonTooltip]);
|
||||
|
||||
const handleCloseButtonClick = () => {
|
||||
removeNode(nodeId);
|
||||
|
@ -74,6 +74,7 @@ const PromptNode = ({ data, id }) => {
|
||||
|
||||
// Progress when querying responses
|
||||
const [progress, setProgress] = useState(100);
|
||||
const [runTooltip, setRunTooltip] = useState(null);
|
||||
|
||||
const triggerAlert = (msg) => {
|
||||
setProgress(100);
|
||||
@ -130,32 +131,10 @@ const PromptNode = ({ data, id }) => {
|
||||
}
|
||||
};
|
||||
|
||||
const handleRunClick = (event) => {
|
||||
// Go through all template hooks (if any) and check they're connected:
|
||||
const is_fully_connected = templateVars.every(varname => {
|
||||
// Check that some edge has, as its target, this node and its template hook:
|
||||
return edges.some(e => (e.target == id && e.targetHandle == varname));
|
||||
});
|
||||
|
||||
if (!is_fully_connected) {
|
||||
console.log('Not connected! :(');
|
||||
triggerAlert('Missing inputs to one or more template variables.');
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('Connected!');
|
||||
|
||||
// Check that there is at least one LLM selected:
|
||||
if (llmItemsCurrState.length === 0) {
|
||||
alert('Please select at least one LLM to prompt.')
|
||||
return;
|
||||
}
|
||||
|
||||
// Set status indicator
|
||||
setStatus('loading');
|
||||
setReponsePreviews([]);
|
||||
|
||||
// Pull data from each source, recursively:
|
||||
// Pull all inputs needed to request responses.
|
||||
// Returns [prompt, vars dict]
|
||||
const pullInputData = () => {
|
||||
// Pull data from each source recursively:
|
||||
const pulled_data = {};
|
||||
const get_outputs = (varnames, nodeId) => {
|
||||
varnames.forEach(varname => {
|
||||
@ -189,7 +168,78 @@ const PromptNode = ({ data, id }) => {
|
||||
Object.keys(pulled_data).forEach(varname => {
|
||||
pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val));
|
||||
});
|
||||
console.log(pulled_data);
|
||||
|
||||
return [py_prompt_template, pulled_data];
|
||||
};
|
||||
|
||||
// Ask the backend how many responses it needs to collect, given the input data:
|
||||
const fetchResponseCounts = (prompt, vars, llms, rejected) => {
|
||||
return fetch(BASE_URL + 'api/countQueriesRequired', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
prompt: prompt,
|
||||
vars: vars,
|
||||
llms: llms,
|
||||
})}, rejected).then(function(response) {
|
||||
return response.json();
|
||||
}, rejected).then(function(json) {
|
||||
if (!json || !json.counts) {
|
||||
throw new Error('Request was sent and received by backend server, but there was no response.');
|
||||
}
|
||||
return json.counts;
|
||||
}, rejected);
|
||||
};
|
||||
|
||||
// On hover over the 'Run' button, request how many responses are required and update the tooltip. Soft fails.
|
||||
const handleRunHover = () => {
|
||||
// Check if there's at least one model in the list; if not, nothing to run on.
|
||||
if (!llmItemsCurrState || llmItemsCurrState.length == 0) {
|
||||
setRunTooltip('No LLMs to query.');
|
||||
return;
|
||||
}
|
||||
|
||||
// Get input data and prompt
|
||||
const [py_prompt, pulled_vars] = pullInputData();
|
||||
const llms = llmItemsCurrState.map(item => item.model);
|
||||
const num_llms = llms.length;
|
||||
|
||||
// Fetch response counts from backend
|
||||
fetchResponseCounts(py_prompt, pulled_vars, llms, (err) => {
|
||||
console.warn(err.message); // soft fail
|
||||
}).then((counts) => {
|
||||
const n = counts[Object.keys(counts)[0]];
|
||||
const req = n > 1 ? 'requests' : 'request';
|
||||
setRunTooltip(`Will send ${n} ${req}` + (num_llms > 1 ? ' per LLM' : ''));
|
||||
});
|
||||
};
|
||||
|
||||
const handleRunClick = (event) => {
|
||||
// Go through all template hooks (if any) and check they're connected:
|
||||
const is_fully_connected = templateVars.every(varname => {
|
||||
// Check that some edge has, as its target, this node and its template hook:
|
||||
return edges.some(e => (e.target == id && e.targetHandle == varname));
|
||||
});
|
||||
|
||||
if (!is_fully_connected) {
|
||||
console.log('Not connected! :(');
|
||||
triggerAlert('Missing inputs to one or more template variables.');
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('Connected!');
|
||||
|
||||
// Check that there is at least one LLM selected:
|
||||
if (llmItemsCurrState.length === 0) {
|
||||
alert('Please select at least one LLM to prompt.')
|
||||
return;
|
||||
}
|
||||
|
||||
// Set status indicator
|
||||
setStatus('loading');
|
||||
setReponsePreviews([]);
|
||||
|
||||
const [py_prompt_template, pulled_data] = pullInputData();
|
||||
|
||||
let FINISHED_QUERY = false;
|
||||
const rejected = (err) => {
|
||||
@ -208,27 +258,12 @@ const PromptNode = ({ data, id }) => {
|
||||
})}, rejected);
|
||||
};
|
||||
|
||||
// Query the backend to ask how many responses it needs to collect, given the input data:
|
||||
const fetch_resp_count = () => {
|
||||
return fetch(BASE_URL + 'api/countQueriesRequired', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
prompt: py_prompt_template,
|
||||
vars: pulled_data,
|
||||
llms: llmItemsCurrState.map(item => item.model),
|
||||
})}, rejected).then(function(response) {
|
||||
return response.json();
|
||||
}, rejected).then(function(json) {
|
||||
if (!json || !json.count) {
|
||||
throw new Error('Request was sent and received by backend server, but there was no response.');
|
||||
}
|
||||
return json.count;
|
||||
}, rejected);
|
||||
};
|
||||
// Fetch info about the number of queries we'll need to make
|
||||
const fetch_resp_count = () => fetchResponseCounts(
|
||||
py_prompt_template, pulled_data, llmItemsCurrState.map(item => item.model), rejected);
|
||||
|
||||
// Open a socket to listen for progress
|
||||
const open_progress_listener_socket = (max_responses) => {
|
||||
const open_progress_listener_socket = (response_counts) => {
|
||||
// With the counts information we can create progress bars. Now we load a socket connection to
|
||||
// the socketio server that will stream to us the current progress:
|
||||
const socket = io('http://localhost:8001/' + 'queryllm', {
|
||||
@ -236,9 +271,19 @@ const PromptNode = ({ data, id }) => {
|
||||
cors: {origin: "http://localhost:3000/"},
|
||||
});
|
||||
|
||||
const max_responses = Object.keys(response_counts).reduce((acc, llm) => acc + response_counts[llm], 0);
|
||||
|
||||
// On connect to the server, ask it to give us the current progress
|
||||
// for task 'queryllm' with id 'id', and stop when it reads progress >= 'max'.
|
||||
socket.on("connect", (msg) => {
|
||||
// Initialize progress bars to small amounts
|
||||
setProgress(5);
|
||||
setLLMItems(llmItemsCurrState.map(item => {
|
||||
item.progress = 0;
|
||||
return item;
|
||||
}));
|
||||
|
||||
// Request progress bar updates
|
||||
socket.emit("queryllm", {'id': id, 'max': max_responses});
|
||||
});
|
||||
socket.on("disconnect", (msg) => {
|
||||
@ -380,6 +425,8 @@ const PromptNode = ({ data, id }) => {
|
||||
status={status}
|
||||
alertModal={alertModal}
|
||||
handleRunClick={handleRunClick}
|
||||
handleRunHover={handleRunHover}
|
||||
runButtonTooltip={runTooltip}
|
||||
/>
|
||||
<div className="input-field">
|
||||
<textarea
|
||||
|
@ -300,6 +300,10 @@
|
||||
-moz-transition: all 0.2s ease-out;
|
||||
transition: all 0.2s ease-out;
|
||||
}
|
||||
.AmitSahoo45-button-3:active {
|
||||
background: #40a829;
|
||||
color: yellow;
|
||||
}
|
||||
|
||||
.AmitSahoo45-button-3:hover::before {
|
||||
-moz-animation: sh02 0.5s 0s linear;
|
||||
@ -364,6 +368,10 @@
|
||||
-moz-transition: all 0.2s ease-out;
|
||||
transition: all 0.2s ease-out;
|
||||
}
|
||||
.close-button:active {
|
||||
background: #660000;
|
||||
color: yellow;
|
||||
}
|
||||
|
||||
.close-button:hover::before {
|
||||
-moz-animation: sh02 0.5s 0s linear;
|
||||
|
@ -228,9 +228,11 @@ def countQueries():
|
||||
return jsonify({'error': str(e)})
|
||||
|
||||
# TODO: Send more informative data back including how many queries per LLM based on cache'd data
|
||||
num_queries = len(all_prompt_permutations) * len(data['llms'])
|
||||
num_queries = {} # len(all_prompt_permutations) * len(data['llms'])
|
||||
for llm in data['llms']:
|
||||
num_queries[llm] = len(all_prompt_permutations)
|
||||
|
||||
ret = jsonify({'count': num_queries})
|
||||
ret = jsonify({'counts': num_queries})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
|
@ -105,7 +105,7 @@ class PromptPipeline:
|
||||
# Yield responses as they come in
|
||||
for task in asyncio.as_completed(tasks):
|
||||
# Collect the response from the earliest completed task
|
||||
print(f'awaiting a response from {llm.name}...')
|
||||
print(f'awaiting a task to call {llm.name}...')
|
||||
prompt, query, response = await task
|
||||
|
||||
print('Completed!')
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Dict, Tuple, List, Union
|
||||
from enum import Enum
|
||||
import openai
|
||||
import json, os, time
|
||||
import json, os, time, asyncio
|
||||
|
||||
DALAI_MODEL = None
|
||||
DALAI_RESPONSE = None
|
||||
@ -102,7 +102,7 @@ async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperat
|
||||
|
||||
# Blocking --wait for request to complete:
|
||||
while DALAI_RESPONSE is None:
|
||||
time.sleep(0.01)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
response = DALAI_RESPONSE['response']
|
||||
if response[-5:] == '<end>': # strip ending <end> tag, if present
|
||||
|
Loading…
x
Reference in New Issue
Block a user