mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
WIP live progress wheels
This commit is contained in:
parent
75718766ac
commit
ccf6880da4
80
chain-forge/package-lock.json
generated
80
chain-forge/package-lock.json
generated
@ -38,6 +38,7 @@
|
||||
"react-flow-renderer": "^10.3.17",
|
||||
"react-plotly.js": "^2.6.0",
|
||||
"react-scripts": "5.0.1",
|
||||
"socket.io-client": "^4.6.1",
|
||||
"styled-components": "^5.3.10",
|
||||
"uuidv4": "^6.2.13",
|
||||
"web-vitals": "^2.1.4",
|
||||
@ -3913,6 +3914,11 @@
|
||||
"@sinonjs/commons": "^1.7.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@socket.io/component-emitter": {
|
||||
"version": "3.1.0",
|
||||
"resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz",
|
||||
"integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg=="
|
||||
},
|
||||
"node_modules/@surma/rollup-plugin-off-main-thread": {
|
||||
"version": "2.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@surma/rollup-plugin-off-main-thread/-/rollup-plugin-off-main-thread-2.2.3.tgz",
|
||||
@ -8718,6 +8724,46 @@
|
||||
"once": "^1.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/engine.io-client": {
|
||||
"version": "6.4.0",
|
||||
"resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.4.0.tgz",
|
||||
"integrity": "sha512-GyKPDyoEha+XZ7iEqam49vz6auPnNJ9ZBfy89f+rMMas8AuiMWOZ9PVzu8xb9ZC6rafUqiGHSCfu22ih66E+1g==",
|
||||
"dependencies": {
|
||||
"@socket.io/component-emitter": "~3.1.0",
|
||||
"debug": "~4.3.1",
|
||||
"engine.io-parser": "~5.0.3",
|
||||
"ws": "~8.11.0",
|
||||
"xmlhttprequest-ssl": "~2.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/engine.io-client/node_modules/ws": {
|
||||
"version": "8.11.0",
|
||||
"resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz",
|
||||
"integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==",
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"bufferutil": "^4.0.1",
|
||||
"utf-8-validate": "^5.0.2"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"bufferutil": {
|
||||
"optional": true
|
||||
},
|
||||
"utf-8-validate": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/engine.io-parser": {
|
||||
"version": "5.0.6",
|
||||
"resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.0.6.tgz",
|
||||
"integrity": "sha512-tjuoZDMAdEhVnSFleYPCtdL2GXwVTGtNjoeJd9IhIG3C1xs9uwxqRNEu5WpnDZCaozwVlK/nuQhpodhXSIMaxw==",
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/enhanced-resolve": {
|
||||
"version": "5.12.0",
|
||||
"resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.12.0.tgz",
|
||||
@ -18338,6 +18384,32 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/socket.io-client": {
|
||||
"version": "4.6.1",
|
||||
"resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.6.1.tgz",
|
||||
"integrity": "sha512-5UswCV6hpaRsNg5kkEHVcbBIXEYoVbMQaHJBXJCyEQ+CiFPV1NIOY0XOFWG4XR4GZcB8Kn6AsRs/9cy9TbqVMQ==",
|
||||
"dependencies": {
|
||||
"@socket.io/component-emitter": "~3.1.0",
|
||||
"debug": "~4.3.2",
|
||||
"engine.io-client": "~6.4.0",
|
||||
"socket.io-parser": "~4.2.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/socket.io-parser": {
|
||||
"version": "4.2.2",
|
||||
"resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.2.tgz",
|
||||
"integrity": "sha512-DJtziuKypFkMMHCm2uIshOYC7QaylbtzQwiMYDuCKy3OPkjLzu4B2vAhTlqipRHHzrI0NJeBAizTK7X+6m1jVw==",
|
||||
"dependencies": {
|
||||
"@socket.io/component-emitter": "~3.1.0",
|
||||
"debug": "~4.3.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/sockjs": {
|
||||
"version": "0.3.24",
|
||||
"resolved": "https://registry.npmjs.org/sockjs/-/sockjs-0.3.24.tgz",
|
||||
@ -21030,6 +21102,14 @@
|
||||
"resolved": "https://registry.npmjs.org/xmlchars/-/xmlchars-2.2.0.tgz",
|
||||
"integrity": "sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw=="
|
||||
},
|
||||
"node_modules/xmlhttprequest-ssl": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.0.0.tgz",
|
||||
"integrity": "sha512-QKxVRxiRACQcVuQEYFsI1hhkrMlrXHPegbbd1yn9UHOmRxY+si12nQYzri3vbzt8VdTTRviqcKxcyllFas5z2A==",
|
||||
"engines": {
|
||||
"node": ">=0.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/xtend": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz",
|
||||
|
@ -33,6 +33,7 @@
|
||||
"react-flow-renderer": "^10.3.17",
|
||||
"react-plotly.js": "^2.6.0",
|
||||
"react-scripts": "5.0.1",
|
||||
"socket.io-client": "^4.6.1",
|
||||
"styled-components": "^5.3.10",
|
||||
"uuidv4": "^6.2.13",
|
||||
"web-vitals": "^2.1.4",
|
||||
|
@ -74,7 +74,7 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
console.log(script_paths);
|
||||
// Run evaluator in backend
|
||||
const codeTextOnRun = codeText + '';
|
||||
fetch(BASE_URL + 'execute', {
|
||||
fetch(BASE_URL + 'api/execute', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
|
@ -32,7 +32,7 @@ const InspectorNode = ({ data, id }) => {
|
||||
console.log(input_node_ids);
|
||||
|
||||
// Grab responses associated with those ids:
|
||||
fetch(BASE_URL + 'grabResponses', {
|
||||
fetch(BASE_URL + 'api/grabResponses', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { useDisclosure } from '@mantine/hooks';
|
||||
import { Modal, Button, Group } from '@mantine/core';
|
||||
import { Modal, Button, Group, RingProgress } from '@mantine/core';
|
||||
import { IconSettings, IconTrash } from '@tabler/icons-react';
|
||||
|
||||
export default function LLMItemButtonGroup( {onClickTrash, onClickSettings} ) {
|
||||
export default function LLMItemButtonGroup( {onClickTrash, onClickSettings, ringProgress} ) {
|
||||
const [opened, { open, close }] = useDisclosure(false);
|
||||
|
||||
return (
|
||||
@ -12,7 +12,11 @@ export default function LLMItemButtonGroup( {onClickTrash, onClickSettings} ) {
|
||||
</Modal>
|
||||
|
||||
<Group position="right" style={{float: 'right', height:'20px'}}>
|
||||
<Button onClick={onClickTrash} size="xs" variant="light" compact color="red" style={{padding: '0px'}} ><IconTrash size={"95%"} /></Button>
|
||||
{(ringProgress) ?
|
||||
(<RingProgress size={20} thickness={3} sections={[{ value: ringProgress, color: ringProgress < 99 ? 'blue' : 'green' }]} width='16px' />) :
|
||||
<></>
|
||||
}
|
||||
<Button onClick={onClickTrash} size="xs" variant="light" compact color="red" style={{padding: '0px'}} ><IconTrash size={"95%"} /></Button>
|
||||
<Button onClick={onClickSettings} size="xs" variant="light" compact>Settings <IconSettings size={"110%"} /></Button>
|
||||
</Group>
|
||||
</div>
|
||||
|
@ -63,7 +63,7 @@ export default function LLMList({llms, onItemsChange}) {
|
||||
{items.map((item, index) => (
|
||||
<Draggable key={item.key} draggableId={item.key} index={index}>
|
||||
{(provided, snapshot) => (
|
||||
<LLMListItem provided={provided} snapshot={snapshot} item={item} removeCallback={removeItem} />
|
||||
<LLMListItem provided={provided} snapshot={snapshot} item={item} removeCallback={removeItem} progress={item.progress} />
|
||||
)}
|
||||
</Draggable>
|
||||
))}
|
||||
|
@ -23,7 +23,7 @@ export const DragItem = styled.div`
|
||||
flex-direction: column;
|
||||
`;
|
||||
|
||||
const LLMListItem = ({ item, provided, snapshot, removeCallback }) => {
|
||||
const LLMListItem = ({ item, provided, snapshot, removeCallback, progress }) => {
|
||||
return (
|
||||
<DragItem
|
||||
ref={provided.innerRef}
|
||||
@ -33,7 +33,7 @@ const LLMListItem = ({ item, provided, snapshot, removeCallback }) => {
|
||||
>
|
||||
<div>
|
||||
<CardHeader>{item.emoji} {item.name}</CardHeader>
|
||||
<LLMItemButtonGroup onClickTrash={() => removeCallback(item.key)} />
|
||||
<LLMItemButtonGroup onClickTrash={() => removeCallback(item.key)} ringProgress={progress} />
|
||||
</div>
|
||||
|
||||
</DragItem>
|
||||
|
@ -1,14 +1,13 @@
|
||||
import React, { useEffect, useState, useRef, useCallback } from 'react';
|
||||
import { Handle } from 'react-flow-renderer';
|
||||
import { Menu, Badge } from '@mantine/core';
|
||||
import { Menu, Badge, Progress } from '@mantine/core';
|
||||
import { v4 as uuid } from 'uuid';
|
||||
import useStore from './store';
|
||||
import StatusIndicator from './StatusIndicatorComponent'
|
||||
import NodeLabel from './NodeLabelComponent'
|
||||
import TemplateHooks from './TemplateHooksComponent'
|
||||
import LLMList from './LLMListComponent'
|
||||
import AlertModal from './AlertModal'
|
||||
import {BASE_URL} from './store';
|
||||
import io from 'socket.io-client';
|
||||
|
||||
// Available LLMs
|
||||
const allLLMs = [
|
||||
@ -66,13 +65,28 @@ const PromptNode = ({ data, id }) => {
|
||||
// Selecting LLM models to prompt
|
||||
const [llmItems, setLLMItems] = useState(initLLMs.map((i, idx) => ({key: uuid(), ...i})));
|
||||
const [llmItemsCurrState, setLLMItemsCurrState] = useState([]);
|
||||
const resetLLMItemsProgress = useCallback(() => {
|
||||
setLLMItems(llmItemsCurrState.map(item => {
|
||||
item.progress = undefined;
|
||||
return item;
|
||||
}));
|
||||
}, [llmItemsCurrState]);
|
||||
|
||||
// Progress when querying responses
|
||||
const [progress, setProgress] = useState(100);
|
||||
|
||||
const triggerAlert = (msg) => {
|
||||
setProgress(100);
|
||||
resetLLMItemsProgress();
|
||||
alertModal.current.trigger(msg);
|
||||
};
|
||||
|
||||
const addModel = useCallback((model) => {
|
||||
// Get the item for that model
|
||||
let item = allLLMs.find(llm => llm.model === model);
|
||||
|
||||
if (!item) { // This should never trigger, but in case it does:
|
||||
alertModal.current.trigger(`Could not find model named '${model}' in list of available LLMs.`);
|
||||
triggerAlert(`Could not find model named '${model}' in list of available LLMs.`);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -123,63 +137,144 @@ const PromptNode = ({ data, id }) => {
|
||||
return edges.some(e => (e.target == id && e.targetHandle == varname));
|
||||
});
|
||||
|
||||
// console.log(templateHooks);
|
||||
if (!is_fully_connected) {
|
||||
console.log('Not connected! :(');
|
||||
triggerAlert('Missing inputs to one or more template variables.');
|
||||
return;
|
||||
}
|
||||
|
||||
if (is_fully_connected) {
|
||||
console.log('Connected!');
|
||||
console.log('Connected!');
|
||||
|
||||
// Check that there is at least one LLM selected:
|
||||
if (llmItemsCurrState.length === 0) {
|
||||
alert('Please select at least one LLM to prompt.')
|
||||
return;
|
||||
}
|
||||
// Check that there is at least one LLM selected:
|
||||
if (llmItemsCurrState.length === 0) {
|
||||
alert('Please select at least one LLM to prompt.')
|
||||
return;
|
||||
}
|
||||
|
||||
// Set status indicator
|
||||
setStatus('loading');
|
||||
// Set status indicator
|
||||
setStatus('loading');
|
||||
setReponsePreviews([]);
|
||||
|
||||
// Pull data from each source, recursively:
|
||||
const pulled_data = {};
|
||||
const get_outputs = (varnames, nodeId) => {
|
||||
console.log(varnames);
|
||||
varnames.forEach(varname => {
|
||||
// Find the relevant edge(s):
|
||||
edges.forEach(e => {
|
||||
if (e.target == nodeId && e.targetHandle == varname) {
|
||||
// Get the immediate output:
|
||||
let out = output(e.source, e.sourceHandle);
|
||||
// Pull data from each source, recursively:
|
||||
const pulled_data = {};
|
||||
const get_outputs = (varnames, nodeId) => {
|
||||
varnames.forEach(varname => {
|
||||
// Find the relevant edge(s):
|
||||
edges.forEach(e => {
|
||||
if (e.target == nodeId && e.targetHandle == varname) {
|
||||
// Get the immediate output:
|
||||
let out = output(e.source, e.sourceHandle);
|
||||
|
||||
// Save the var data from the pulled output
|
||||
if (varname in pulled_data)
|
||||
pulled_data[varname] = pulled_data[varname].concat(out);
|
||||
else
|
||||
pulled_data[varname] = out;
|
||||
// Save the var data from the pulled output
|
||||
if (varname in pulled_data)
|
||||
pulled_data[varname] = pulled_data[varname].concat(out);
|
||||
else
|
||||
pulled_data[varname] = out;
|
||||
|
||||
// Get any vars that the output depends on, and recursively collect those outputs as well:
|
||||
const n_vars = getNode(e.source).data.vars;
|
||||
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
|
||||
get_outputs(n_vars, e.source);
|
||||
}
|
||||
});
|
||||
// Get any vars that the output depends on, and recursively collect those outputs as well:
|
||||
const n_vars = getNode(e.source).data.vars;
|
||||
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
|
||||
get_outputs(n_vars, e.source);
|
||||
}
|
||||
});
|
||||
};
|
||||
get_outputs(templateVars, id);
|
||||
});
|
||||
};
|
||||
get_outputs(templateVars, id);
|
||||
|
||||
// Get Pythonic version of the prompt, by adding a $ before any template variables in braces:
|
||||
const to_py_template_format = (str) => str.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
|
||||
const py_prompt_template = to_py_template_format(promptText);
|
||||
// Get Pythonic version of the prompt, by adding a $ before any template variables in braces:
|
||||
const to_py_template_format = (str) => str.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
|
||||
const py_prompt_template = to_py_template_format(promptText);
|
||||
|
||||
// Do the same for the vars, since vars can themselves be prompt templates:
|
||||
Object.keys(pulled_data).forEach(varname => {
|
||||
pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val));
|
||||
// Do the same for the vars, since vars can themselves be prompt templates:
|
||||
Object.keys(pulled_data).forEach(varname => {
|
||||
pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val));
|
||||
});
|
||||
console.log(pulled_data);
|
||||
|
||||
let FINISHED_QUERY = false;
|
||||
const rejected = (err) => {
|
||||
setStatus('error');
|
||||
triggerAlert(err.message);
|
||||
FINISHED_QUERY = true;
|
||||
};
|
||||
|
||||
// Ask the backend to reset the scratchpad for counting queries:
|
||||
const create_progress_scratchpad = () => {
|
||||
return fetch(BASE_URL + 'api/createProgressFile', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
id: id,
|
||||
})}, rejected);
|
||||
};
|
||||
|
||||
// Query the backend to ask how many responses it needs to collect, given the input data:
|
||||
const fetch_resp_count = () => {
|
||||
return fetch(BASE_URL + 'api/countQueriesRequired', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
prompt: py_prompt_template,
|
||||
vars: pulled_data,
|
||||
llms: llmItemsCurrState.map(item => item.model),
|
||||
})}, rejected).then(function(response) {
|
||||
return response.json();
|
||||
}, rejected).then(function(json) {
|
||||
if (!json || !json.count) {
|
||||
throw new Error('Request was sent and received by backend server, but there was no response.');
|
||||
}
|
||||
return json.count;
|
||||
}, rejected);
|
||||
};
|
||||
|
||||
// Open a socket to listen for progress
|
||||
const open_progress_listener_socket = (max_responses) => {
|
||||
// With the counts information we can create progress bars. Now we load a socket connection to
|
||||
// the socketio server that will stream to us the current progress:
|
||||
const socket = io('http://localhost:8001/' + 'queryllm', {
|
||||
transports: ["websocket"],
|
||||
cors: {origin: "http://localhost:3000/"},
|
||||
});
|
||||
|
||||
const rejected = (err) => {
|
||||
setStatus('error');
|
||||
alertModal.current.trigger(err.message);
|
||||
};
|
||||
// On connect to the server, ask it to give us the current progress
|
||||
// for task 'queryllm' with id 'id', and stop when it reads progress >= 'max'.
|
||||
socket.on("connect", (msg) => {
|
||||
socket.emit("queryllm", {'id': id, 'max': max_responses});
|
||||
});
|
||||
socket.on("disconnect", (msg) => {
|
||||
console.log(msg);
|
||||
});
|
||||
|
||||
// Run all prompt permutations through the LLM to generate + cache responses:
|
||||
fetch(BASE_URL + 'queryllm', {
|
||||
// The current progress, a number specifying how many responses collected so far:
|
||||
socket.on("response", (counts) => {
|
||||
console.log(counts);
|
||||
if (!counts || FINISHED_QUERY) return;
|
||||
|
||||
// Update individual progress bars
|
||||
const num_llms = llmItemsCurrState.length;
|
||||
setLLMItems(llmItemsCurrState.map(item => {
|
||||
if (item.model in counts)
|
||||
item.progress = counts[item.model] / (max_responses / num_llms) * 100;
|
||||
return item;
|
||||
}));
|
||||
|
||||
// Update total progress bar
|
||||
const total_num_resps = Object.keys(counts).reduce((acc, llm_name) => {
|
||||
return acc + counts[llm_name];
|
||||
}, 0);
|
||||
setProgress(total_num_resps / max_responses * 100);
|
||||
});
|
||||
|
||||
// The process has finished; close the connection:
|
||||
socket.on("finish", (msg) => {
|
||||
console.log("finished:", msg);
|
||||
socket.disconnect();
|
||||
});
|
||||
};
|
||||
|
||||
// Run all prompt permutations through the LLM to generate + cache responses:
|
||||
const query_llms = () => {
|
||||
return fetch(BASE_URL + 'api/queryllm', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
@ -191,7 +286,7 @@ const PromptNode = ({ data, id }) => {
|
||||
temperature: 0.5,
|
||||
n: numGenerations,
|
||||
},
|
||||
no_cache: false,
|
||||
no_cache: true,
|
||||
}),
|
||||
}, rejected).then(function(response) {
|
||||
return response.json();
|
||||
@ -205,6 +300,11 @@ const PromptNode = ({ data, id }) => {
|
||||
// Success! Change status to 'ready':
|
||||
setStatus('ready');
|
||||
|
||||
// Remove progress bars
|
||||
setProgress(100);
|
||||
resetLLMItemsProgress();
|
||||
FINISHED_QUERY = true;
|
||||
|
||||
// Save prompt text so we remember what prompt we have responses cache'd for:
|
||||
setPromptTextOnLastRun(promptText);
|
||||
|
||||
@ -246,14 +346,15 @@ const PromptNode = ({ data, id }) => {
|
||||
alertModal.current.trigger(json.error || 'Unknown error when querying LLM');
|
||||
}
|
||||
}, rejected);
|
||||
};
|
||||
|
||||
console.log(pulled_data);
|
||||
} else {
|
||||
console.log('Not connected! :(');
|
||||
alertModal.current.trigger('Missing inputs to one or more template variables.')
|
||||
|
||||
// TODO: Blink the names of unconnected params
|
||||
}
|
||||
// Now put it all together!
|
||||
create_progress_scratchpad()
|
||||
.then(fetch_resp_count)
|
||||
.then(open_progress_listener_socket)
|
||||
.then(query_llms)
|
||||
.catch(rejected);
|
||||
|
||||
}
|
||||
|
||||
const handleNumGenChange = (event) => {
|
||||
@ -331,6 +432,7 @@ const PromptNode = ({ data, id }) => {
|
||||
<label htmlFor="alpaca.7B">Alpaca 7B</label> */}
|
||||
</div>
|
||||
</div>
|
||||
{progress < 100 ? (<Progress value={progress} animate />) : <></>}
|
||||
<div className="response-preview-container nowheel">
|
||||
{responsePreviews}
|
||||
</div>
|
||||
|
@ -47,7 +47,7 @@ const VisNode = ({ data, id }) => {
|
||||
// Grab the input node ids
|
||||
const input_node_ids = [data.input];
|
||||
|
||||
fetch(BASE_URL + 'grabResponses', {
|
||||
fetch(BASE_URL + 'api/grabResponses', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
|
@ -3,12 +3,23 @@ from dataclasses import dataclass
|
||||
from statistics import mean, median, stdev
|
||||
from flask import Flask, request, jsonify
|
||||
from flask_cors import CORS
|
||||
from flask_socketio import SocketIO
|
||||
from promptengine.query import PromptLLM, PromptLLMDummy
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
# Set up CORS for specific routes
|
||||
cors = CORS(app, resources={r"/api/*": {"origins": "*"}})
|
||||
|
||||
# Initialize Socket.IO
|
||||
# socketio = SocketIO(app, cors_allowed_origins="*", async_mode="gevent")
|
||||
|
||||
# import threading
|
||||
# thread = None
|
||||
# thread_lock = threading.Lock()
|
||||
|
||||
|
||||
LLM_NAME_MAP = {
|
||||
'gpt3.5': LLM.ChatGPT,
|
||||
@ -138,11 +149,119 @@ def reduce_responses(responses: list, vars: list) -> list:
|
||||
|
||||
return ret
|
||||
|
||||
@app.route('/test', methods=['GET'])
|
||||
@app.route('/api/test', methods=['GET'])
|
||||
def test():
|
||||
return "Hello, world!"
|
||||
|
||||
@app.route('/queryllm', methods=['POST'])
|
||||
# @socketio.on('queryllm', namespace='/queryllm')
|
||||
# def handleQueryAsync(data):
|
||||
# print("reached handleQueryAsync")
|
||||
# socketio.start_background_task(queryLLM, emitLLMResponse)
|
||||
|
||||
# def emitLLMResponse(result):
|
||||
# socketio.emit('response', result)
|
||||
|
||||
"""
|
||||
Testing sockets. The following function can
|
||||
communicate to React via with the JS code:
|
||||
|
||||
const socket = io(BASE_URL + 'queryllm', {
|
||||
transports: ["websocket"],
|
||||
cors: {
|
||||
origin: "http://localhost:3000/",
|
||||
},
|
||||
});
|
||||
|
||||
socket.on("connect", (data) => {
|
||||
socket.emit("queryllm", "hello");
|
||||
});
|
||||
socket.on("disconnect", (data) => {
|
||||
console.log("disconnected");
|
||||
});
|
||||
socket.on("response", (data) => {
|
||||
console.log(data);
|
||||
});
|
||||
"""
|
||||
# def background_thread():
|
||||
# n = 10
|
||||
# while n > 0:
|
||||
# socketio.sleep(0.5)
|
||||
# socketio.emit('response', n, namespace='/queryllm')
|
||||
# n -= 1
|
||||
|
||||
# @socketio.on('queryllm', namespace='/queryllm')
|
||||
# def testSocket(data):
|
||||
# print(data)
|
||||
# global thread
|
||||
# with thread_lock:
|
||||
# if thread is None:
|
||||
# thread = socketio.start_background_task(target=background_thread)
|
||||
|
||||
# @socketio.on('queryllm', namespace='/queryllm')
|
||||
# def handleQuery(data):
|
||||
# print(data)
|
||||
|
||||
# def handleConnect():
|
||||
# print('here')
|
||||
# socketio.emit('response', 'goodbye', namespace='/')
|
||||
|
||||
@app.route('/api/countQueriesRequired', methods=['POST'])
|
||||
def countQueries():
|
||||
"""
|
||||
Returns how many queries we need to make, given the passed prompt and vars.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'prompt': str # the prompt template, with any {{}} vars
|
||||
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
|
||||
'llms': list # the list of LLMs you will query
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
if not set(data.keys()).issuperset({'prompt', 'vars', 'llms'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
|
||||
try:
|
||||
gen_prompts = PromptPermutationGenerator(PromptTemplate(data['prompt']))
|
||||
all_prompt_permutations = list(gen_prompts(data['vars']))
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)})
|
||||
|
||||
# TODO: Send more informative data back including how many queries per LLM based on cache'd data
|
||||
num_queries = len(all_prompt_permutations) * len(data['llms'])
|
||||
|
||||
ret = jsonify({'count': num_queries})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@app.route('/api/createProgressFile', methods=['POST'])
|
||||
def createProgressFile():
|
||||
"""
|
||||
Creates a temp txt file for storing progress of async LLM queries.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'id': str # a unique ID that will be used when calling 'queryllm'
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
if 'id' not in data or not isinstance(data['id'], str) or len(data['id']) == 0:
|
||||
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
||||
|
||||
# Create a scratch file for keeping track of how many responses loaded
|
||||
try:
|
||||
with open(f"cache/_temp_{data['id']}.txt", 'w') as f:
|
||||
json.dump({}, f)
|
||||
ret = jsonify({'success': True})
|
||||
except Exception as e:
|
||||
ret = jsonify({'success': False, 'error': str(e)})
|
||||
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
# @socketio.on('connect', namespace='/queryllm')
|
||||
@app.route('/api/queryllm', methods=['POST'])
|
||||
async def queryLLM():
|
||||
"""
|
||||
Queries LLM(s) given a JSON spec.
|
||||
@ -170,11 +289,10 @@ async def queryLLM():
|
||||
return jsonify({'error': 'POST data llm is improper format (not string or list, or of length 0).'})
|
||||
if isinstance(data['llm'], str):
|
||||
data['llm'] = [ data['llm'] ]
|
||||
llms = []
|
||||
|
||||
for llm_str in data['llm']:
|
||||
if llm_str not in LLM_NAME_MAP:
|
||||
return jsonify({'error': f"LLM named '{llm_str}' is not supported."})
|
||||
llms.append(LLM_NAME_MAP[llm_str])
|
||||
|
||||
if 'no_cache' in data and data['no_cache'] is True:
|
||||
remove_cached_responses(data['id'])
|
||||
@ -184,9 +302,13 @@ async def queryLLM():
|
||||
|
||||
# For each LLM, generate and cache responses:
|
||||
responses = {}
|
||||
llms = data['llm']
|
||||
params = data['params'] if 'params' in data else {}
|
||||
tempfilepath = f"cache/_temp_{data['id']}.txt"
|
||||
|
||||
async def query(llm_str: str) -> list:
|
||||
llm = LLM_NAME_MAP[llm_str]
|
||||
|
||||
async def query(llm: str) -> list:
|
||||
# Check that storage path is valid:
|
||||
cache_filepath = os.path.join('cache', f"{data['id']}-{str(llm.name)}.json")
|
||||
if not is_valid_filepath(cache_filepath):
|
||||
@ -202,6 +324,17 @@ async def queryLLM():
|
||||
print(f'Querying {llm}...')
|
||||
async for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
|
||||
resps.append(response)
|
||||
print(f"collected response from {llm.name}:", str(response))
|
||||
|
||||
# Save the number of responses collected to a temp file on disk
|
||||
with open(tempfilepath, 'r') as f:
|
||||
txt = f.read().strip()
|
||||
|
||||
cur_data = json.loads(txt) if len(txt) > 0 else {}
|
||||
cur_data[llm_str] = len(resps)
|
||||
|
||||
with open(tempfilepath, 'w') as f:
|
||||
json.dump(cur_data, f)
|
||||
except Exception as e:
|
||||
print('error generating responses:', e)
|
||||
raise e
|
||||
@ -233,7 +366,7 @@ async def queryLLM():
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@app.route('/execute', methods=['POST'])
|
||||
@app.route('/api/execute', methods=['POST'])
|
||||
def execute():
|
||||
"""
|
||||
Executes a Python lambda function sent from JavaScript,
|
||||
@ -348,7 +481,7 @@ def execute():
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@app.route('/checkEvalFunc', methods=['POST'])
|
||||
@app.route('/api/checkEvalFunc', methods=['POST'])
|
||||
def checkEvalFunc():
|
||||
"""
|
||||
Tries to compile a Python lambda function sent from JavaScript.
|
||||
@ -378,7 +511,7 @@ def checkEvalFunc():
|
||||
except Exception as e:
|
||||
return jsonify({'result': False, 'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
|
||||
|
||||
@app.route('/grabResponses', methods=['POST'])
|
||||
@app.route('/api/grabResponses', methods=['POST'])
|
||||
def grabResponses():
|
||||
"""
|
||||
Returns all responses with the specified id(s)
|
||||
@ -430,10 +563,14 @@ if __name__ == '__main__':
|
||||
Produces each dummy response at random intervals between 0.1 and 3 seconds.""",
|
||||
dest='dummy_responses',
|
||||
action='store_true')
|
||||
parser.add_argument('--port', help='The port to run the server on. Defaults to 8000.', type=int, default=8000, nargs='?')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dummy_responses:
|
||||
PromptLLM = PromptLLMDummy
|
||||
extract_responses = lambda r, llm: r['response']
|
||||
|
||||
port = args.port if args.port else 8000
|
||||
print(f"Serving Flask server on port {port}...")
|
||||
# socketio.run(app, host="localhost", port=port)
|
||||
app.run(host="localhost", port=8000, debug=True)
|
@ -79,8 +79,10 @@ class PromptPipeline:
|
||||
tasks.append(self._prompt_llm(llm, prompt, n, temperature))
|
||||
else:
|
||||
# Blocking. Await + yield a single LLM call.
|
||||
print('reached')
|
||||
_, query, response = await self._prompt_llm(llm, prompt, n, temperature)
|
||||
info = prompt.fill_history
|
||||
print('back')
|
||||
|
||||
# Save the response to a JSON file
|
||||
responses[str(prompt)] = {
|
||||
@ -103,8 +105,11 @@ class PromptPipeline:
|
||||
# Yield responses as they come in
|
||||
for task in asyncio.as_completed(tasks):
|
||||
# Collect the response from the earliest completed task
|
||||
print(f'awaiting a response from {llm.name}...')
|
||||
prompt, query, response = await task
|
||||
|
||||
print('Completed!')
|
||||
|
||||
# Each prompt has a history of what was filled in from its base template.
|
||||
# This data --like, "class", "language", "library" etc --can be useful when parsing responses.
|
||||
info = prompt.fill_history
|
||||
@ -149,8 +154,10 @@ class PromptPipeline:
|
||||
|
||||
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Dict]:
|
||||
if llm is LLM.ChatGPT or llm is LLM.GPT4:
|
||||
print('calling chatgpt and awaiting')
|
||||
query, response = await call_chatgpt(str(prompt), model=llm, n=n, temperature=temperature)
|
||||
elif llm is LLM.Alpaca7B:
|
||||
print('calling dalai alpaca.7b and awaiting')
|
||||
query, response = await call_dalai(llm_name='alpaca.7B', port=4000, prompt=str(prompt), n=n, temperature=temperature)
|
||||
else:
|
||||
raise Exception(f"Language model {llm} is not supported.")
|
||||
|
@ -1,5 +1,7 @@
|
||||
dalaipy==2.0.2
|
||||
flask[async]
|
||||
flask_cors
|
||||
flask_socketio
|
||||
openai
|
||||
python-socketio
|
||||
python-socketio
|
||||
dalaipy==2.0.2
|
||||
gevent-websocket
|
67
python-backend/test_sockets.py
Normal file
67
python-backend/test_sockets.py
Normal file
@ -0,0 +1,67 @@
|
||||
import json, os, asyncio, sys, argparse
|
||||
from dataclasses import dataclass
|
||||
from statistics import mean, median, stdev
|
||||
from flask import Flask, request, jsonify
|
||||
from flask_cors import CORS
|
||||
from flask_socketio import SocketIO
|
||||
from promptengine.query import PromptLLM, PromptLLMDummy
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# Set up CORS for specific routes
|
||||
# cors = CORS(app, resources={r"/api/*": {"origins": "*"}})
|
||||
|
||||
# Initialize Socket.IO
|
||||
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="gevent")
|
||||
|
||||
# Wait a max of a full minute (60 seconds) for the response count to update, before exiting.
|
||||
MAX_WAIT_TIME = 60
|
||||
|
||||
# import threading
|
||||
# thread = None
|
||||
# thread_lock = threading.Lock()
|
||||
|
||||
def countdown():
|
||||
n = 10
|
||||
while n > 0:
|
||||
socketio.sleep(0.5)
|
||||
socketio.emit('response', n, namespace='/queryllm')
|
||||
n -= 1
|
||||
|
||||
def readCounts(id, max_count):
|
||||
i = 0
|
||||
n = 0
|
||||
last_n = 0
|
||||
while i < MAX_WAIT_TIME and n < max_count:
|
||||
with open(f'cache/_temp_{id}.txt', 'r') as f:
|
||||
queries = json.load(f)
|
||||
n = sum([int(n) for llm, n in queries.items()])
|
||||
print(n)
|
||||
socketio.emit('response', queries, namespace='/queryllm')
|
||||
socketio.sleep(0.1)
|
||||
if last_n != n:
|
||||
i = 0
|
||||
last_n = n
|
||||
else:
|
||||
i += 0.1
|
||||
|
||||
if i >= MAX_WAIT_TIME:
|
||||
print(f"Error: Waited maximum {MAX_WAIT_TIME} seconds for response count to update. Exited prematurely.")
|
||||
socketio.emit('finish', 'max_wait_reached', namespace='/queryllm')
|
||||
else:
|
||||
print("All responses loaded!")
|
||||
socketio.emit('finish', 'success', namespace='/queryllm')
|
||||
|
||||
@socketio.on('queryllm', namespace='/queryllm')
|
||||
def testSocket(data):
|
||||
readCounts(data['id'], data['max'])
|
||||
# countdown()
|
||||
# global thread
|
||||
# with thread_lock:
|
||||
# if thread is None:
|
||||
# thread = socketio.start_background_task(target=countdown)
|
||||
|
||||
if __name__ == "__main__":
|
||||
socketio.run(app, host="localhost", port=8001)
|
Loading…
x
Reference in New Issue
Block a user