WIP live progress wheels

This commit is contained in:
Ian Arawjo 2023-05-04 18:54:31 -04:00
parent 75718766ac
commit ccf6880da4
13 changed files with 477 additions and 77 deletions

View File

@ -38,6 +38,7 @@
"react-flow-renderer": "^10.3.17",
"react-plotly.js": "^2.6.0",
"react-scripts": "5.0.1",
"socket.io-client": "^4.6.1",
"styled-components": "^5.3.10",
"uuidv4": "^6.2.13",
"web-vitals": "^2.1.4",
@ -3913,6 +3914,11 @@
"@sinonjs/commons": "^1.7.0"
}
},
"node_modules/@socket.io/component-emitter": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz",
"integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg=="
},
"node_modules/@surma/rollup-plugin-off-main-thread": {
"version": "2.2.3",
"resolved": "https://registry.npmjs.org/@surma/rollup-plugin-off-main-thread/-/rollup-plugin-off-main-thread-2.2.3.tgz",
@ -8718,6 +8724,46 @@
"once": "^1.4.0"
}
},
"node_modules/engine.io-client": {
"version": "6.4.0",
"resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.4.0.tgz",
"integrity": "sha512-GyKPDyoEha+XZ7iEqam49vz6auPnNJ9ZBfy89f+rMMas8AuiMWOZ9PVzu8xb9ZC6rafUqiGHSCfu22ih66E+1g==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1",
"engine.io-parser": "~5.0.3",
"ws": "~8.11.0",
"xmlhttprequest-ssl": "~2.0.0"
}
},
"node_modules/engine.io-client/node_modules/ws": {
"version": "8.11.0",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz",
"integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==",
"engines": {
"node": ">=10.0.0"
},
"peerDependencies": {
"bufferutil": "^4.0.1",
"utf-8-validate": "^5.0.2"
},
"peerDependenciesMeta": {
"bufferutil": {
"optional": true
},
"utf-8-validate": {
"optional": true
}
}
},
"node_modules/engine.io-parser": {
"version": "5.0.6",
"resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.0.6.tgz",
"integrity": "sha512-tjuoZDMAdEhVnSFleYPCtdL2GXwVTGtNjoeJd9IhIG3C1xs9uwxqRNEu5WpnDZCaozwVlK/nuQhpodhXSIMaxw==",
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/enhanced-resolve": {
"version": "5.12.0",
"resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.12.0.tgz",
@ -18338,6 +18384,32 @@
"node": ">=8"
}
},
"node_modules/socket.io-client": {
"version": "4.6.1",
"resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.6.1.tgz",
"integrity": "sha512-5UswCV6hpaRsNg5kkEHVcbBIXEYoVbMQaHJBXJCyEQ+CiFPV1NIOY0XOFWG4XR4GZcB8Kn6AsRs/9cy9TbqVMQ==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.2",
"engine.io-client": "~6.4.0",
"socket.io-parser": "~4.2.1"
},
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/socket.io-parser": {
"version": "4.2.2",
"resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.2.tgz",
"integrity": "sha512-DJtziuKypFkMMHCm2uIshOYC7QaylbtzQwiMYDuCKy3OPkjLzu4B2vAhTlqipRHHzrI0NJeBAizTK7X+6m1jVw==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1"
},
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/sockjs": {
"version": "0.3.24",
"resolved": "https://registry.npmjs.org/sockjs/-/sockjs-0.3.24.tgz",
@ -21030,6 +21102,14 @@
"resolved": "https://registry.npmjs.org/xmlchars/-/xmlchars-2.2.0.tgz",
"integrity": "sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw=="
},
"node_modules/xmlhttprequest-ssl": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.0.0.tgz",
"integrity": "sha512-QKxVRxiRACQcVuQEYFsI1hhkrMlrXHPegbbd1yn9UHOmRxY+si12nQYzri3vbzt8VdTTRviqcKxcyllFas5z2A==",
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/xtend": {
"version": "4.0.2",
"resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz",

View File

@ -33,6 +33,7 @@
"react-flow-renderer": "^10.3.17",
"react-plotly.js": "^2.6.0",
"react-scripts": "5.0.1",
"socket.io-client": "^4.6.1",
"styled-components": "^5.3.10",
"uuidv4": "^6.2.13",
"web-vitals": "^2.1.4",

View File

@ -74,7 +74,7 @@ const EvaluatorNode = ({ data, id }) => {
console.log(script_paths);
// Run evaluator in backend
const codeTextOnRun = codeText + '';
fetch(BASE_URL + 'execute', {
fetch(BASE_URL + 'api/execute', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({

View File

@ -32,7 +32,7 @@ const InspectorNode = ({ data, id }) => {
console.log(input_node_ids);
// Grab responses associated with those ids:
fetch(BASE_URL + 'grabResponses', {
fetch(BASE_URL + 'api/grabResponses', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({

View File

@ -1,8 +1,8 @@
import { useDisclosure } from '@mantine/hooks';
import { Modal, Button, Group } from '@mantine/core';
import { Modal, Button, Group, RingProgress } from '@mantine/core';
import { IconSettings, IconTrash } from '@tabler/icons-react';
export default function LLMItemButtonGroup( {onClickTrash, onClickSettings} ) {
export default function LLMItemButtonGroup( {onClickTrash, onClickSettings, ringProgress} ) {
const [opened, { open, close }] = useDisclosure(false);
return (
@ -12,7 +12,11 @@ export default function LLMItemButtonGroup( {onClickTrash, onClickSettings} ) {
</Modal>
<Group position="right" style={{float: 'right', height:'20px'}}>
<Button onClick={onClickTrash} size="xs" variant="light" compact color="red" style={{padding: '0px'}} ><IconTrash size={"95%"} /></Button>
{(ringProgress) ?
(<RingProgress size={20} thickness={3} sections={[{ value: ringProgress, color: ringProgress < 99 ? 'blue' : 'green' }]} width='16px' />) :
<></>
}
<Button onClick={onClickTrash} size="xs" variant="light" compact color="red" style={{padding: '0px'}} ><IconTrash size={"95%"} /></Button>
<Button onClick={onClickSettings} size="xs" variant="light" compact>Settings&nbsp;<IconSettings size={"110%"} /></Button>
</Group>
</div>

View File

@ -63,7 +63,7 @@ export default function LLMList({llms, onItemsChange}) {
{items.map((item, index) => (
<Draggable key={item.key} draggableId={item.key} index={index}>
{(provided, snapshot) => (
<LLMListItem provided={provided} snapshot={snapshot} item={item} removeCallback={removeItem} />
<LLMListItem provided={provided} snapshot={snapshot} item={item} removeCallback={removeItem} progress={item.progress} />
)}
</Draggable>
))}

View File

@ -23,7 +23,7 @@ export const DragItem = styled.div`
flex-direction: column;
`;
const LLMListItem = ({ item, provided, snapshot, removeCallback }) => {
const LLMListItem = ({ item, provided, snapshot, removeCallback, progress }) => {
return (
<DragItem
ref={provided.innerRef}
@ -33,7 +33,7 @@ const LLMListItem = ({ item, provided, snapshot, removeCallback }) => {
>
<div>
<CardHeader>{item.emoji}&nbsp;{item.name}</CardHeader>
<LLMItemButtonGroup onClickTrash={() => removeCallback(item.key)} />
<LLMItemButtonGroup onClickTrash={() => removeCallback(item.key)} ringProgress={progress} />
</div>
</DragItem>

View File

@ -1,14 +1,13 @@
import React, { useEffect, useState, useRef, useCallback } from 'react';
import { Handle } from 'react-flow-renderer';
import { Menu, Badge } from '@mantine/core';
import { Menu, Badge, Progress } from '@mantine/core';
import { v4 as uuid } from 'uuid';
import useStore from './store';
import StatusIndicator from './StatusIndicatorComponent'
import NodeLabel from './NodeLabelComponent'
import TemplateHooks from './TemplateHooksComponent'
import LLMList from './LLMListComponent'
import AlertModal from './AlertModal'
import {BASE_URL} from './store';
import io from 'socket.io-client';
// Available LLMs
const allLLMs = [
@ -66,13 +65,28 @@ const PromptNode = ({ data, id }) => {
// Selecting LLM models to prompt
const [llmItems, setLLMItems] = useState(initLLMs.map((i, idx) => ({key: uuid(), ...i})));
const [llmItemsCurrState, setLLMItemsCurrState] = useState([]);
const resetLLMItemsProgress = useCallback(() => {
setLLMItems(llmItemsCurrState.map(item => {
item.progress = undefined;
return item;
}));
}, [llmItemsCurrState]);
// Progress when querying responses
const [progress, setProgress] = useState(100);
const triggerAlert = (msg) => {
setProgress(100);
resetLLMItemsProgress();
alertModal.current.trigger(msg);
};
const addModel = useCallback((model) => {
// Get the item for that model
let item = allLLMs.find(llm => llm.model === model);
if (!item) { // This should never trigger, but in case it does:
alertModal.current.trigger(`Could not find model named '${model}' in list of available LLMs.`);
triggerAlert(`Could not find model named '${model}' in list of available LLMs.`);
return;
}
@ -123,63 +137,144 @@ const PromptNode = ({ data, id }) => {
return edges.some(e => (e.target == id && e.targetHandle == varname));
});
// console.log(templateHooks);
if (!is_fully_connected) {
console.log('Not connected! :(');
triggerAlert('Missing inputs to one or more template variables.');
return;
}
if (is_fully_connected) {
console.log('Connected!');
console.log('Connected!');
// Check that there is at least one LLM selected:
if (llmItemsCurrState.length === 0) {
alert('Please select at least one LLM to prompt.')
return;
}
// Check that there is at least one LLM selected:
if (llmItemsCurrState.length === 0) {
alert('Please select at least one LLM to prompt.')
return;
}
// Set status indicator
setStatus('loading');
// Set status indicator
setStatus('loading');
setReponsePreviews([]);
// Pull data from each source, recursively:
const pulled_data = {};
const get_outputs = (varnames, nodeId) => {
console.log(varnames);
varnames.forEach(varname => {
// Find the relevant edge(s):
edges.forEach(e => {
if (e.target == nodeId && e.targetHandle == varname) {
// Get the immediate output:
let out = output(e.source, e.sourceHandle);
// Pull data from each source, recursively:
const pulled_data = {};
const get_outputs = (varnames, nodeId) => {
varnames.forEach(varname => {
// Find the relevant edge(s):
edges.forEach(e => {
if (e.target == nodeId && e.targetHandle == varname) {
// Get the immediate output:
let out = output(e.source, e.sourceHandle);
// Save the var data from the pulled output
if (varname in pulled_data)
pulled_data[varname] = pulled_data[varname].concat(out);
else
pulled_data[varname] = out;
// Save the var data from the pulled output
if (varname in pulled_data)
pulled_data[varname] = pulled_data[varname].concat(out);
else
pulled_data[varname] = out;
// Get any vars that the output depends on, and recursively collect those outputs as well:
const n_vars = getNode(e.source).data.vars;
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
get_outputs(n_vars, e.source);
}
});
// Get any vars that the output depends on, and recursively collect those outputs as well:
const n_vars = getNode(e.source).data.vars;
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
get_outputs(n_vars, e.source);
}
});
};
get_outputs(templateVars, id);
});
};
get_outputs(templateVars, id);
// Get Pythonic version of the prompt, by adding a $ before any template variables in braces:
const to_py_template_format = (str) => str.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
const py_prompt_template = to_py_template_format(promptText);
// Get Pythonic version of the prompt, by adding a $ before any template variables in braces:
const to_py_template_format = (str) => str.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
const py_prompt_template = to_py_template_format(promptText);
// Do the same for the vars, since vars can themselves be prompt templates:
Object.keys(pulled_data).forEach(varname => {
pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val));
// Do the same for the vars, since vars can themselves be prompt templates:
Object.keys(pulled_data).forEach(varname => {
pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val));
});
console.log(pulled_data);
let FINISHED_QUERY = false;
const rejected = (err) => {
setStatus('error');
triggerAlert(err.message);
FINISHED_QUERY = true;
};
// Ask the backend to reset the scratchpad for counting queries:
const create_progress_scratchpad = () => {
return fetch(BASE_URL + 'api/createProgressFile', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
id: id,
})}, rejected);
};
// Query the backend to ask how many responses it needs to collect, given the input data:
const fetch_resp_count = () => {
return fetch(BASE_URL + 'api/countQueriesRequired', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
prompt: py_prompt_template,
vars: pulled_data,
llms: llmItemsCurrState.map(item => item.model),
})}, rejected).then(function(response) {
return response.json();
}, rejected).then(function(json) {
if (!json || !json.count) {
throw new Error('Request was sent and received by backend server, but there was no response.');
}
return json.count;
}, rejected);
};
// Open a socket to listen for progress
const open_progress_listener_socket = (max_responses) => {
// With the counts information we can create progress bars. Now we load a socket connection to
// the socketio server that will stream to us the current progress:
const socket = io('http://localhost:8001/' + 'queryllm', {
transports: ["websocket"],
cors: {origin: "http://localhost:3000/"},
});
const rejected = (err) => {
setStatus('error');
alertModal.current.trigger(err.message);
};
// On connect to the server, ask it to give us the current progress
// for task 'queryllm' with id 'id', and stop when it reads progress >= 'max'.
socket.on("connect", (msg) => {
socket.emit("queryllm", {'id': id, 'max': max_responses});
});
socket.on("disconnect", (msg) => {
console.log(msg);
});
// Run all prompt permutations through the LLM to generate + cache responses:
fetch(BASE_URL + 'queryllm', {
// The current progress, a number specifying how many responses collected so far:
socket.on("response", (counts) => {
console.log(counts);
if (!counts || FINISHED_QUERY) return;
// Update individual progress bars
const num_llms = llmItemsCurrState.length;
setLLMItems(llmItemsCurrState.map(item => {
if (item.model in counts)
item.progress = counts[item.model] / (max_responses / num_llms) * 100;
return item;
}));
// Update total progress bar
const total_num_resps = Object.keys(counts).reduce((acc, llm_name) => {
return acc + counts[llm_name];
}, 0);
setProgress(total_num_resps / max_responses * 100);
});
// The process has finished; close the connection:
socket.on("finish", (msg) => {
console.log("finished:", msg);
socket.disconnect();
});
};
// Run all prompt permutations through the LLM to generate + cache responses:
const query_llms = () => {
return fetch(BASE_URL + 'api/queryllm', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
@ -191,7 +286,7 @@ const PromptNode = ({ data, id }) => {
temperature: 0.5,
n: numGenerations,
},
no_cache: false,
no_cache: true,
}),
}, rejected).then(function(response) {
return response.json();
@ -205,6 +300,11 @@ const PromptNode = ({ data, id }) => {
// Success! Change status to 'ready':
setStatus('ready');
// Remove progress bars
setProgress(100);
resetLLMItemsProgress();
FINISHED_QUERY = true;
// Save prompt text so we remember what prompt we have responses cache'd for:
setPromptTextOnLastRun(promptText);
@ -246,14 +346,15 @@ const PromptNode = ({ data, id }) => {
alertModal.current.trigger(json.error || 'Unknown error when querying LLM');
}
}, rejected);
};
console.log(pulled_data);
} else {
console.log('Not connected! :(');
alertModal.current.trigger('Missing inputs to one or more template variables.')
// TODO: Blink the names of unconnected params
}
// Now put it all together!
create_progress_scratchpad()
.then(fetch_resp_count)
.then(open_progress_listener_socket)
.then(query_llms)
.catch(rejected);
}
const handleNumGenChange = (event) => {
@ -331,6 +432,7 @@ const PromptNode = ({ data, id }) => {
<label htmlFor="alpaca.7B">Alpaca 7B</label> */}
</div>
</div>
{progress < 100 ? (<Progress value={progress} animate />) : <></>}
<div className="response-preview-container nowheel">
{responsePreviews}
</div>

View File

@ -47,7 +47,7 @@ const VisNode = ({ data, id }) => {
// Grab the input node ids
const input_node_ids = [data.input];
fetch(BASE_URL + 'grabResponses', {
fetch(BASE_URL + 'api/grabResponses', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({

View File

@ -3,12 +3,23 @@ from dataclasses import dataclass
from statistics import mean, median, stdev
from flask import Flask, request, jsonify
from flask_cors import CORS
from flask_socketio import SocketIO
from promptengine.query import PromptLLM, PromptLLMDummy
from promptengine.template import PromptTemplate, PromptPermutationGenerator
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
app = Flask(__name__)
CORS(app)
# Set up CORS for specific routes
cors = CORS(app, resources={r"/api/*": {"origins": "*"}})
# Initialize Socket.IO
# socketio = SocketIO(app, cors_allowed_origins="*", async_mode="gevent")
# import threading
# thread = None
# thread_lock = threading.Lock()
LLM_NAME_MAP = {
'gpt3.5': LLM.ChatGPT,
@ -138,11 +149,119 @@ def reduce_responses(responses: list, vars: list) -> list:
return ret
@app.route('/test', methods=['GET'])
@app.route('/api/test', methods=['GET'])
def test():
return "Hello, world!"
@app.route('/queryllm', methods=['POST'])
# @socketio.on('queryllm', namespace='/queryllm')
# def handleQueryAsync(data):
# print("reached handleQueryAsync")
# socketio.start_background_task(queryLLM, emitLLMResponse)
# def emitLLMResponse(result):
# socketio.emit('response', result)
"""
Testing sockets. The following function can
communicate to React via with the JS code:
const socket = io(BASE_URL + 'queryllm', {
transports: ["websocket"],
cors: {
origin: "http://localhost:3000/",
},
});
socket.on("connect", (data) => {
socket.emit("queryllm", "hello");
});
socket.on("disconnect", (data) => {
console.log("disconnected");
});
socket.on("response", (data) => {
console.log(data);
});
"""
# def background_thread():
# n = 10
# while n > 0:
# socketio.sleep(0.5)
# socketio.emit('response', n, namespace='/queryllm')
# n -= 1
# @socketio.on('queryllm', namespace='/queryllm')
# def testSocket(data):
# print(data)
# global thread
# with thread_lock:
# if thread is None:
# thread = socketio.start_background_task(target=background_thread)
# @socketio.on('queryllm', namespace='/queryllm')
# def handleQuery(data):
# print(data)
# def handleConnect():
# print('here')
# socketio.emit('response', 'goodbye', namespace='/')
@app.route('/api/countQueriesRequired', methods=['POST'])
def countQueries():
"""
Returns how many queries we need to make, given the passed prompt and vars.
POST'd data should be in the form:
{
'prompt': str # the prompt template, with any {{}} vars
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
'llms': list # the list of LLMs you will query
}
"""
data = request.get_json()
if not set(data.keys()).issuperset({'prompt', 'vars', 'llms'}):
return jsonify({'error': 'POST data is improper format.'})
try:
gen_prompts = PromptPermutationGenerator(PromptTemplate(data['prompt']))
all_prompt_permutations = list(gen_prompts(data['vars']))
except Exception as e:
return jsonify({'error': str(e)})
# TODO: Send more informative data back including how many queries per LLM based on cache'd data
num_queries = len(all_prompt_permutations) * len(data['llms'])
ret = jsonify({'count': num_queries})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@app.route('/api/createProgressFile', methods=['POST'])
def createProgressFile():
"""
Creates a temp txt file for storing progress of async LLM queries.
POST'd data should be in the form:
{
'id': str # a unique ID that will be used when calling 'queryllm'
}
"""
data = request.get_json()
if 'id' not in data or not isinstance(data['id'], str) or len(data['id']) == 0:
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
# Create a scratch file for keeping track of how many responses loaded
try:
with open(f"cache/_temp_{data['id']}.txt", 'w') as f:
json.dump({}, f)
ret = jsonify({'success': True})
except Exception as e:
ret = jsonify({'success': False, 'error': str(e)})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
# @socketio.on('connect', namespace='/queryllm')
@app.route('/api/queryllm', methods=['POST'])
async def queryLLM():
"""
Queries LLM(s) given a JSON spec.
@ -170,11 +289,10 @@ async def queryLLM():
return jsonify({'error': 'POST data llm is improper format (not string or list, or of length 0).'})
if isinstance(data['llm'], str):
data['llm'] = [ data['llm'] ]
llms = []
for llm_str in data['llm']:
if llm_str not in LLM_NAME_MAP:
return jsonify({'error': f"LLM named '{llm_str}' is not supported."})
llms.append(LLM_NAME_MAP[llm_str])
if 'no_cache' in data and data['no_cache'] is True:
remove_cached_responses(data['id'])
@ -184,9 +302,13 @@ async def queryLLM():
# For each LLM, generate and cache responses:
responses = {}
llms = data['llm']
params = data['params'] if 'params' in data else {}
tempfilepath = f"cache/_temp_{data['id']}.txt"
async def query(llm_str: str) -> list:
llm = LLM_NAME_MAP[llm_str]
async def query(llm: str) -> list:
# Check that storage path is valid:
cache_filepath = os.path.join('cache', f"{data['id']}-{str(llm.name)}.json")
if not is_valid_filepath(cache_filepath):
@ -202,6 +324,17 @@ async def queryLLM():
print(f'Querying {llm}...')
async for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
resps.append(response)
print(f"collected response from {llm.name}:", str(response))
# Save the number of responses collected to a temp file on disk
with open(tempfilepath, 'r') as f:
txt = f.read().strip()
cur_data = json.loads(txt) if len(txt) > 0 else {}
cur_data[llm_str] = len(resps)
with open(tempfilepath, 'w') as f:
json.dump(cur_data, f)
except Exception as e:
print('error generating responses:', e)
raise e
@ -233,7 +366,7 @@ async def queryLLM():
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@app.route('/execute', methods=['POST'])
@app.route('/api/execute', methods=['POST'])
def execute():
"""
Executes a Python lambda function sent from JavaScript,
@ -348,7 +481,7 @@ def execute():
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@app.route('/checkEvalFunc', methods=['POST'])
@app.route('/api/checkEvalFunc', methods=['POST'])
def checkEvalFunc():
"""
Tries to compile a Python lambda function sent from JavaScript.
@ -378,7 +511,7 @@ def checkEvalFunc():
except Exception as e:
return jsonify({'result': False, 'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
@app.route('/grabResponses', methods=['POST'])
@app.route('/api/grabResponses', methods=['POST'])
def grabResponses():
"""
Returns all responses with the specified id(s)
@ -430,10 +563,14 @@ if __name__ == '__main__':
Produces each dummy response at random intervals between 0.1 and 3 seconds.""",
dest='dummy_responses',
action='store_true')
parser.add_argument('--port', help='The port to run the server on. Defaults to 8000.', type=int, default=8000, nargs='?')
args = parser.parse_args()
if args.dummy_responses:
PromptLLM = PromptLLMDummy
extract_responses = lambda r, llm: r['response']
port = args.port if args.port else 8000
print(f"Serving Flask server on port {port}...")
# socketio.run(app, host="localhost", port=port)
app.run(host="localhost", port=8000, debug=True)

View File

@ -79,8 +79,10 @@ class PromptPipeline:
tasks.append(self._prompt_llm(llm, prompt, n, temperature))
else:
# Blocking. Await + yield a single LLM call.
print('reached')
_, query, response = await self._prompt_llm(llm, prompt, n, temperature)
info = prompt.fill_history
print('back')
# Save the response to a JSON file
responses[str(prompt)] = {
@ -103,8 +105,11 @@ class PromptPipeline:
# Yield responses as they come in
for task in asyncio.as_completed(tasks):
# Collect the response from the earliest completed task
print(f'awaiting a response from {llm.name}...')
prompt, query, response = await task
print('Completed!')
# Each prompt has a history of what was filled in from its base template.
# This data --like, "class", "language", "library" etc --can be useful when parsing responses.
info = prompt.fill_history
@ -149,8 +154,10 @@ class PromptPipeline:
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Dict]:
if llm is LLM.ChatGPT or llm is LLM.GPT4:
print('calling chatgpt and awaiting')
query, response = await call_chatgpt(str(prompt), model=llm, n=n, temperature=temperature)
elif llm is LLM.Alpaca7B:
print('calling dalai alpaca.7b and awaiting')
query, response = await call_dalai(llm_name='alpaca.7B', port=4000, prompt=str(prompt), n=n, temperature=temperature)
else:
raise Exception(f"Language model {llm} is not supported.")

View File

@ -1,5 +1,7 @@
dalaipy==2.0.2
flask[async]
flask_cors
flask_socketio
openai
python-socketio
python-socketio
dalaipy==2.0.2
gevent-websocket

View File

@ -0,0 +1,67 @@
import json, os, asyncio, sys, argparse
from dataclasses import dataclass
from statistics import mean, median, stdev
from flask import Flask, request, jsonify
from flask_cors import CORS
from flask_socketio import SocketIO
from promptengine.query import PromptLLM, PromptLLMDummy
from promptengine.template import PromptTemplate, PromptPermutationGenerator
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
app = Flask(__name__)
# Set up CORS for specific routes
# cors = CORS(app, resources={r"/api/*": {"origins": "*"}})
# Initialize Socket.IO
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="gevent")
# Wait a max of a full minute (60 seconds) for the response count to update, before exiting.
MAX_WAIT_TIME = 60
# import threading
# thread = None
# thread_lock = threading.Lock()
def countdown():
n = 10
while n > 0:
socketio.sleep(0.5)
socketio.emit('response', n, namespace='/queryllm')
n -= 1
def readCounts(id, max_count):
i = 0
n = 0
last_n = 0
while i < MAX_WAIT_TIME and n < max_count:
with open(f'cache/_temp_{id}.txt', 'r') as f:
queries = json.load(f)
n = sum([int(n) for llm, n in queries.items()])
print(n)
socketio.emit('response', queries, namespace='/queryllm')
socketio.sleep(0.1)
if last_n != n:
i = 0
last_n = n
else:
i += 0.1
if i >= MAX_WAIT_TIME:
print(f"Error: Waited maximum {MAX_WAIT_TIME} seconds for response count to update. Exited prematurely.")
socketio.emit('finish', 'max_wait_reached', namespace='/queryllm')
else:
print("All responses loaded!")
socketio.emit('finish', 'success', namespace='/queryllm')
@socketio.on('queryllm', namespace='/queryllm')
def testSocket(data):
readCounts(data['id'], data['max'])
# countdown()
# global thread
# with thread_lock:
# if thread is None:
# thread = socketio.start_background_task(target=countdown)
if __name__ == "__main__":
socketio.run(app, host="localhost", port=8001)