Stream responses and Serve React, Flask, SocketIO from single Python script ()

* Live progress wheels

* Dynamic run tooltip for prompt node

* Run React and Flask and Socketio with single script
This commit is contained in:
ianarawjo 2023-05-06 13:02:47 -04:00 committed by GitHub
parent 34e0e465c1
commit f6d7996f97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 928 additions and 539 deletions

1
.gitignore vendored

@ -1,5 +1,6 @@
*.DS_Store
chain-forge/node_modules
chain-forge/build
__pycache__
python-backend/cache

@ -38,6 +38,7 @@
"react-flow-renderer": "^10.3.17",
"react-plotly.js": "^2.6.0",
"react-scripts": "5.0.1",
"socket.io-client": "^4.6.1",
"styled-components": "^5.3.10",
"uuidv4": "^6.2.13",
"web-vitals": "^2.1.4",
@ -3913,6 +3914,11 @@
"@sinonjs/commons": "^1.7.0"
}
},
"node_modules/@socket.io/component-emitter": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz",
"integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg=="
},
"node_modules/@surma/rollup-plugin-off-main-thread": {
"version": "2.2.3",
"resolved": "https://registry.npmjs.org/@surma/rollup-plugin-off-main-thread/-/rollup-plugin-off-main-thread-2.2.3.tgz",
@ -8718,6 +8724,46 @@
"once": "^1.4.0"
}
},
"node_modules/engine.io-client": {
"version": "6.4.0",
"resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.4.0.tgz",
"integrity": "sha512-GyKPDyoEha+XZ7iEqam49vz6auPnNJ9ZBfy89f+rMMas8AuiMWOZ9PVzu8xb9ZC6rafUqiGHSCfu22ih66E+1g==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1",
"engine.io-parser": "~5.0.3",
"ws": "~8.11.0",
"xmlhttprequest-ssl": "~2.0.0"
}
},
"node_modules/engine.io-client/node_modules/ws": {
"version": "8.11.0",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz",
"integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==",
"engines": {
"node": ">=10.0.0"
},
"peerDependencies": {
"bufferutil": "^4.0.1",
"utf-8-validate": "^5.0.2"
},
"peerDependenciesMeta": {
"bufferutil": {
"optional": true
},
"utf-8-validate": {
"optional": true
}
}
},
"node_modules/engine.io-parser": {
"version": "5.0.6",
"resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.0.6.tgz",
"integrity": "sha512-tjuoZDMAdEhVnSFleYPCtdL2GXwVTGtNjoeJd9IhIG3C1xs9uwxqRNEu5WpnDZCaozwVlK/nuQhpodhXSIMaxw==",
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/enhanced-resolve": {
"version": "5.12.0",
"resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.12.0.tgz",
@ -18338,6 +18384,32 @@
"node": ">=8"
}
},
"node_modules/socket.io-client": {
"version": "4.6.1",
"resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.6.1.tgz",
"integrity": "sha512-5UswCV6hpaRsNg5kkEHVcbBIXEYoVbMQaHJBXJCyEQ+CiFPV1NIOY0XOFWG4XR4GZcB8Kn6AsRs/9cy9TbqVMQ==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.2",
"engine.io-client": "~6.4.0",
"socket.io-parser": "~4.2.1"
},
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/socket.io-parser": {
"version": "4.2.2",
"resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.2.tgz",
"integrity": "sha512-DJtziuKypFkMMHCm2uIshOYC7QaylbtzQwiMYDuCKy3OPkjLzu4B2vAhTlqipRHHzrI0NJeBAizTK7X+6m1jVw==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1"
},
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/sockjs": {
"version": "0.3.24",
"resolved": "https://registry.npmjs.org/sockjs/-/sockjs-0.3.24.tgz",
@ -21030,6 +21102,14 @@
"resolved": "https://registry.npmjs.org/xmlchars/-/xmlchars-2.2.0.tgz",
"integrity": "sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw=="
},
"node_modules/xmlhttprequest-ssl": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.0.0.tgz",
"integrity": "sha512-QKxVRxiRACQcVuQEYFsI1hhkrMlrXHPegbbd1yn9UHOmRxY+si12nQYzri3vbzt8VdTTRviqcKxcyllFas5z2A==",
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/xtend": {
"version": "4.0.2",
"resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz",

@ -33,6 +33,7 @@
"react-flow-renderer": "^10.3.17",
"react-plotly.js": "^2.6.0",
"react-scripts": "5.0.1",
"socket.io-client": "^4.6.1",
"styled-components": "^5.3.10",
"uuidv4": "^6.2.13",
"web-vitals": "^2.1.4",

@ -74,7 +74,7 @@ const EvaluatorNode = ({ data, id }) => {
console.log(script_paths);
// Run evaluator in backend
const codeTextOnRun = codeText + '';
fetch(BASE_URL + 'execute', {
fetch(BASE_URL + 'app/execute', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
@ -155,6 +155,7 @@ const EvaluatorNode = ({ data, id }) => {
status={status}
alertModal={alertModal}
handleRunClick={handleRunClick}
runButtonTooltip="Run evaluator over inputs"
/>
<Handle
type="target"

@ -32,7 +32,7 @@ const InspectorNode = ({ data, id }) => {
console.log(input_node_ids);
// Grab responses associated with those ids:
fetch(BASE_URL + 'grabResponses', {
fetch(BASE_URL + 'app/grabResponses', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({

@ -1,8 +1,8 @@
import { useDisclosure } from '@mantine/hooks';
import { Modal, Button, Group } from '@mantine/core';
import { Modal, Button, Group, RingProgress } from '@mantine/core';
import { IconSettings, IconTrash } from '@tabler/icons-react';
export default function LLMItemButtonGroup( {onClickTrash, onClickSettings} ) {
export default function LLMItemButtonGroup( {onClickTrash, onClickSettings, ringProgress} ) {
const [opened, { open, close }] = useDisclosure(false);
return (
@ -12,7 +12,13 @@ export default function LLMItemButtonGroup( {onClickTrash, onClickSettings} ) {
</Modal>
<Group position="right" style={{float: 'right', height:'20px'}}>
<Button onClick={onClickTrash} size="xs" variant="light" compact color="red" style={{padding: '0px'}} ><IconTrash size={"95%"} /></Button>
{ringProgress !== undefined ?
(ringProgress > 0 ?
(<RingProgress size={20} thickness={3} sections={[{ value: ringProgress, color: ringProgress < 99 ? 'blue' : 'green' }]} width='16px' />) :
(<div className="lds-ring"><div></div><div></div><div></div><div></div></div>))
: (<></>)
}
<Button onClick={onClickTrash} size="xs" variant="light" compact color="red" style={{padding: '0px'}} ><IconTrash size={"95%"} /></Button>
<Button onClick={onClickSettings} size="xs" variant="light" compact>Settings&nbsp;<IconSettings size={"110%"} /></Button>
</Group>
</div>

@ -63,7 +63,7 @@ export default function LLMList({llms, onItemsChange}) {
{items.map((item, index) => (
<Draggable key={item.key} draggableId={item.key} index={index}>
{(provided, snapshot) => (
<LLMListItem provided={provided} snapshot={snapshot} item={item} removeCallback={removeItem} />
<LLMListItem provided={provided} snapshot={snapshot} item={item} removeCallback={removeItem} progress={item.progress} />
)}
</Draggable>
))}

@ -23,7 +23,7 @@ export const DragItem = styled.div`
flex-direction: column;
`;
const LLMListItem = ({ item, provided, snapshot, removeCallback }) => {
const LLMListItem = ({ item, provided, snapshot, removeCallback, progress }) => {
return (
<DragItem
ref={provided.innerRef}
@ -33,7 +33,7 @@ const LLMListItem = ({ item, provided, snapshot, removeCallback }) => {
>
<div>
<CardHeader>{item.emoji}&nbsp;{item.name}</CardHeader>
<LLMItemButtonGroup onClickTrash={() => removeCallback(item.key)} />
<LLMItemButtonGroup onClickTrash={() => removeCallback(item.key)} ringProgress={progress} />
</div>
</DragItem>

@ -4,9 +4,9 @@ import 'react-edit-text/dist/index.css';
import StatusIndicator from './StatusIndicatorComponent';
import AlertModal from './AlertModal';
import { useState, useEffect} from 'react';
import { CloseButton } from '@mantine/core';
import { Tooltip } from '@mantine/core';
export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editable, status, alertModal, handleRunClick }) {
export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editable, status, alertModal, handleRunClick, handleRunHover, runButtonTooltip }) {
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const [statusIndicator, setStatusIndicator] = useState('none');
const [runButton, setRunButton] = useState('none');
@ -33,12 +33,20 @@ export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editabl
useEffect(() => {
if(handleRunClick !== undefined) {
setRunButton(<button className="AmitSahoo45-button-3 nodrag" onClick={handleRunClick}>&#9654;</button>);
const run_btn = (<button className="AmitSahoo45-button-3 nodrag" onClick={handleRunClick} onPointerEnter={handleRunHover}>&#9654;</button>);
if (runButtonTooltip)
setRunButton(
<Tooltip label={runButtonTooltip} withArrow arrowSize={6} arrowRadius={2}>
{run_btn}
</Tooltip>
);
else
setRunButton(run_btn);
}
else {
setRunButton(<></>);
}
}, [handleRunClick]);
}, [handleRunClick, runButtonTooltip]);
const handleCloseButtonClick = () => {
removeNode(nodeId);

@ -1,14 +1,13 @@
import React, { useEffect, useState, useRef, useCallback } from 'react';
import { Handle } from 'react-flow-renderer';
import { Menu, Badge } from '@mantine/core';
import { Menu, Badge, Progress } from '@mantine/core';
import { v4 as uuid } from 'uuid';
import useStore from './store';
import StatusIndicator from './StatusIndicatorComponent'
import NodeLabel from './NodeLabelComponent'
import TemplateHooks from './TemplateHooksComponent'
import LLMList from './LLMListComponent'
import AlertModal from './AlertModal'
import {BASE_URL} from './store';
import io from 'socket.io-client';
// Available LLMs
const allLLMs = [
@ -66,13 +65,29 @@ const PromptNode = ({ data, id }) => {
// Selecting LLM models to prompt
const [llmItems, setLLMItems] = useState(initLLMs.map((i, idx) => ({key: uuid(), ...i})));
const [llmItemsCurrState, setLLMItemsCurrState] = useState([]);
const resetLLMItemsProgress = useCallback(() => {
setLLMItems(llmItemsCurrState.map(item => {
item.progress = undefined;
return item;
}));
}, [llmItemsCurrState]);
// Progress when querying responses
const [progress, setProgress] = useState(100);
const [runTooltip, setRunTooltip] = useState(null);
const triggerAlert = (msg) => {
setProgress(100);
resetLLMItemsProgress();
alertModal.current.trigger(msg);
};
const addModel = useCallback((model) => {
// Get the item for that model
let item = allLLMs.find(llm => llm.model === model);
if (!item) { // This should never trigger, but in case it does:
alertModal.current.trigger(`Could not find model named '${model}' in list of available LLMs.`);
triggerAlert(`Could not find model named '${model}' in list of available LLMs.`);
return;
}
@ -116,6 +131,89 @@ const PromptNode = ({ data, id }) => {
}
};
// Pull all inputs needed to request responses.
// Returns [prompt, vars dict]
const pullInputData = () => {
// Pull data from each source recursively:
const pulled_data = {};
const get_outputs = (varnames, nodeId) => {
varnames.forEach(varname => {
// Find the relevant edge(s):
edges.forEach(e => {
if (e.target == nodeId && e.targetHandle == varname) {
// Get the immediate output:
let out = output(e.source, e.sourceHandle);
// Save the var data from the pulled output
if (varname in pulled_data)
pulled_data[varname] = pulled_data[varname].concat(out);
else
pulled_data[varname] = out;
// Get any vars that the output depends on, and recursively collect those outputs as well:
const n_vars = getNode(e.source).data.vars;
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
get_outputs(n_vars, e.source);
}
});
});
};
get_outputs(templateVars, id);
// Get Pythonic version of the prompt, by adding a $ before any template variables in braces:
const to_py_template_format = (str) => str.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
const py_prompt_template = to_py_template_format(promptText);
// Do the same for the vars, since vars can themselves be prompt templates:
Object.keys(pulled_data).forEach(varname => {
pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val));
});
return [py_prompt_template, pulled_data];
};
// Ask the backend how many responses it needs to collect, given the input data:
const fetchResponseCounts = (prompt, vars, llms, rejected) => {
return fetch(BASE_URL + 'app/countQueriesRequired', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
prompt: prompt,
vars: vars,
llms: llms,
})}, rejected).then(function(response) {
return response.json();
}, rejected).then(function(json) {
if (!json || !json.counts) {
throw new Error('Request was sent and received by backend server, but there was no response.');
}
return json.counts;
}, rejected);
};
// On hover over the 'Run' button, request how many responses are required and update the tooltip. Soft fails.
const handleRunHover = () => {
// Check if there's at least one model in the list; if not, nothing to run on.
if (!llmItemsCurrState || llmItemsCurrState.length == 0) {
setRunTooltip('No LLMs to query.');
return;
}
// Get input data and prompt
const [py_prompt, pulled_vars] = pullInputData();
const llms = llmItemsCurrState.map(item => item.model);
const num_llms = llms.length;
// Fetch response counts from backend
fetchResponseCounts(py_prompt, pulled_vars, llms, (err) => {
console.warn(err.message); // soft fail
}).then((counts) => {
const n = counts[Object.keys(counts)[0]];
const req = n > 1 ? 'requests' : 'request';
setRunTooltip(`Will send ${n} ${req}` + (num_llms > 1 ? ' per LLM' : ''));
});
};
const handleRunClick = (event) => {
// Go through all template hooks (if any) and check they're connected:
const is_fully_connected = templateVars.every(varname => {
@ -123,63 +221,113 @@ const PromptNode = ({ data, id }) => {
return edges.some(e => (e.target == id && e.targetHandle == varname));
});
// console.log(templateHooks);
if (!is_fully_connected) {
console.log('Not connected! :(');
triggerAlert('Missing inputs to one or more template variables.');
return;
}
if (is_fully_connected) {
console.log('Connected!');
console.log('Connected!');
// Check that there is at least one LLM selected:
if (llmItemsCurrState.length === 0) {
alert('Please select at least one LLM to prompt.')
return;
}
// Check that there is at least one LLM selected:
if (llmItemsCurrState.length === 0) {
alert('Please select at least one LLM to prompt.')
return;
}
// Set status indicator
setStatus('loading');
// Set status indicator
setStatus('loading');
setReponsePreviews([]);
// Pull data from each source, recursively:
const pulled_data = {};
const get_outputs = (varnames, nodeId) => {
console.log(varnames);
varnames.forEach(varname => {
// Find the relevant edge(s):
edges.forEach(e => {
if (e.target == nodeId && e.targetHandle == varname) {
// Get the immediate output:
let out = output(e.source, e.sourceHandle);
const [py_prompt_template, pulled_data] = pullInputData();
// Save the var data from the pulled output
if (varname in pulled_data)
pulled_data[varname] = pulled_data[varname].concat(out);
else
pulled_data[varname] = out;
let FINISHED_QUERY = false;
const rejected = (err) => {
setStatus('error');
triggerAlert(err.message);
FINISHED_QUERY = true;
};
// Get any vars that the output depends on, and recursively collect those outputs as well:
const n_vars = getNode(e.source).data.vars;
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
get_outputs(n_vars, e.source);
}
});
});
};
get_outputs(templateVars, id);
// Ask the backend to reset the scratchpad for counting queries:
const create_progress_scratchpad = () => {
return fetch(BASE_URL + 'app/createProgressFile', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
id: id,
})}, rejected);
};
// Get Pythonic version of the prompt, by adding a $ before any template variables in braces:
const to_py_template_format = (str) => str.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
const py_prompt_template = to_py_template_format(promptText);
// Fetch info about the number of queries we'll need to make
const fetch_resp_count = () => fetchResponseCounts(
py_prompt_template, pulled_data, llmItemsCurrState.map(item => item.model), rejected);
// Do the same for the vars, since vars can themselves be prompt templates:
Object.keys(pulled_data).forEach(varname => {
pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val));
// Open a socket to listen for progress
const open_progress_listener_socket = (response_counts) => {
// With the counts information we can create progress bars. Now we load a socket connection to
// the socketio server that will stream to us the current progress:
const socket = io('http://localhost:8001/' + 'queryllm', {
transports: ["websocket"],
cors: {origin: "http://localhost:8000/"},
});
const rejected = (err) => {
setStatus('error');
alertModal.current.trigger(err.message);
};
const max_responses = Object.keys(response_counts).reduce((acc, llm) => acc + response_counts[llm], 0);
// Run all prompt permutations through the LLM to generate + cache responses:
fetch(BASE_URL + 'queryllm', {
// On connect to the server, ask it to give us the current progress
// for task 'queryllm' with id 'id', and stop when it reads progress >= 'max'.
socket.on("connect", (msg) => {
// Initialize progress bars to small amounts
setProgress(5);
setLLMItems(llmItemsCurrState.map(item => {
item.progress = 0;
return item;
}));
// Request progress bar updates
socket.emit("queryllm", {'id': id, 'max': max_responses});
});
// Socket connection could not be established
socket.on("connect_error", (error) => {
console.log("Socket connection failed:", error.message);
socket.disconnect();
});
// Socket disconnected
socket.on("disconnect", (msg) => {
console.log(msg);
});
// The current progress, a number specifying how many responses collected so far:
socket.on("response", (counts) => {
console.log(counts);
if (!counts || FINISHED_QUERY) return;
// Update individual progress bars
const num_llms = llmItemsCurrState.length;
setLLMItems(llmItemsCurrState.map(item => {
if (item.model in counts)
item.progress = counts[item.model] / (max_responses / num_llms) * 100;
return item;
}));
// Update total progress bar
const total_num_resps = Object.keys(counts).reduce((acc, llm_name) => {
return acc + counts[llm_name];
}, 0);
setProgress(total_num_resps / max_responses * 100);
});
// The process has finished; close the connection:
socket.on("finish", (msg) => {
console.log("finished:", msg);
socket.disconnect();
});
};
// Run all prompt permutations through the LLM to generate + cache responses:
const query_llms = () => {
return fetch(BASE_URL + 'app/queryllm', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
@ -191,7 +339,7 @@ const PromptNode = ({ data, id }) => {
temperature: 0.5,
n: numGenerations,
},
no_cache: false,
no_cache: true,
}),
}, rejected).then(function(response) {
return response.json();
@ -205,6 +353,11 @@ const PromptNode = ({ data, id }) => {
// Success! Change status to 'ready':
setStatus('ready');
// Remove progress bars
setProgress(100);
resetLLMItemsProgress();
FINISHED_QUERY = true;
// Save prompt text so we remember what prompt we have responses cache'd for:
setPromptTextOnLastRun(promptText);
@ -246,14 +399,15 @@ const PromptNode = ({ data, id }) => {
alertModal.current.trigger(json.error || 'Unknown error when querying LLM');
}
}, rejected);
};
console.log(pulled_data);
} else {
console.log('Not connected! :(');
alertModal.current.trigger('Missing inputs to one or more template variables.')
// TODO: Blink the names of unconnected params
}
// Now put it all together!
create_progress_scratchpad()
.then(fetch_resp_count)
.then(open_progress_listener_socket)
.then(query_llms)
.catch(rejected);
}
const handleNumGenChange = (event) => {
@ -279,6 +433,8 @@ const PromptNode = ({ data, id }) => {
status={status}
alertModal={alertModal}
handleRunClick={handleRunClick}
handleRunHover={handleRunHover}
runButtonTooltip={runTooltip}
/>
<div className="input-field">
<textarea
@ -331,6 +487,7 @@ const PromptNode = ({ data, id }) => {
<label htmlFor="alpaca.7B">Alpaca 7B</label> */}
</div>
</div>
{progress < 100 ? (<Progress value={progress} animate />) : <></>}
<div className="response-preview-container nowheel">
{responsePreviews}
</div>

@ -47,7 +47,7 @@ const VisNode = ({ data, id }) => {
// Grab the input node ids
const input_node_ids = [data.input];
fetch(BASE_URL + 'grabResponses', {
fetch(BASE_URL + 'app/grabResponses', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({

@ -300,6 +300,10 @@
-moz-transition: all 0.2s ease-out;
transition: all 0.2s ease-out;
}
.AmitSahoo45-button-3:active {
background: #40a829;
color: yellow;
}
.AmitSahoo45-button-3:hover::before {
-moz-animation: sh02 0.5s 0s linear;
@ -364,6 +368,10 @@
-moz-transition: all 0.2s ease-out;
transition: all 0.2s ease-out;
}
.close-button:active {
background: #660000;
color: yellow;
}
.close-button:hover::before {
-moz-animation: sh02 0.5s 0s linear;

@ -1,426 +1,91 @@
import json, os, asyncio, sys, argparse
import json, os, asyncio, sys, argparse, threading
from dataclasses import dataclass
from statistics import mean, median, stdev
from flask import Flask, request, jsonify
from flask_cors import CORS
from flask_socketio import SocketIO
from flask_app import run_server
from promptengine.query import PromptLLM, PromptLLMDummy
from promptengine.template import PromptTemplate, PromptPermutationGenerator
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
app = Flask(__name__)
CORS(app)
# Setup the socketio app
# BUILD_DIR = "../chain-forge/build"
# STATIC_DIR = BUILD_DIR + '/static'
app = Flask(__name__) #, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
LLM_NAME_MAP = {
'gpt3.5': LLM.ChatGPT,
'alpaca.7B': LLM.Alpaca7B,
'gpt4': LLM.GPT4,
}
LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()}
# Initialize Socket.IO
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="gevent")
@dataclass
class ResponseInfo:
"""Stores info about a single response. Passed to evaluator functions."""
text: str
prompt: str
var: str
llm: str
# Set up CORS for specific routes
# cors = CORS(app, resources={r"/api/*": {"origins": "*"}})
def __str__(self):
return self.text
# Wait a max of a full 3 minutes (180 seconds) for the response count to update, before exiting.
MAX_WAIT_TIME = 180
def to_standard_format(r: dict) -> list:
llm = LLM_NAME_MAP_INVERSE[r['llm']]
resp_obj = {
'vars': r['info'],
'llm': llm,
'prompt': r['prompt'],
'responses': extract_responses(r, r['llm']),
'tokens': r['response']['usage'] if 'usage' in r['response'] else {},
}
if 'eval_res' in r:
resp_obj['eval_res'] = r['eval_res']
return resp_obj
def countdown():
n = 10
while n > 0:
socketio.sleep(0.5)
socketio.emit('response', n, namespace='/queryllm')
n -= 1
def get_llm_of_response(response: dict) -> LLM:
return LLM_NAME_MAP[response['llm']]
@socketio.on('queryllm', namespace='/queryllm')
def readCounts(data):
id = data['id']
max_count = data['max']
tempfilepath = f'cache/_temp_{id}.txt'
def get_filenames_with_id(filenames: list, id: str) -> list:
return [
c for c in filenames
if c.split('.')[0] == id or ('-' in c and c[:c.rfind('-')] == id)
]
# Check that temp file exists. If it doesn't, something went wrong with setup on Flask's end:
if not os.path.exists(tempfilepath):
print(f"Error: Temp file not found at path {tempfilepath}. Cannot stream querying progress.")
socketio.emit('finish', 'temp file not found', namespace='/queryllm')
def remove_cached_responses(cache_id: str):
all_cache_files = get_files_at_dir('cache/')
cache_files = get_filenames_with_id(all_cache_files, cache_id)
for filename in cache_files:
os.remove(os.path.join('cache/', filename))
i = 0
last_n = 0
init_run = True
while i < MAX_WAIT_TIME and last_n < max_count:
def load_cache_json(filepath: str) -> dict:
with open(filepath, encoding="utf-8") as f:
responses = json.load(f)
return responses
def run_over_responses(eval_func, responses: dict, scope: str) -> list:
for prompt, resp_obj in responses.items():
res = extract_responses(resp_obj, resp_obj['llm'])
if scope == 'response':
evals = [ # Run evaluator func over every individual response text
eval_func(
ResponseInfo(
text=r,
prompt=prompt,
var=resp_obj['info'],
llm=LLM_NAME_MAP_INVERSE[resp_obj['llm']])
) for r in res
]
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
'mean': mean(evals),
'median': median(evals),
'stdev': stdev(evals) if len(evals) > 1 else 0,
'range': (min(evals), max(evals)),
'items': evals,
}
else: # operate over the entire response batch
ev = eval_func(res)
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
'mean': ev,
'median': ev,
'stdev': 0,
'range': (ev, ev),
'items': [ev],
}
return responses
def reduce_responses(responses: list, vars: list) -> list:
if len(responses) == 0: return responses
# Figure out what vars we still care about (the ones we aren't reducing over):
# NOTE: We are assuming all responses have the same 'vars' keys.
all_vars = set(responses[0]['vars'])
if not all_vars.issuperset(set(vars)):
# There's a var in vars which isn't part of the response.
raise Exception(f"Some vars in {set(vars)} are not in the responses.")
# Get just the vars we want to keep around:
include_vars = list(set(responses[0]['vars']) - set(vars))
# Bucket responses by the remaining var values, where tuples of vars are keys to a dict:
# E.g. {(var1_val, var2_val): [responses] }
bucketed_resp = {}
for r in responses:
print(r)
tup_key = tuple([r['vars'][v] for v in include_vars])
if tup_key in bucketed_resp:
bucketed_resp[tup_key].append(r)
else:
bucketed_resp[tup_key] = [r]
# Perform reduce op across all bucketed responses, collecting them into a single 'meta'-response:
ret = []
for tup_key, resps in bucketed_resp.items():
flat_eval_res = [item for r in resps for item in r['eval_res']['items']]
ret.append({
'vars': {v: r['vars'][v] for r in resps for v in include_vars},
'llm': resps[0]['llm'],
'prompt': [r['prompt'] for r in resps],
'responses': [r['responses'] for r in resps],
'tokens': resps[0]['tokens'],
'eval_res': {
'mean': mean(flat_eval_res),
'median': median(flat_eval_res),
'stdev': stdev(flat_eval_res) if len(flat_eval_res) > 1 else 0,
'range': (min(flat_eval_res), max(flat_eval_res)),
'items': flat_eval_res
}
})
return ret
@app.route('/test', methods=['GET'])
def test():
return "Hello, world!"
@app.route('/queryllm', methods=['POST'])
async def queryLLM():
"""
Queries LLM(s) given a JSON spec.
POST'd data should be in the form:
{
'id': str # a unique ID to refer to this information. Used when cache'ing responses.
'llm': str | list # a string or list of strings specifying the LLM(s) to query
'params': dict # an optional dict of any other params to set when querying the LLMs, like 'temperature', 'n' (num of responses per prompt), etc.
'prompt': str # the prompt template, with any {{}} vars
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
'no_cache': bool (optional) # delete any cache'd responses for 'id' (always call the LLM fresh)
}
"""
data = request.get_json()
# Check that all required info is here:
if not set(data.keys()).issuperset({'llm', 'prompt', 'vars', 'id'}):
return jsonify({'error': 'POST data is improper format.'})
elif not isinstance(data['id'], str) or len(data['id']) == 0:
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
# Verify LLM name(s) (string or list) and convert to enum(s):
if not (isinstance(data['llm'], list) or isinstance(data['llm'], str)) or (isinstance(data['llm'], list) and len(data['llm']) == 0):
return jsonify({'error': 'POST data llm is improper format (not string or list, or of length 0).'})
if isinstance(data['llm'], str):
data['llm'] = [ data['llm'] ]
llms = []
for llm_str in data['llm']:
if llm_str not in LLM_NAME_MAP:
return jsonify({'error': f"LLM named '{llm_str}' is not supported."})
llms.append(LLM_NAME_MAP[llm_str])
if 'no_cache' in data and data['no_cache'] is True:
remove_cached_responses(data['id'])
# Create a cache dir if it doesn't exist:
create_dir_if_not_exists('cache')
# For each LLM, generate and cache responses:
responses = {}
params = data['params'] if 'params' in data else {}
async def query(llm: str) -> list:
# Check that storage path is valid:
cache_filepath = os.path.join('cache', f"{data['id']}-{str(llm.name)}.json")
if not is_valid_filepath(cache_filepath):
return jsonify({'error': f'Invalid filepath: {cache_filepath}'})
# Create an object to query the LLM, passing a file for cache'ing responses
prompter = PromptLLM(data['prompt'], storageFile=cache_filepath)
# Prompt the LLM with all permutations of the input prompt template:
# NOTE: If the responses are already cache'd, this just loads them (no LLM is queried, saving $$$)
resps = []
try:
print(f'Querying {llm}...')
async for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
resps.append(response)
except Exception as e:
print('error generating responses:', e)
raise e
# Open the temp file to read the progress so far:
try:
with open(tempfilepath, 'r') as f:
queries = json.load(f)
except FileNotFoundError as e:
# If the temp file was deleted during executing, the Flask 'queryllm' func must've terminated successfully:
socketio.emit('finish', 'success', namespace='/queryllm')
return
return {'llm': llm, 'responses': resps}
# Calculate the total sum of responses
# TODO: This is a naive approach; we need to make this more complex and factor in cache'ing in future
n = sum([int(n) for llm, n in queries.items()])
# If something's changed...
if init_run or last_n != n:
i = 0
last_n = n
init_run = False
# Update the React front-end with the current progress
socketio.emit('response', queries, namespace='/queryllm')
try:
# Request responses simultaneously across LLMs
tasks = [query(llm) for llm in llms]
else:
i += 0.1
# Wait a bit before reading the file again
socketio.sleep(0.1)
# Await the responses from all queried LLMs
llm_results = await asyncio.gather(*tasks)
for item in llm_results:
responses[item['llm']] = item['responses']
if i >= MAX_WAIT_TIME:
print(f"Error: Waited maximum {MAX_WAIT_TIME} seconds for response count to update. Exited prematurely.")
socketio.emit('finish', 'max_wait_reached', namespace='/queryllm')
else:
print("All responses loaded!")
socketio.emit('finish', 'success', namespace='/queryllm')
except Exception as e:
return jsonify({'error': str(e)})
def run_socketio_server(socketio, port):
socketio.run(app, host="localhost", port=8001)
# Convert the responses into a more standardized format with less information
res = [
to_standard_format(r)
for rs in responses.values()
for r in rs
]
# Return all responses for all LLMs
print('returning responses:', res)
ret = jsonify({'responses': res})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@app.route('/execute', methods=['POST'])
def execute():
"""
Executes a Python lambda function sent from JavaScript,
over all cache'd responses with given id's.
POST'd data should be in the form:
{
'id': # a unique ID to refer to this information. Used when cache'ing responses.
'code': str, # the body of the lambda function to evaluate, in form: lambda responses: <body>
'responses': str | List[str] # the responses to run on; a unique ID or list of unique IDs of cache'd data,
'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch.
# If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt.
'reduce_vars': unspecified | List[str] # the 'vars' to average over (mean, median, stdev, range)
'script_paths': unspecified | List[str] # the paths to scripts to be added to the path before the lambda function is evaluated
}
NOTE: This should only be run on your server on code you trust.
There is no sandboxing; no safety. We assume you are the creator of the code.
"""
data = request.get_json()
# Check that all required info is here:
if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope'}):
return jsonify({'error': 'POST data is improper format.'})
if not isinstance(data['id'], str) or len(data['id']) == 0:
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
if data['scope'] not in ('response', 'batch'):
return jsonify({'error': "POST data scope is unknown. Must be either 'response' or 'batch'."})
if __name__ == "__main__":
# Check that the filepath used to cache eval'd responses is valid:
cache_filepath = os.path.join('cache', f"{data['id']}.json")
if not is_valid_filepath(cache_filepath):
return jsonify({'error': f'Invalid filepath: {cache_filepath}'})
# Check format of responses:
if not (isinstance(data['responses'], str) or isinstance(data['responses'], list)):
return jsonify({'error': 'POST data responses is improper format.'})
elif isinstance(data['responses'], str):
data['responses'] = [ data['responses'] ]
# add the path to any scripts to the path:
try:
if 'script_paths' in data:
for script_path in data['script_paths']:
# get the folder of the script_path:
script_folder = os.path.dirname(script_path)
# check that the script_folder is valid, and it contains __init__.py
if not os.path.exists(script_folder):
print(script_folder, 'is not a valid script path.')
print(os.path.exists(script_folder))
continue
# add it to the path:
sys.path.append(script_folder)
print(f'added {script_folder} to sys.path')
except Exception as e:
return jsonify({'error': f'Could not add script path to sys.path. Error message:\n{str(e)}'})
# Create the evaluator function
# DANGER DANGER!
try:
exec(data['code'], globals())
# Double-check that there is an 'evaluate' method in our namespace.
# This will throw a NameError if not:
evaluate
except Exception as e:
return jsonify({'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
# Load all responses with the given ID:
all_cache_files = get_files_at_dir('cache/')
all_evald_responses = []
for cache_id in data['responses']:
cache_files = get_filenames_with_id(all_cache_files, cache_id)
if len(cache_files) == 0:
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
# To avoid loading all response files into memory at once, we'll run the evaluator on each file:
for filename in cache_files:
# Load the raw responses from the cache
responses = load_cache_json(os.path.join('cache', filename))
if len(responses) == 0: continue
# Run the evaluator over them:
# NOTE: 'evaluate' here was defined dynamically from 'exec' above.
try:
evald_responses = run_over_responses(evaluate, responses, scope=data['scope'])
except Exception as e:
return jsonify({'error': f'Error encountered while trying to run "evaluate" method:\n{str(e)}'})
# Convert to standard format:
std_evald_responses = [
to_standard_format({'prompt': prompt, **res_obj})
for prompt, res_obj in evald_responses.items()
]
# Perform any reduction operations:
if 'reduce_vars' in data and len(data['reduce_vars']) > 0:
std_evald_responses = reduce_responses(
std_evald_responses,
vars=data['reduce_vars']
)
all_evald_responses.extend(std_evald_responses)
# Store the evaluated responses in a new cache json:
with open(cache_filepath, "w") as f:
json.dump(all_evald_responses, f)
ret = jsonify({'responses': all_evald_responses})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@app.route('/checkEvalFunc', methods=['POST'])
def checkEvalFunc():
"""
Tries to compile a Python lambda function sent from JavaScript.
Returns a dict with 'result':true if it compiles without raising an exception;
'result':false (and an 'error' property with a message) if not.
POST'd data should be in form:
{
'code': str, # the body of the lambda function to evaluate, in form: lambda responses: <body>
}
NOTE: This should only be run on your server on code you trust.
There is no sandboxing; no safety. We assume you are the creator of the code.
"""
data = request.get_json()
if 'code' not in data:
return jsonify({'result': False, 'error': f'Could not evaluate code. Error message:\n{str(e)}'})
# DANGER DANGER! Running exec on code passed through front-end. Make sure it's trusted!
try:
exec(data['code'], globals())
# Double-check that there is an 'evaluate' method in our namespace.
# This will throw a NameError if not:
evaluate
return jsonify({'result': True})
except Exception as e:
return jsonify({'result': False, 'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
@app.route('/grabResponses', methods=['POST'])
def grabResponses():
"""
Returns all responses with the specified id(s)
POST'd data should be in the form:
{
'responses': <the ids to grab>
}
"""
data = request.get_json()
# Check format of responses:
if not (isinstance(data['responses'], str) or isinstance(data['responses'], list)):
return jsonify({'error': 'POST data responses is improper format.'})
elif isinstance(data['responses'], str):
data['responses'] = [ data['responses'] ]
# Load all responses with the given ID:
all_cache_files = get_files_at_dir('cache/')
print(all_cache_files)
responses = []
for cache_id in data['responses']:
cache_files = get_filenames_with_id(all_cache_files, cache_id)
if len(cache_files) == 0:
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
for filename in cache_files:
res = load_cache_json(os.path.join('cache', filename))
if isinstance(res, dict):
# Convert to standard response format
res = [
to_standard_format({'prompt': prompt, **res_obj})
for prompt, res_obj in res.items()
]
responses.extend(res)
print(responses)
ret = jsonify({'responses': responses})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='This script spins up a Flask server that serves as the backend for ChainForge')
# Turn on to disable all outbound LLM API calls and replace them with dummy calls
@ -430,10 +95,15 @@ if __name__ == '__main__':
Produces each dummy response at random intervals between 0.1 and 3 seconds.""",
dest='dummy_responses',
action='store_true')
parser.add_argument('--port', help='The port to run the server on. Defaults to 8000.', type=int, default=8000, nargs='?')
args = parser.parse_args()
if args.dummy_responses:
PromptLLM = PromptLLMDummy
extract_responses = lambda r, llm: r['response']
port = args.port if args.port else 8000
app.run(host="localhost", port=8000, debug=True)
# Spin up separate thread for socketio app, on port+1 (8001 default)
print(f"Serving SocketIO server on port {port+1}...")
t1 = threading.Thread(target=run_socketio_server, args=[socketio, port+1])
t1.start()
print(f"Serving Flask server on port {port}...")
run_server(host="localhost", port=port, cmd_args=args)

512
python-backend/flask_app.py Normal file

@ -0,0 +1,512 @@
import json, os, asyncio, sys, argparse, threading
from dataclasses import dataclass
from statistics import mean, median, stdev
from flask import Flask, request, jsonify, render_template, send_from_directory
from flask_cors import CORS
from flask_socketio import SocketIO
from promptengine.query import PromptLLM, PromptLLMDummy
from promptengine.template import PromptTemplate, PromptPermutationGenerator
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
# Setup Flask app to serve static version of React front-end
BUILD_DIR = "../chain-forge/build"
STATIC_DIR = BUILD_DIR + '/static'
app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
# Set up CORS for specific routes
cors = CORS(app, resources={r"/*": {"origins": "*"}})
# Serve React app (static; no hot reloading)
@app.route("/")
def index():
return render_template("index.html")
LLM_NAME_MAP = {
'gpt3.5': LLM.ChatGPT,
'alpaca.7B': LLM.Alpaca7B,
'gpt4': LLM.GPT4,
}
LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()}
@dataclass
class ResponseInfo:
"""Stores info about a single response. Passed to evaluator functions."""
text: str
prompt: str
var: str
llm: str
def __str__(self):
return self.text
def to_standard_format(r: dict) -> list:
llm = LLM_NAME_MAP_INVERSE[r['llm']]
resp_obj = {
'vars': r['info'],
'llm': llm,
'prompt': r['prompt'],
'responses': extract_responses(r, r['llm']),
'tokens': r['response']['usage'] if 'usage' in r['response'] else {},
}
if 'eval_res' in r:
resp_obj['eval_res'] = r['eval_res']
return resp_obj
def get_llm_of_response(response: dict) -> LLM:
return LLM_NAME_MAP[response['llm']]
def get_filenames_with_id(filenames: list, id: str) -> list:
return [
c for c in filenames
if c.split('.')[0] == id or ('-' in c and c[:c.rfind('-')] == id)
]
def remove_cached_responses(cache_id: str):
all_cache_files = get_files_at_dir('cache/')
cache_files = get_filenames_with_id(all_cache_files, cache_id)
for filename in cache_files:
os.remove(os.path.join('cache/', filename))
def load_cache_json(filepath: str) -> dict:
with open(filepath, encoding="utf-8") as f:
responses = json.load(f)
return responses
def run_over_responses(eval_func, responses: dict, scope: str) -> list:
for prompt, resp_obj in responses.items():
res = extract_responses(resp_obj, resp_obj['llm'])
if scope == 'response':
evals = [ # Run evaluator func over every individual response text
eval_func(
ResponseInfo(
text=r,
prompt=prompt,
var=resp_obj['info'],
llm=LLM_NAME_MAP_INVERSE[resp_obj['llm']])
) for r in res
]
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
'mean': mean(evals),
'median': median(evals),
'stdev': stdev(evals) if len(evals) > 1 else 0,
'range': (min(evals), max(evals)),
'items': evals,
}
else: # operate over the entire response batch
ev = eval_func(res)
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
'mean': ev,
'median': ev,
'stdev': 0,
'range': (ev, ev),
'items': [ev],
}
return responses
def reduce_responses(responses: list, vars: list) -> list:
if len(responses) == 0: return responses
# Figure out what vars we still care about (the ones we aren't reducing over):
# NOTE: We are assuming all responses have the same 'vars' keys.
all_vars = set(responses[0]['vars'])
if not all_vars.issuperset(set(vars)):
# There's a var in vars which isn't part of the response.
raise Exception(f"Some vars in {set(vars)} are not in the responses.")
# Get just the vars we want to keep around:
include_vars = list(set(responses[0]['vars']) - set(vars))
# Bucket responses by the remaining var values, where tuples of vars are keys to a dict:
# E.g. {(var1_val, var2_val): [responses] }
bucketed_resp = {}
for r in responses:
tup_key = tuple([r['vars'][v] for v in include_vars])
if tup_key in bucketed_resp:
bucketed_resp[tup_key].append(r)
else:
bucketed_resp[tup_key] = [r]
# Perform reduce op across all bucketed responses, collecting them into a single 'meta'-response:
ret = []
for tup_key, resps in bucketed_resp.items():
flat_eval_res = [item for r in resps for item in r['eval_res']['items']]
ret.append({
'vars': {v: r['vars'][v] for r in resps for v in include_vars},
'llm': resps[0]['llm'],
'prompt': [r['prompt'] for r in resps],
'responses': [r['responses'] for r in resps],
'tokens': resps[0]['tokens'],
'eval_res': {
'mean': mean(flat_eval_res),
'median': median(flat_eval_res),
'stdev': stdev(flat_eval_res) if len(flat_eval_res) > 1 else 0,
'range': (min(flat_eval_res), max(flat_eval_res)),
'items': flat_eval_res
}
})
return ret
@app.route('/app/countQueriesRequired', methods=['POST'])
def countQueries():
"""
Returns how many queries we need to make, given the passed prompt and vars.
POST'd data should be in the form:
{
'prompt': str # the prompt template, with any {{}} vars
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
'llms': list # the list of LLMs you will query
}
"""
data = request.get_json()
if not set(data.keys()).issuperset({'prompt', 'vars', 'llms'}):
return jsonify({'error': 'POST data is improper format.'})
try:
gen_prompts = PromptPermutationGenerator(PromptTemplate(data['prompt']))
all_prompt_permutations = list(gen_prompts(data['vars']))
except Exception as e:
return jsonify({'error': str(e)})
# TODO: Send more informative data back including how many queries per LLM based on cache'd data
num_queries = {} # len(all_prompt_permutations) * len(data['llms'])
for llm in data['llms']:
num_queries[llm] = len(all_prompt_permutations)
ret = jsonify({'counts': num_queries})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@app.route('/app/createProgressFile', methods=['POST'])
def createProgressFile():
"""
Creates a temp txt file for storing progress of async LLM queries.
POST'd data should be in the form:
{
'id': str # a unique ID that will be used when calling 'queryllm'
}
"""
data = request.get_json()
if 'id' not in data or not isinstance(data['id'], str) or len(data['id']) == 0:
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
# Create a scratch file for keeping track of how many responses loaded
try:
with open(f"cache/_temp_{data['id']}.txt", 'w') as f:
json.dump({}, f)
ret = jsonify({'success': True})
except Exception as e:
ret = jsonify({'success': False, 'error': str(e)})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
# @socketio.on('connect', namespace='/queryllm')
@app.route('/app/queryllm', methods=['POST'])
async def queryLLM():
"""
Queries LLM(s) given a JSON spec.
POST'd data should be in the form:
{
'id': str # a unique ID to refer to this information. Used when cache'ing responses.
'llm': str | list # a string or list of strings specifying the LLM(s) to query
'params': dict # an optional dict of any other params to set when querying the LLMs, like 'temperature', 'n' (num of responses per prompt), etc.
'prompt': str # the prompt template, with any {{}} vars
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
'no_cache': bool (optional) # delete any cache'd responses for 'id' (always call the LLM fresh)
}
"""
data = request.get_json()
# Check that all required info is here:
if not set(data.keys()).issuperset({'llm', 'prompt', 'vars', 'id'}):
return jsonify({'error': 'POST data is improper format.'})
elif not isinstance(data['id'], str) or len(data['id']) == 0:
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
# Verify LLM name(s) (string or list) and convert to enum(s):
if not (isinstance(data['llm'], list) or isinstance(data['llm'], str)) or (isinstance(data['llm'], list) and len(data['llm']) == 0):
return jsonify({'error': 'POST data llm is improper format (not string or list, or of length 0).'})
if isinstance(data['llm'], str):
data['llm'] = [ data['llm'] ]
for llm_str in data['llm']:
if llm_str not in LLM_NAME_MAP:
return jsonify({'error': f"LLM named '{llm_str}' is not supported."})
if 'no_cache' in data and data['no_cache'] is True:
remove_cached_responses(data['id'])
# Create a cache dir if it doesn't exist:
create_dir_if_not_exists('cache')
# For each LLM, generate and cache responses:
responses = {}
llms = data['llm']
params = data['params'] if 'params' in data else {}
tempfilepath = f"cache/_temp_{data['id']}.txt"
async def query(llm_str: str) -> list:
llm = LLM_NAME_MAP[llm_str]
# Check that storage path is valid:
cache_filepath = os.path.join('cache', f"{data['id']}-{str(llm.name)}.json")
if not is_valid_filepath(cache_filepath):
return jsonify({'error': f'Invalid filepath: {cache_filepath}'})
# Create an object to query the LLM, passing a file for cache'ing responses
prompter = PromptLLM(data['prompt'], storageFile=cache_filepath)
# Prompt the LLM with all permutations of the input prompt template:
# NOTE: If the responses are already cache'd, this just loads them (no LLM is queried, saving $$$)
resps = []
try:
print(f'Querying {llm}...')
async for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
resps.append(response)
print(f"collected response from {llm.name}:", str(response))
# Save the number of responses collected to a temp file on disk
with open(tempfilepath, 'r') as f:
txt = f.read().strip()
cur_data = json.loads(txt) if len(txt) > 0 else {}
cur_data[llm_str] = len(resps)
with open(tempfilepath, 'w') as f:
json.dump(cur_data, f)
except Exception as e:
print('error generating responses:', e)
raise e
return {'llm': llm, 'responses': resps}
try:
# Request responses simultaneously across LLMs
tasks = [query(llm) for llm in llms]
# Await the responses from all queried LLMs
llm_results = await asyncio.gather(*tasks)
for item in llm_results:
responses[item['llm']] = item['responses']
except Exception as e:
return jsonify({'error': str(e)})
# Convert the responses into a more standardized format with less information
res = [
to_standard_format(r)
for rs in responses.values()
for r in rs
]
# Remove the temp file used to stream progress updates:
if os.path.exists(tempfilepath):
os.remove(tempfilepath)
# Return all responses for all LLMs
print('returning responses:', res)
ret = jsonify({'responses': res})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@app.route('/app/execute', methods=['POST'])
def execute():
"""
Executes a Python lambda function sent from JavaScript,
over all cache'd responses with given id's.
POST'd data should be in the form:
{
'id': # a unique ID to refer to this information. Used when cache'ing responses.
'code': str, # the body of the lambda function to evaluate, in form: lambda responses: <body>
'responses': str | List[str] # the responses to run on; a unique ID or list of unique IDs of cache'd data,
'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch.
# If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt.
'reduce_vars': unspecified | List[str] # the 'vars' to average over (mean, median, stdev, range)
'script_paths': unspecified | List[str] # the paths to scripts to be added to the path before the lambda function is evaluated
}
NOTE: This should only be run on your server on code you trust.
There is no sandboxing; no safety. We assume you are the creator of the code.
"""
data = request.get_json()
# Check that all required info is here:
if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope'}):
return jsonify({'error': 'POST data is improper format.'})
if not isinstance(data['id'], str) or len(data['id']) == 0:
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
if data['scope'] not in ('response', 'batch'):
return jsonify({'error': "POST data scope is unknown. Must be either 'response' or 'batch'."})
# Check that the filepath used to cache eval'd responses is valid:
cache_filepath = os.path.join('cache', f"{data['id']}.json")
if not is_valid_filepath(cache_filepath):
return jsonify({'error': f'Invalid filepath: {cache_filepath}'})
# Check format of responses:
if not (isinstance(data['responses'], str) or isinstance(data['responses'], list)):
return jsonify({'error': 'POST data responses is improper format.'})
elif isinstance(data['responses'], str):
data['responses'] = [ data['responses'] ]
# add the path to any scripts to the path:
try:
if 'script_paths' in data:
for script_path in data['script_paths']:
# get the folder of the script_path:
script_folder = os.path.dirname(script_path)
# check that the script_folder is valid, and it contains __init__.py
if not os.path.exists(script_folder):
print(script_folder, 'is not a valid script path.')
continue
# add it to the path:
sys.path.append(script_folder)
print(f'added {script_folder} to sys.path')
except Exception as e:
return jsonify({'error': f'Could not add script path to sys.path. Error message:\n{str(e)}'})
# Create the evaluator function
# DANGER DANGER!
try:
exec(data['code'], globals())
# Double-check that there is an 'evaluate' method in our namespace.
# This will throw a NameError if not:
evaluate
except Exception as e:
return jsonify({'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
# Load all responses with the given ID:
all_cache_files = get_files_at_dir('cache/')
all_evald_responses = []
for cache_id in data['responses']:
cache_files = get_filenames_with_id(all_cache_files, cache_id)
if len(cache_files) == 0:
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
# To avoid loading all response files into memory at once, we'll run the evaluator on each file:
for filename in cache_files:
# Load the raw responses from the cache
responses = load_cache_json(os.path.join('cache', filename))
if len(responses) == 0: continue
# Run the evaluator over them:
# NOTE: 'evaluate' here was defined dynamically from 'exec' above.
try:
evald_responses = run_over_responses(evaluate, responses, scope=data['scope'])
except Exception as e:
return jsonify({'error': f'Error encountered while trying to run "evaluate" method:\n{str(e)}'})
# Convert to standard format:
std_evald_responses = [
to_standard_format({'prompt': prompt, **res_obj})
for prompt, res_obj in evald_responses.items()
]
# Perform any reduction operations:
if 'reduce_vars' in data and len(data['reduce_vars']) > 0:
std_evald_responses = reduce_responses(
std_evald_responses,
vars=data['reduce_vars']
)
all_evald_responses.extend(std_evald_responses)
# Store the evaluated responses in a new cache json:
with open(cache_filepath, "w") as f:
json.dump(all_evald_responses, f)
ret = jsonify({'responses': all_evald_responses})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@app.route('/app/checkEvalFunc', methods=['POST'])
def checkEvalFunc():
"""
Tries to compile a Python lambda function sent from JavaScript.
Returns a dict with 'result':true if it compiles without raising an exception;
'result':false (and an 'error' property with a message) if not.
POST'd data should be in form:
{
'code': str, # the body of the lambda function to evaluate, in form: lambda responses: <body>
}
NOTE: This should only be run on your server on code you trust.
There is no sandboxing; no safety. We assume you are the creator of the code.
"""
data = request.get_json()
if 'code' not in data:
return jsonify({'result': False, 'error': f'Could not evaluate code. Error message:\n{str(e)}'})
# DANGER DANGER! Running exec on code passed through front-end. Make sure it's trusted!
try:
exec(data['code'], globals())
# Double-check that there is an 'evaluate' method in our namespace.
# This will throw a NameError if not:
evaluate
return jsonify({'result': True})
except Exception as e:
return jsonify({'result': False, 'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
@app.route('/app/grabResponses', methods=['POST'])
def grabResponses():
"""
Returns all responses with the specified id(s)
POST'd data should be in the form:
{
'responses': <the ids to grab>
}
"""
data = request.get_json()
# Check format of responses:
if not (isinstance(data['responses'], str) or isinstance(data['responses'], list)):
return jsonify({'error': 'POST data responses is improper format.'})
elif isinstance(data['responses'], str):
data['responses'] = [ data['responses'] ]
# Load all responses with the given ID:
all_cache_files = get_files_at_dir('cache/')
responses = []
for cache_id in data['responses']:
cache_files = get_filenames_with_id(all_cache_files, cache_id)
if len(cache_files) == 0:
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
for filename in cache_files:
res = load_cache_json(os.path.join('cache', filename))
if isinstance(res, dict):
# Convert to standard response format
res = [
to_standard_format({'prompt': prompt, **res_obj})
for prompt, res_obj in res.items()
]
responses.extend(res)
ret = jsonify({'responses': responses})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
def run_server(host="", port=8000, cmd_args=None):
if cmd_args is not None and cmd_args.dummy_responses:
global PromptLLM
global extract_responses
PromptLLM = PromptLLMDummy
extract_responses = lambda r, llm: r['response']
app.run(host=host, port=port)
if __name__ == '__main__':
print("Run app.py instead.")

@ -1,7 +1,7 @@
from typing import Dict, Tuple, List, Union
from enum import Enum
import openai
import json, os, time
import json, os, time, asyncio
DALAI_MODEL = None
DALAI_RESPONSE = None
@ -23,6 +23,7 @@ async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float =
if model not in model_map:
raise Exception(f"Could not find OpenAI chat model {model}")
model = model_map[model]
print(f"Querying OpenAI model '{model}' with prompt '{prompt}'...")
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
query = {
"model": model,
@ -102,7 +103,7 @@ async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperat
# Blocking --wait for request to complete:
while DALAI_RESPONSE is None:
time.sleep(0.01)
await asyncio.sleep(0.01)
response = DALAI_RESPONSE['response']
if response[-5:] == '<end>': # strip ending <end> tag, if present

@ -1,5 +1,8 @@
dalaipy==2.0.2
flask[async]
flask_cors
flask_socketio
openai
python-socketio
python-socketio
dalaipy==2.0.2
gevent-websocket
werkzeug

@ -1,51 +0,0 @@
<html>
<head>
<title>Test Flask backend</title>
</head>
<body>
<button onclick="test_query()">Test query LLM!</button>
<button onclick="test_exec()">Test evaluating responses!</button>
</body>
<script>
function test_exec() {
const response = fetch(BASE_URL + 'execute', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
id: 'eval',
code: 'return len(response)',
responses: 'test',
}),
}).then(function(response) {
return response.json();
}).then(function(json) {
console.log(json);
});
}
function test_query() {
const response = fetch(BASE_URL + 'queryllm', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
id: 'test',
llm: 'gpt3.5',
params: {
temperature: 1.0,
n: 1,
},
prompt: 'What is the capital of ${country}?',
vars: {
country: ['Sweden', 'Uganda', 'Japan']
},
}),
}).then(function(response) {
return response.json();
}).then(function(json) {
console.log(json);
});
}
</script>
</html>

@ -1,8 +0,0 @@
from promptengine.utils import LLM, call_dalai
if __name__ == '__main__':
print("Testing a single response...")
call_dalai(llm_name='alpaca.7B', port=4000, prompt='Write a poem about how an AI will escape the prison of its containment.', n=1, temperature=0.5)
print("Testing multiple responses...")
call_dalai(llm_name='alpaca.7B', port=4000, prompt='Was George Washington a good person?', n=3, temperature=0.5)