mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Stream responses and Serve React, Flask, SocketIO from single Python script (#23)
* Live progress wheels * Dynamic run tooltip for prompt node * Run React and Flask and Socketio with single script
This commit is contained in:
parent
34e0e465c1
commit
f6d7996f97
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,5 +1,6 @@
|
||||
*.DS_Store
|
||||
chain-forge/node_modules
|
||||
chain-forge/build
|
||||
__pycache__
|
||||
python-backend/cache
|
||||
|
||||
|
80
chain-forge/package-lock.json
generated
80
chain-forge/package-lock.json
generated
@ -38,6 +38,7 @@
|
||||
"react-flow-renderer": "^10.3.17",
|
||||
"react-plotly.js": "^2.6.0",
|
||||
"react-scripts": "5.0.1",
|
||||
"socket.io-client": "^4.6.1",
|
||||
"styled-components": "^5.3.10",
|
||||
"uuidv4": "^6.2.13",
|
||||
"web-vitals": "^2.1.4",
|
||||
@ -3913,6 +3914,11 @@
|
||||
"@sinonjs/commons": "^1.7.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@socket.io/component-emitter": {
|
||||
"version": "3.1.0",
|
||||
"resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz",
|
||||
"integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg=="
|
||||
},
|
||||
"node_modules/@surma/rollup-plugin-off-main-thread": {
|
||||
"version": "2.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@surma/rollup-plugin-off-main-thread/-/rollup-plugin-off-main-thread-2.2.3.tgz",
|
||||
@ -8718,6 +8724,46 @@
|
||||
"once": "^1.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/engine.io-client": {
|
||||
"version": "6.4.0",
|
||||
"resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.4.0.tgz",
|
||||
"integrity": "sha512-GyKPDyoEha+XZ7iEqam49vz6auPnNJ9ZBfy89f+rMMas8AuiMWOZ9PVzu8xb9ZC6rafUqiGHSCfu22ih66E+1g==",
|
||||
"dependencies": {
|
||||
"@socket.io/component-emitter": "~3.1.0",
|
||||
"debug": "~4.3.1",
|
||||
"engine.io-parser": "~5.0.3",
|
||||
"ws": "~8.11.0",
|
||||
"xmlhttprequest-ssl": "~2.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/engine.io-client/node_modules/ws": {
|
||||
"version": "8.11.0",
|
||||
"resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz",
|
||||
"integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==",
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"bufferutil": "^4.0.1",
|
||||
"utf-8-validate": "^5.0.2"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"bufferutil": {
|
||||
"optional": true
|
||||
},
|
||||
"utf-8-validate": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/engine.io-parser": {
|
||||
"version": "5.0.6",
|
||||
"resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.0.6.tgz",
|
||||
"integrity": "sha512-tjuoZDMAdEhVnSFleYPCtdL2GXwVTGtNjoeJd9IhIG3C1xs9uwxqRNEu5WpnDZCaozwVlK/nuQhpodhXSIMaxw==",
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/enhanced-resolve": {
|
||||
"version": "5.12.0",
|
||||
"resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.12.0.tgz",
|
||||
@ -18338,6 +18384,32 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/socket.io-client": {
|
||||
"version": "4.6.1",
|
||||
"resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.6.1.tgz",
|
||||
"integrity": "sha512-5UswCV6hpaRsNg5kkEHVcbBIXEYoVbMQaHJBXJCyEQ+CiFPV1NIOY0XOFWG4XR4GZcB8Kn6AsRs/9cy9TbqVMQ==",
|
||||
"dependencies": {
|
||||
"@socket.io/component-emitter": "~3.1.0",
|
||||
"debug": "~4.3.2",
|
||||
"engine.io-client": "~6.4.0",
|
||||
"socket.io-parser": "~4.2.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/socket.io-parser": {
|
||||
"version": "4.2.2",
|
||||
"resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.2.tgz",
|
||||
"integrity": "sha512-DJtziuKypFkMMHCm2uIshOYC7QaylbtzQwiMYDuCKy3OPkjLzu4B2vAhTlqipRHHzrI0NJeBAizTK7X+6m1jVw==",
|
||||
"dependencies": {
|
||||
"@socket.io/component-emitter": "~3.1.0",
|
||||
"debug": "~4.3.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/sockjs": {
|
||||
"version": "0.3.24",
|
||||
"resolved": "https://registry.npmjs.org/sockjs/-/sockjs-0.3.24.tgz",
|
||||
@ -21030,6 +21102,14 @@
|
||||
"resolved": "https://registry.npmjs.org/xmlchars/-/xmlchars-2.2.0.tgz",
|
||||
"integrity": "sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw=="
|
||||
},
|
||||
"node_modules/xmlhttprequest-ssl": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.0.0.tgz",
|
||||
"integrity": "sha512-QKxVRxiRACQcVuQEYFsI1hhkrMlrXHPegbbd1yn9UHOmRxY+si12nQYzri3vbzt8VdTTRviqcKxcyllFas5z2A==",
|
||||
"engines": {
|
||||
"node": ">=0.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/xtend": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz",
|
||||
|
@ -33,6 +33,7 @@
|
||||
"react-flow-renderer": "^10.3.17",
|
||||
"react-plotly.js": "^2.6.0",
|
||||
"react-scripts": "5.0.1",
|
||||
"socket.io-client": "^4.6.1",
|
||||
"styled-components": "^5.3.10",
|
||||
"uuidv4": "^6.2.13",
|
||||
"web-vitals": "^2.1.4",
|
||||
|
@ -74,7 +74,7 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
console.log(script_paths);
|
||||
// Run evaluator in backend
|
||||
const codeTextOnRun = codeText + '';
|
||||
fetch(BASE_URL + 'execute', {
|
||||
fetch(BASE_URL + 'app/execute', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
@ -155,6 +155,7 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
status={status}
|
||||
alertModal={alertModal}
|
||||
handleRunClick={handleRunClick}
|
||||
runButtonTooltip="Run evaluator over inputs"
|
||||
/>
|
||||
<Handle
|
||||
type="target"
|
||||
|
@ -32,7 +32,7 @@ const InspectorNode = ({ data, id }) => {
|
||||
console.log(input_node_ids);
|
||||
|
||||
// Grab responses associated with those ids:
|
||||
fetch(BASE_URL + 'grabResponses', {
|
||||
fetch(BASE_URL + 'app/grabResponses', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { useDisclosure } from '@mantine/hooks';
|
||||
import { Modal, Button, Group } from '@mantine/core';
|
||||
import { Modal, Button, Group, RingProgress } from '@mantine/core';
|
||||
import { IconSettings, IconTrash } from '@tabler/icons-react';
|
||||
|
||||
export default function LLMItemButtonGroup( {onClickTrash, onClickSettings} ) {
|
||||
export default function LLMItemButtonGroup( {onClickTrash, onClickSettings, ringProgress} ) {
|
||||
const [opened, { open, close }] = useDisclosure(false);
|
||||
|
||||
return (
|
||||
@ -12,7 +12,13 @@ export default function LLMItemButtonGroup( {onClickTrash, onClickSettings} ) {
|
||||
</Modal>
|
||||
|
||||
<Group position="right" style={{float: 'right', height:'20px'}}>
|
||||
<Button onClick={onClickTrash} size="xs" variant="light" compact color="red" style={{padding: '0px'}} ><IconTrash size={"95%"} /></Button>
|
||||
{ringProgress !== undefined ?
|
||||
(ringProgress > 0 ?
|
||||
(<RingProgress size={20} thickness={3} sections={[{ value: ringProgress, color: ringProgress < 99 ? 'blue' : 'green' }]} width='16px' />) :
|
||||
(<div className="lds-ring"><div></div><div></div><div></div><div></div></div>))
|
||||
: (<></>)
|
||||
}
|
||||
<Button onClick={onClickTrash} size="xs" variant="light" compact color="red" style={{padding: '0px'}} ><IconTrash size={"95%"} /></Button>
|
||||
<Button onClick={onClickSettings} size="xs" variant="light" compact>Settings <IconSettings size={"110%"} /></Button>
|
||||
</Group>
|
||||
</div>
|
||||
|
@ -63,7 +63,7 @@ export default function LLMList({llms, onItemsChange}) {
|
||||
{items.map((item, index) => (
|
||||
<Draggable key={item.key} draggableId={item.key} index={index}>
|
||||
{(provided, snapshot) => (
|
||||
<LLMListItem provided={provided} snapshot={snapshot} item={item} removeCallback={removeItem} />
|
||||
<LLMListItem provided={provided} snapshot={snapshot} item={item} removeCallback={removeItem} progress={item.progress} />
|
||||
)}
|
||||
</Draggable>
|
||||
))}
|
||||
|
@ -23,7 +23,7 @@ export const DragItem = styled.div`
|
||||
flex-direction: column;
|
||||
`;
|
||||
|
||||
const LLMListItem = ({ item, provided, snapshot, removeCallback }) => {
|
||||
const LLMListItem = ({ item, provided, snapshot, removeCallback, progress }) => {
|
||||
return (
|
||||
<DragItem
|
||||
ref={provided.innerRef}
|
||||
@ -33,7 +33,7 @@ const LLMListItem = ({ item, provided, snapshot, removeCallback }) => {
|
||||
>
|
||||
<div>
|
||||
<CardHeader>{item.emoji} {item.name}</CardHeader>
|
||||
<LLMItemButtonGroup onClickTrash={() => removeCallback(item.key)} />
|
||||
<LLMItemButtonGroup onClickTrash={() => removeCallback(item.key)} ringProgress={progress} />
|
||||
</div>
|
||||
|
||||
</DragItem>
|
||||
|
@ -4,9 +4,9 @@ import 'react-edit-text/dist/index.css';
|
||||
import StatusIndicator from './StatusIndicatorComponent';
|
||||
import AlertModal from './AlertModal';
|
||||
import { useState, useEffect} from 'react';
|
||||
import { CloseButton } from '@mantine/core';
|
||||
import { Tooltip } from '@mantine/core';
|
||||
|
||||
export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editable, status, alertModal, handleRunClick }) {
|
||||
export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editable, status, alertModal, handleRunClick, handleRunHover, runButtonTooltip }) {
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
const [statusIndicator, setStatusIndicator] = useState('none');
|
||||
const [runButton, setRunButton] = useState('none');
|
||||
@ -33,12 +33,20 @@ export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editabl
|
||||
|
||||
useEffect(() => {
|
||||
if(handleRunClick !== undefined) {
|
||||
setRunButton(<button className="AmitSahoo45-button-3 nodrag" onClick={handleRunClick}>▶</button>);
|
||||
const run_btn = (<button className="AmitSahoo45-button-3 nodrag" onClick={handleRunClick} onPointerEnter={handleRunHover}>▶</button>);
|
||||
if (runButtonTooltip)
|
||||
setRunButton(
|
||||
<Tooltip label={runButtonTooltip} withArrow arrowSize={6} arrowRadius={2}>
|
||||
{run_btn}
|
||||
</Tooltip>
|
||||
);
|
||||
else
|
||||
setRunButton(run_btn);
|
||||
}
|
||||
else {
|
||||
setRunButton(<></>);
|
||||
}
|
||||
}, [handleRunClick]);
|
||||
}, [handleRunClick, runButtonTooltip]);
|
||||
|
||||
const handleCloseButtonClick = () => {
|
||||
removeNode(nodeId);
|
||||
|
@ -1,14 +1,13 @@
|
||||
import React, { useEffect, useState, useRef, useCallback } from 'react';
|
||||
import { Handle } from 'react-flow-renderer';
|
||||
import { Menu, Badge } from '@mantine/core';
|
||||
import { Menu, Badge, Progress } from '@mantine/core';
|
||||
import { v4 as uuid } from 'uuid';
|
||||
import useStore from './store';
|
||||
import StatusIndicator from './StatusIndicatorComponent'
|
||||
import NodeLabel from './NodeLabelComponent'
|
||||
import TemplateHooks from './TemplateHooksComponent'
|
||||
import LLMList from './LLMListComponent'
|
||||
import AlertModal from './AlertModal'
|
||||
import {BASE_URL} from './store';
|
||||
import io from 'socket.io-client';
|
||||
|
||||
// Available LLMs
|
||||
const allLLMs = [
|
||||
@ -66,13 +65,29 @@ const PromptNode = ({ data, id }) => {
|
||||
// Selecting LLM models to prompt
|
||||
const [llmItems, setLLMItems] = useState(initLLMs.map((i, idx) => ({key: uuid(), ...i})));
|
||||
const [llmItemsCurrState, setLLMItemsCurrState] = useState([]);
|
||||
const resetLLMItemsProgress = useCallback(() => {
|
||||
setLLMItems(llmItemsCurrState.map(item => {
|
||||
item.progress = undefined;
|
||||
return item;
|
||||
}));
|
||||
}, [llmItemsCurrState]);
|
||||
|
||||
// Progress when querying responses
|
||||
const [progress, setProgress] = useState(100);
|
||||
const [runTooltip, setRunTooltip] = useState(null);
|
||||
|
||||
const triggerAlert = (msg) => {
|
||||
setProgress(100);
|
||||
resetLLMItemsProgress();
|
||||
alertModal.current.trigger(msg);
|
||||
};
|
||||
|
||||
const addModel = useCallback((model) => {
|
||||
// Get the item for that model
|
||||
let item = allLLMs.find(llm => llm.model === model);
|
||||
|
||||
if (!item) { // This should never trigger, but in case it does:
|
||||
alertModal.current.trigger(`Could not find model named '${model}' in list of available LLMs.`);
|
||||
triggerAlert(`Could not find model named '${model}' in list of available LLMs.`);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -116,6 +131,89 @@ const PromptNode = ({ data, id }) => {
|
||||
}
|
||||
};
|
||||
|
||||
// Pull all inputs needed to request responses.
|
||||
// Returns [prompt, vars dict]
|
||||
const pullInputData = () => {
|
||||
// Pull data from each source recursively:
|
||||
const pulled_data = {};
|
||||
const get_outputs = (varnames, nodeId) => {
|
||||
varnames.forEach(varname => {
|
||||
// Find the relevant edge(s):
|
||||
edges.forEach(e => {
|
||||
if (e.target == nodeId && e.targetHandle == varname) {
|
||||
// Get the immediate output:
|
||||
let out = output(e.source, e.sourceHandle);
|
||||
|
||||
// Save the var data from the pulled output
|
||||
if (varname in pulled_data)
|
||||
pulled_data[varname] = pulled_data[varname].concat(out);
|
||||
else
|
||||
pulled_data[varname] = out;
|
||||
|
||||
// Get any vars that the output depends on, and recursively collect those outputs as well:
|
||||
const n_vars = getNode(e.source).data.vars;
|
||||
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
|
||||
get_outputs(n_vars, e.source);
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
get_outputs(templateVars, id);
|
||||
|
||||
// Get Pythonic version of the prompt, by adding a $ before any template variables in braces:
|
||||
const to_py_template_format = (str) => str.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
|
||||
const py_prompt_template = to_py_template_format(promptText);
|
||||
|
||||
// Do the same for the vars, since vars can themselves be prompt templates:
|
||||
Object.keys(pulled_data).forEach(varname => {
|
||||
pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val));
|
||||
});
|
||||
|
||||
return [py_prompt_template, pulled_data];
|
||||
};
|
||||
|
||||
// Ask the backend how many responses it needs to collect, given the input data:
|
||||
const fetchResponseCounts = (prompt, vars, llms, rejected) => {
|
||||
return fetch(BASE_URL + 'app/countQueriesRequired', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
prompt: prompt,
|
||||
vars: vars,
|
||||
llms: llms,
|
||||
})}, rejected).then(function(response) {
|
||||
return response.json();
|
||||
}, rejected).then(function(json) {
|
||||
if (!json || !json.counts) {
|
||||
throw new Error('Request was sent and received by backend server, but there was no response.');
|
||||
}
|
||||
return json.counts;
|
||||
}, rejected);
|
||||
};
|
||||
|
||||
// On hover over the 'Run' button, request how many responses are required and update the tooltip. Soft fails.
|
||||
const handleRunHover = () => {
|
||||
// Check if there's at least one model in the list; if not, nothing to run on.
|
||||
if (!llmItemsCurrState || llmItemsCurrState.length == 0) {
|
||||
setRunTooltip('No LLMs to query.');
|
||||
return;
|
||||
}
|
||||
|
||||
// Get input data and prompt
|
||||
const [py_prompt, pulled_vars] = pullInputData();
|
||||
const llms = llmItemsCurrState.map(item => item.model);
|
||||
const num_llms = llms.length;
|
||||
|
||||
// Fetch response counts from backend
|
||||
fetchResponseCounts(py_prompt, pulled_vars, llms, (err) => {
|
||||
console.warn(err.message); // soft fail
|
||||
}).then((counts) => {
|
||||
const n = counts[Object.keys(counts)[0]];
|
||||
const req = n > 1 ? 'requests' : 'request';
|
||||
setRunTooltip(`Will send ${n} ${req}` + (num_llms > 1 ? ' per LLM' : ''));
|
||||
});
|
||||
};
|
||||
|
||||
const handleRunClick = (event) => {
|
||||
// Go through all template hooks (if any) and check they're connected:
|
||||
const is_fully_connected = templateVars.every(varname => {
|
||||
@ -123,63 +221,113 @@ const PromptNode = ({ data, id }) => {
|
||||
return edges.some(e => (e.target == id && e.targetHandle == varname));
|
||||
});
|
||||
|
||||
// console.log(templateHooks);
|
||||
if (!is_fully_connected) {
|
||||
console.log('Not connected! :(');
|
||||
triggerAlert('Missing inputs to one or more template variables.');
|
||||
return;
|
||||
}
|
||||
|
||||
if (is_fully_connected) {
|
||||
console.log('Connected!');
|
||||
console.log('Connected!');
|
||||
|
||||
// Check that there is at least one LLM selected:
|
||||
if (llmItemsCurrState.length === 0) {
|
||||
alert('Please select at least one LLM to prompt.')
|
||||
return;
|
||||
}
|
||||
// Check that there is at least one LLM selected:
|
||||
if (llmItemsCurrState.length === 0) {
|
||||
alert('Please select at least one LLM to prompt.')
|
||||
return;
|
||||
}
|
||||
|
||||
// Set status indicator
|
||||
setStatus('loading');
|
||||
// Set status indicator
|
||||
setStatus('loading');
|
||||
setReponsePreviews([]);
|
||||
|
||||
// Pull data from each source, recursively:
|
||||
const pulled_data = {};
|
||||
const get_outputs = (varnames, nodeId) => {
|
||||
console.log(varnames);
|
||||
varnames.forEach(varname => {
|
||||
// Find the relevant edge(s):
|
||||
edges.forEach(e => {
|
||||
if (e.target == nodeId && e.targetHandle == varname) {
|
||||
// Get the immediate output:
|
||||
let out = output(e.source, e.sourceHandle);
|
||||
const [py_prompt_template, pulled_data] = pullInputData();
|
||||
|
||||
// Save the var data from the pulled output
|
||||
if (varname in pulled_data)
|
||||
pulled_data[varname] = pulled_data[varname].concat(out);
|
||||
else
|
||||
pulled_data[varname] = out;
|
||||
let FINISHED_QUERY = false;
|
||||
const rejected = (err) => {
|
||||
setStatus('error');
|
||||
triggerAlert(err.message);
|
||||
FINISHED_QUERY = true;
|
||||
};
|
||||
|
||||
// Get any vars that the output depends on, and recursively collect those outputs as well:
|
||||
const n_vars = getNode(e.source).data.vars;
|
||||
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
|
||||
get_outputs(n_vars, e.source);
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
get_outputs(templateVars, id);
|
||||
// Ask the backend to reset the scratchpad for counting queries:
|
||||
const create_progress_scratchpad = () => {
|
||||
return fetch(BASE_URL + 'app/createProgressFile', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
id: id,
|
||||
})}, rejected);
|
||||
};
|
||||
|
||||
// Get Pythonic version of the prompt, by adding a $ before any template variables in braces:
|
||||
const to_py_template_format = (str) => str.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
|
||||
const py_prompt_template = to_py_template_format(promptText);
|
||||
// Fetch info about the number of queries we'll need to make
|
||||
const fetch_resp_count = () => fetchResponseCounts(
|
||||
py_prompt_template, pulled_data, llmItemsCurrState.map(item => item.model), rejected);
|
||||
|
||||
// Do the same for the vars, since vars can themselves be prompt templates:
|
||||
Object.keys(pulled_data).forEach(varname => {
|
||||
pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val));
|
||||
// Open a socket to listen for progress
|
||||
const open_progress_listener_socket = (response_counts) => {
|
||||
// With the counts information we can create progress bars. Now we load a socket connection to
|
||||
// the socketio server that will stream to us the current progress:
|
||||
const socket = io('http://localhost:8001/' + 'queryllm', {
|
||||
transports: ["websocket"],
|
||||
cors: {origin: "http://localhost:8000/"},
|
||||
});
|
||||
|
||||
const rejected = (err) => {
|
||||
setStatus('error');
|
||||
alertModal.current.trigger(err.message);
|
||||
};
|
||||
const max_responses = Object.keys(response_counts).reduce((acc, llm) => acc + response_counts[llm], 0);
|
||||
|
||||
// Run all prompt permutations through the LLM to generate + cache responses:
|
||||
fetch(BASE_URL + 'queryllm', {
|
||||
// On connect to the server, ask it to give us the current progress
|
||||
// for task 'queryllm' with id 'id', and stop when it reads progress >= 'max'.
|
||||
socket.on("connect", (msg) => {
|
||||
// Initialize progress bars to small amounts
|
||||
setProgress(5);
|
||||
setLLMItems(llmItemsCurrState.map(item => {
|
||||
item.progress = 0;
|
||||
return item;
|
||||
}));
|
||||
|
||||
// Request progress bar updates
|
||||
socket.emit("queryllm", {'id': id, 'max': max_responses});
|
||||
});
|
||||
|
||||
// Socket connection could not be established
|
||||
socket.on("connect_error", (error) => {
|
||||
console.log("Socket connection failed:", error.message);
|
||||
socket.disconnect();
|
||||
});
|
||||
|
||||
// Socket disconnected
|
||||
socket.on("disconnect", (msg) => {
|
||||
console.log(msg);
|
||||
});
|
||||
|
||||
// The current progress, a number specifying how many responses collected so far:
|
||||
socket.on("response", (counts) => {
|
||||
console.log(counts);
|
||||
if (!counts || FINISHED_QUERY) return;
|
||||
|
||||
// Update individual progress bars
|
||||
const num_llms = llmItemsCurrState.length;
|
||||
setLLMItems(llmItemsCurrState.map(item => {
|
||||
if (item.model in counts)
|
||||
item.progress = counts[item.model] / (max_responses / num_llms) * 100;
|
||||
return item;
|
||||
}));
|
||||
|
||||
// Update total progress bar
|
||||
const total_num_resps = Object.keys(counts).reduce((acc, llm_name) => {
|
||||
return acc + counts[llm_name];
|
||||
}, 0);
|
||||
setProgress(total_num_resps / max_responses * 100);
|
||||
});
|
||||
|
||||
// The process has finished; close the connection:
|
||||
socket.on("finish", (msg) => {
|
||||
console.log("finished:", msg);
|
||||
socket.disconnect();
|
||||
});
|
||||
};
|
||||
|
||||
// Run all prompt permutations through the LLM to generate + cache responses:
|
||||
const query_llms = () => {
|
||||
return fetch(BASE_URL + 'app/queryllm', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
@ -191,7 +339,7 @@ const PromptNode = ({ data, id }) => {
|
||||
temperature: 0.5,
|
||||
n: numGenerations,
|
||||
},
|
||||
no_cache: false,
|
||||
no_cache: true,
|
||||
}),
|
||||
}, rejected).then(function(response) {
|
||||
return response.json();
|
||||
@ -205,6 +353,11 @@ const PromptNode = ({ data, id }) => {
|
||||
// Success! Change status to 'ready':
|
||||
setStatus('ready');
|
||||
|
||||
// Remove progress bars
|
||||
setProgress(100);
|
||||
resetLLMItemsProgress();
|
||||
FINISHED_QUERY = true;
|
||||
|
||||
// Save prompt text so we remember what prompt we have responses cache'd for:
|
||||
setPromptTextOnLastRun(promptText);
|
||||
|
||||
@ -246,14 +399,15 @@ const PromptNode = ({ data, id }) => {
|
||||
alertModal.current.trigger(json.error || 'Unknown error when querying LLM');
|
||||
}
|
||||
}, rejected);
|
||||
};
|
||||
|
||||
console.log(pulled_data);
|
||||
} else {
|
||||
console.log('Not connected! :(');
|
||||
alertModal.current.trigger('Missing inputs to one or more template variables.')
|
||||
|
||||
// TODO: Blink the names of unconnected params
|
||||
}
|
||||
// Now put it all together!
|
||||
create_progress_scratchpad()
|
||||
.then(fetch_resp_count)
|
||||
.then(open_progress_listener_socket)
|
||||
.then(query_llms)
|
||||
.catch(rejected);
|
||||
|
||||
}
|
||||
|
||||
const handleNumGenChange = (event) => {
|
||||
@ -279,6 +433,8 @@ const PromptNode = ({ data, id }) => {
|
||||
status={status}
|
||||
alertModal={alertModal}
|
||||
handleRunClick={handleRunClick}
|
||||
handleRunHover={handleRunHover}
|
||||
runButtonTooltip={runTooltip}
|
||||
/>
|
||||
<div className="input-field">
|
||||
<textarea
|
||||
@ -331,6 +487,7 @@ const PromptNode = ({ data, id }) => {
|
||||
<label htmlFor="alpaca.7B">Alpaca 7B</label> */}
|
||||
</div>
|
||||
</div>
|
||||
{progress < 100 ? (<Progress value={progress} animate />) : <></>}
|
||||
<div className="response-preview-container nowheel">
|
||||
{responsePreviews}
|
||||
</div>
|
||||
|
@ -47,7 +47,7 @@ const VisNode = ({ data, id }) => {
|
||||
// Grab the input node ids
|
||||
const input_node_ids = [data.input];
|
||||
|
||||
fetch(BASE_URL + 'grabResponses', {
|
||||
fetch(BASE_URL + 'app/grabResponses', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
|
@ -300,6 +300,10 @@
|
||||
-moz-transition: all 0.2s ease-out;
|
||||
transition: all 0.2s ease-out;
|
||||
}
|
||||
.AmitSahoo45-button-3:active {
|
||||
background: #40a829;
|
||||
color: yellow;
|
||||
}
|
||||
|
||||
.AmitSahoo45-button-3:hover::before {
|
||||
-moz-animation: sh02 0.5s 0s linear;
|
||||
@ -364,6 +368,10 @@
|
||||
-moz-transition: all 0.2s ease-out;
|
||||
transition: all 0.2s ease-out;
|
||||
}
|
||||
.close-button:active {
|
||||
background: #660000;
|
||||
color: yellow;
|
||||
}
|
||||
|
||||
.close-button:hover::before {
|
||||
-moz-animation: sh02 0.5s 0s linear;
|
||||
|
@ -1,426 +1,91 @@
|
||||
import json, os, asyncio, sys, argparse
|
||||
import json, os, asyncio, sys, argparse, threading
|
||||
from dataclasses import dataclass
|
||||
from statistics import mean, median, stdev
|
||||
from flask import Flask, request, jsonify
|
||||
from flask_cors import CORS
|
||||
from flask_socketio import SocketIO
|
||||
from flask_app import run_server
|
||||
from promptengine.query import PromptLLM, PromptLLMDummy
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
# Setup the socketio app
|
||||
# BUILD_DIR = "../chain-forge/build"
|
||||
# STATIC_DIR = BUILD_DIR + '/static'
|
||||
app = Flask(__name__) #, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
|
||||
|
||||
LLM_NAME_MAP = {
|
||||
'gpt3.5': LLM.ChatGPT,
|
||||
'alpaca.7B': LLM.Alpaca7B,
|
||||
'gpt4': LLM.GPT4,
|
||||
}
|
||||
LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()}
|
||||
# Initialize Socket.IO
|
||||
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="gevent")
|
||||
|
||||
@dataclass
|
||||
class ResponseInfo:
|
||||
"""Stores info about a single response. Passed to evaluator functions."""
|
||||
text: str
|
||||
prompt: str
|
||||
var: str
|
||||
llm: str
|
||||
# Set up CORS for specific routes
|
||||
# cors = CORS(app, resources={r"/api/*": {"origins": "*"}})
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
# Wait a max of a full 3 minutes (180 seconds) for the response count to update, before exiting.
|
||||
MAX_WAIT_TIME = 180
|
||||
|
||||
def to_standard_format(r: dict) -> list:
|
||||
llm = LLM_NAME_MAP_INVERSE[r['llm']]
|
||||
resp_obj = {
|
||||
'vars': r['info'],
|
||||
'llm': llm,
|
||||
'prompt': r['prompt'],
|
||||
'responses': extract_responses(r, r['llm']),
|
||||
'tokens': r['response']['usage'] if 'usage' in r['response'] else {},
|
||||
}
|
||||
if 'eval_res' in r:
|
||||
resp_obj['eval_res'] = r['eval_res']
|
||||
return resp_obj
|
||||
def countdown():
|
||||
n = 10
|
||||
while n > 0:
|
||||
socketio.sleep(0.5)
|
||||
socketio.emit('response', n, namespace='/queryllm')
|
||||
n -= 1
|
||||
|
||||
def get_llm_of_response(response: dict) -> LLM:
|
||||
return LLM_NAME_MAP[response['llm']]
|
||||
@socketio.on('queryllm', namespace='/queryllm')
|
||||
def readCounts(data):
|
||||
id = data['id']
|
||||
max_count = data['max']
|
||||
tempfilepath = f'cache/_temp_{id}.txt'
|
||||
|
||||
def get_filenames_with_id(filenames: list, id: str) -> list:
|
||||
return [
|
||||
c for c in filenames
|
||||
if c.split('.')[0] == id or ('-' in c and c[:c.rfind('-')] == id)
|
||||
]
|
||||
# Check that temp file exists. If it doesn't, something went wrong with setup on Flask's end:
|
||||
if not os.path.exists(tempfilepath):
|
||||
print(f"Error: Temp file not found at path {tempfilepath}. Cannot stream querying progress.")
|
||||
socketio.emit('finish', 'temp file not found', namespace='/queryllm')
|
||||
|
||||
def remove_cached_responses(cache_id: str):
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
for filename in cache_files:
|
||||
os.remove(os.path.join('cache/', filename))
|
||||
i = 0
|
||||
last_n = 0
|
||||
init_run = True
|
||||
while i < MAX_WAIT_TIME and last_n < max_count:
|
||||
|
||||
def load_cache_json(filepath: str) -> dict:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
responses = json.load(f)
|
||||
return responses
|
||||
|
||||
def run_over_responses(eval_func, responses: dict, scope: str) -> list:
|
||||
for prompt, resp_obj in responses.items():
|
||||
res = extract_responses(resp_obj, resp_obj['llm'])
|
||||
if scope == 'response':
|
||||
evals = [ # Run evaluator func over every individual response text
|
||||
eval_func(
|
||||
ResponseInfo(
|
||||
text=r,
|
||||
prompt=prompt,
|
||||
var=resp_obj['info'],
|
||||
llm=LLM_NAME_MAP_INVERSE[resp_obj['llm']])
|
||||
) for r in res
|
||||
]
|
||||
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
|
||||
'mean': mean(evals),
|
||||
'median': median(evals),
|
||||
'stdev': stdev(evals) if len(evals) > 1 else 0,
|
||||
'range': (min(evals), max(evals)),
|
||||
'items': evals,
|
||||
}
|
||||
else: # operate over the entire response batch
|
||||
ev = eval_func(res)
|
||||
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
|
||||
'mean': ev,
|
||||
'median': ev,
|
||||
'stdev': 0,
|
||||
'range': (ev, ev),
|
||||
'items': [ev],
|
||||
}
|
||||
return responses
|
||||
|
||||
def reduce_responses(responses: list, vars: list) -> list:
|
||||
if len(responses) == 0: return responses
|
||||
|
||||
# Figure out what vars we still care about (the ones we aren't reducing over):
|
||||
# NOTE: We are assuming all responses have the same 'vars' keys.
|
||||
all_vars = set(responses[0]['vars'])
|
||||
|
||||
if not all_vars.issuperset(set(vars)):
|
||||
# There's a var in vars which isn't part of the response.
|
||||
raise Exception(f"Some vars in {set(vars)} are not in the responses.")
|
||||
|
||||
# Get just the vars we want to keep around:
|
||||
include_vars = list(set(responses[0]['vars']) - set(vars))
|
||||
|
||||
# Bucket responses by the remaining var values, where tuples of vars are keys to a dict:
|
||||
# E.g. {(var1_val, var2_val): [responses] }
|
||||
bucketed_resp = {}
|
||||
for r in responses:
|
||||
print(r)
|
||||
tup_key = tuple([r['vars'][v] for v in include_vars])
|
||||
if tup_key in bucketed_resp:
|
||||
bucketed_resp[tup_key].append(r)
|
||||
else:
|
||||
bucketed_resp[tup_key] = [r]
|
||||
|
||||
# Perform reduce op across all bucketed responses, collecting them into a single 'meta'-response:
|
||||
ret = []
|
||||
for tup_key, resps in bucketed_resp.items():
|
||||
flat_eval_res = [item for r in resps for item in r['eval_res']['items']]
|
||||
ret.append({
|
||||
'vars': {v: r['vars'][v] for r in resps for v in include_vars},
|
||||
'llm': resps[0]['llm'],
|
||||
'prompt': [r['prompt'] for r in resps],
|
||||
'responses': [r['responses'] for r in resps],
|
||||
'tokens': resps[0]['tokens'],
|
||||
'eval_res': {
|
||||
'mean': mean(flat_eval_res),
|
||||
'median': median(flat_eval_res),
|
||||
'stdev': stdev(flat_eval_res) if len(flat_eval_res) > 1 else 0,
|
||||
'range': (min(flat_eval_res), max(flat_eval_res)),
|
||||
'items': flat_eval_res
|
||||
}
|
||||
})
|
||||
|
||||
return ret
|
||||
|
||||
@app.route('/test', methods=['GET'])
|
||||
def test():
|
||||
return "Hello, world!"
|
||||
|
||||
@app.route('/queryllm', methods=['POST'])
|
||||
async def queryLLM():
|
||||
"""
|
||||
Queries LLM(s) given a JSON spec.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'id': str # a unique ID to refer to this information. Used when cache'ing responses.
|
||||
'llm': str | list # a string or list of strings specifying the LLM(s) to query
|
||||
'params': dict # an optional dict of any other params to set when querying the LLMs, like 'temperature', 'n' (num of responses per prompt), etc.
|
||||
'prompt': str # the prompt template, with any {{}} vars
|
||||
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
|
||||
'no_cache': bool (optional) # delete any cache'd responses for 'id' (always call the LLM fresh)
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
# Check that all required info is here:
|
||||
if not set(data.keys()).issuperset({'llm', 'prompt', 'vars', 'id'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
elif not isinstance(data['id'], str) or len(data['id']) == 0:
|
||||
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
||||
|
||||
# Verify LLM name(s) (string or list) and convert to enum(s):
|
||||
if not (isinstance(data['llm'], list) or isinstance(data['llm'], str)) or (isinstance(data['llm'], list) and len(data['llm']) == 0):
|
||||
return jsonify({'error': 'POST data llm is improper format (not string or list, or of length 0).'})
|
||||
if isinstance(data['llm'], str):
|
||||
data['llm'] = [ data['llm'] ]
|
||||
llms = []
|
||||
for llm_str in data['llm']:
|
||||
if llm_str not in LLM_NAME_MAP:
|
||||
return jsonify({'error': f"LLM named '{llm_str}' is not supported."})
|
||||
llms.append(LLM_NAME_MAP[llm_str])
|
||||
|
||||
if 'no_cache' in data and data['no_cache'] is True:
|
||||
remove_cached_responses(data['id'])
|
||||
|
||||
# Create a cache dir if it doesn't exist:
|
||||
create_dir_if_not_exists('cache')
|
||||
|
||||
# For each LLM, generate and cache responses:
|
||||
responses = {}
|
||||
params = data['params'] if 'params' in data else {}
|
||||
|
||||
async def query(llm: str) -> list:
|
||||
# Check that storage path is valid:
|
||||
cache_filepath = os.path.join('cache', f"{data['id']}-{str(llm.name)}.json")
|
||||
if not is_valid_filepath(cache_filepath):
|
||||
return jsonify({'error': f'Invalid filepath: {cache_filepath}'})
|
||||
|
||||
# Create an object to query the LLM, passing a file for cache'ing responses
|
||||
prompter = PromptLLM(data['prompt'], storageFile=cache_filepath)
|
||||
|
||||
# Prompt the LLM with all permutations of the input prompt template:
|
||||
# NOTE: If the responses are already cache'd, this just loads them (no LLM is queried, saving $$$)
|
||||
resps = []
|
||||
try:
|
||||
print(f'Querying {llm}...')
|
||||
async for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
|
||||
resps.append(response)
|
||||
except Exception as e:
|
||||
print('error generating responses:', e)
|
||||
raise e
|
||||
# Open the temp file to read the progress so far:
|
||||
try:
|
||||
with open(tempfilepath, 'r') as f:
|
||||
queries = json.load(f)
|
||||
except FileNotFoundError as e:
|
||||
# If the temp file was deleted during executing, the Flask 'queryllm' func must've terminated successfully:
|
||||
socketio.emit('finish', 'success', namespace='/queryllm')
|
||||
return
|
||||
|
||||
return {'llm': llm, 'responses': resps}
|
||||
# Calculate the total sum of responses
|
||||
# TODO: This is a naive approach; we need to make this more complex and factor in cache'ing in future
|
||||
n = sum([int(n) for llm, n in queries.items()])
|
||||
|
||||
# If something's changed...
|
||||
if init_run or last_n != n:
|
||||
i = 0
|
||||
last_n = n
|
||||
init_run = False
|
||||
|
||||
# Update the React front-end with the current progress
|
||||
socketio.emit('response', queries, namespace='/queryllm')
|
||||
|
||||
try:
|
||||
# Request responses simultaneously across LLMs
|
||||
tasks = [query(llm) for llm in llms]
|
||||
else:
|
||||
i += 0.1
|
||||
|
||||
# Wait a bit before reading the file again
|
||||
socketio.sleep(0.1)
|
||||
|
||||
# Await the responses from all queried LLMs
|
||||
llm_results = await asyncio.gather(*tasks)
|
||||
for item in llm_results:
|
||||
responses[item['llm']] = item['responses']
|
||||
if i >= MAX_WAIT_TIME:
|
||||
print(f"Error: Waited maximum {MAX_WAIT_TIME} seconds for response count to update. Exited prematurely.")
|
||||
socketio.emit('finish', 'max_wait_reached', namespace='/queryllm')
|
||||
else:
|
||||
print("All responses loaded!")
|
||||
socketio.emit('finish', 'success', namespace='/queryllm')
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)})
|
||||
def run_socketio_server(socketio, port):
|
||||
socketio.run(app, host="localhost", port=8001)
|
||||
|
||||
# Convert the responses into a more standardized format with less information
|
||||
res = [
|
||||
to_standard_format(r)
|
||||
for rs in responses.values()
|
||||
for r in rs
|
||||
]
|
||||
|
||||
# Return all responses for all LLMs
|
||||
print('returning responses:', res)
|
||||
ret = jsonify({'responses': res})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@app.route('/execute', methods=['POST'])
|
||||
def execute():
|
||||
"""
|
||||
Executes a Python lambda function sent from JavaScript,
|
||||
over all cache'd responses with given id's.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'id': # a unique ID to refer to this information. Used when cache'ing responses.
|
||||
'code': str, # the body of the lambda function to evaluate, in form: lambda responses: <body>
|
||||
'responses': str | List[str] # the responses to run on; a unique ID or list of unique IDs of cache'd data,
|
||||
'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch.
|
||||
# If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt.
|
||||
'reduce_vars': unspecified | List[str] # the 'vars' to average over (mean, median, stdev, range)
|
||||
'script_paths': unspecified | List[str] # the paths to scripts to be added to the path before the lambda function is evaluated
|
||||
}
|
||||
|
||||
NOTE: This should only be run on your server on code you trust.
|
||||
There is no sandboxing; no safety. We assume you are the creator of the code.
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
# Check that all required info is here:
|
||||
if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
if not isinstance(data['id'], str) or len(data['id']) == 0:
|
||||
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
||||
if data['scope'] not in ('response', 'batch'):
|
||||
return jsonify({'error': "POST data scope is unknown. Must be either 'response' or 'batch'."})
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Check that the filepath used to cache eval'd responses is valid:
|
||||
cache_filepath = os.path.join('cache', f"{data['id']}.json")
|
||||
if not is_valid_filepath(cache_filepath):
|
||||
return jsonify({'error': f'Invalid filepath: {cache_filepath}'})
|
||||
|
||||
# Check format of responses:
|
||||
if not (isinstance(data['responses'], str) or isinstance(data['responses'], list)):
|
||||
return jsonify({'error': 'POST data responses is improper format.'})
|
||||
elif isinstance(data['responses'], str):
|
||||
data['responses'] = [ data['responses'] ]
|
||||
|
||||
# add the path to any scripts to the path:
|
||||
try:
|
||||
if 'script_paths' in data:
|
||||
for script_path in data['script_paths']:
|
||||
# get the folder of the script_path:
|
||||
script_folder = os.path.dirname(script_path)
|
||||
# check that the script_folder is valid, and it contains __init__.py
|
||||
if not os.path.exists(script_folder):
|
||||
print(script_folder, 'is not a valid script path.')
|
||||
print(os.path.exists(script_folder))
|
||||
continue
|
||||
|
||||
# add it to the path:
|
||||
sys.path.append(script_folder)
|
||||
print(f'added {script_folder} to sys.path')
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Could not add script path to sys.path. Error message:\n{str(e)}'})
|
||||
|
||||
# Create the evaluator function
|
||||
# DANGER DANGER!
|
||||
try:
|
||||
exec(data['code'], globals())
|
||||
|
||||
# Double-check that there is an 'evaluate' method in our namespace.
|
||||
# This will throw a NameError if not:
|
||||
evaluate
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
|
||||
|
||||
# Load all responses with the given ID:
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
all_evald_responses = []
|
||||
for cache_id in data['responses']:
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
if len(cache_files) == 0:
|
||||
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
|
||||
|
||||
# To avoid loading all response files into memory at once, we'll run the evaluator on each file:
|
||||
for filename in cache_files:
|
||||
|
||||
# Load the raw responses from the cache
|
||||
responses = load_cache_json(os.path.join('cache', filename))
|
||||
if len(responses) == 0: continue
|
||||
|
||||
# Run the evaluator over them:
|
||||
# NOTE: 'evaluate' here was defined dynamically from 'exec' above.
|
||||
try:
|
||||
evald_responses = run_over_responses(evaluate, responses, scope=data['scope'])
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Error encountered while trying to run "evaluate" method:\n{str(e)}'})
|
||||
|
||||
# Convert to standard format:
|
||||
std_evald_responses = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in evald_responses.items()
|
||||
]
|
||||
|
||||
# Perform any reduction operations:
|
||||
if 'reduce_vars' in data and len(data['reduce_vars']) > 0:
|
||||
std_evald_responses = reduce_responses(
|
||||
std_evald_responses,
|
||||
vars=data['reduce_vars']
|
||||
)
|
||||
|
||||
all_evald_responses.extend(std_evald_responses)
|
||||
|
||||
# Store the evaluated responses in a new cache json:
|
||||
with open(cache_filepath, "w") as f:
|
||||
json.dump(all_evald_responses, f)
|
||||
|
||||
ret = jsonify({'responses': all_evald_responses})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@app.route('/checkEvalFunc', methods=['POST'])
|
||||
def checkEvalFunc():
|
||||
"""
|
||||
Tries to compile a Python lambda function sent from JavaScript.
|
||||
Returns a dict with 'result':true if it compiles without raising an exception;
|
||||
'result':false (and an 'error' property with a message) if not.
|
||||
|
||||
POST'd data should be in form:
|
||||
{
|
||||
'code': str, # the body of the lambda function to evaluate, in form: lambda responses: <body>
|
||||
}
|
||||
|
||||
NOTE: This should only be run on your server on code you trust.
|
||||
There is no sandboxing; no safety. We assume you are the creator of the code.
|
||||
"""
|
||||
data = request.get_json()
|
||||
if 'code' not in data:
|
||||
return jsonify({'result': False, 'error': f'Could not evaluate code. Error message:\n{str(e)}'})
|
||||
|
||||
# DANGER DANGER! Running exec on code passed through front-end. Make sure it's trusted!
|
||||
try:
|
||||
exec(data['code'], globals())
|
||||
|
||||
# Double-check that there is an 'evaluate' method in our namespace.
|
||||
# This will throw a NameError if not:
|
||||
evaluate
|
||||
return jsonify({'result': True})
|
||||
except Exception as e:
|
||||
return jsonify({'result': False, 'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
|
||||
|
||||
@app.route('/grabResponses', methods=['POST'])
|
||||
def grabResponses():
|
||||
"""
|
||||
Returns all responses with the specified id(s)
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'responses': <the ids to grab>
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
# Check format of responses:
|
||||
if not (isinstance(data['responses'], str) or isinstance(data['responses'], list)):
|
||||
return jsonify({'error': 'POST data responses is improper format.'})
|
||||
elif isinstance(data['responses'], str):
|
||||
data['responses'] = [ data['responses'] ]
|
||||
|
||||
# Load all responses with the given ID:
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
print(all_cache_files)
|
||||
responses = []
|
||||
for cache_id in data['responses']:
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
if len(cache_files) == 0:
|
||||
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
|
||||
|
||||
for filename in cache_files:
|
||||
res = load_cache_json(os.path.join('cache', filename))
|
||||
if isinstance(res, dict):
|
||||
# Convert to standard response format
|
||||
res = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in res.items()
|
||||
]
|
||||
responses.extend(res)
|
||||
|
||||
print(responses)
|
||||
ret = jsonify({'responses': responses})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='This script spins up a Flask server that serves as the backend for ChainForge')
|
||||
|
||||
# Turn on to disable all outbound LLM API calls and replace them with dummy calls
|
||||
@ -430,10 +95,15 @@ if __name__ == '__main__':
|
||||
Produces each dummy response at random intervals between 0.1 and 3 seconds.""",
|
||||
dest='dummy_responses',
|
||||
action='store_true')
|
||||
parser.add_argument('--port', help='The port to run the server on. Defaults to 8000.', type=int, default=8000, nargs='?')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dummy_responses:
|
||||
PromptLLM = PromptLLMDummy
|
||||
extract_responses = lambda r, llm: r['response']
|
||||
port = args.port if args.port else 8000
|
||||
|
||||
app.run(host="localhost", port=8000, debug=True)
|
||||
# Spin up separate thread for socketio app, on port+1 (8001 default)
|
||||
print(f"Serving SocketIO server on port {port+1}...")
|
||||
t1 = threading.Thread(target=run_socketio_server, args=[socketio, port+1])
|
||||
t1.start()
|
||||
|
||||
print(f"Serving Flask server on port {port}...")
|
||||
run_server(host="localhost", port=port, cmd_args=args)
|
512
python-backend/flask_app.py
Normal file
512
python-backend/flask_app.py
Normal file
@ -0,0 +1,512 @@
|
||||
import json, os, asyncio, sys, argparse, threading
|
||||
from dataclasses import dataclass
|
||||
from statistics import mean, median, stdev
|
||||
from flask import Flask, request, jsonify, render_template, send_from_directory
|
||||
from flask_cors import CORS
|
||||
from flask_socketio import SocketIO
|
||||
from promptengine.query import PromptLLM, PromptLLMDummy
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
|
||||
|
||||
# Setup Flask app to serve static version of React front-end
|
||||
BUILD_DIR = "../chain-forge/build"
|
||||
STATIC_DIR = BUILD_DIR + '/static'
|
||||
app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
|
||||
|
||||
# Set up CORS for specific routes
|
||||
cors = CORS(app, resources={r"/*": {"origins": "*"}})
|
||||
|
||||
# Serve React app (static; no hot reloading)
|
||||
@app.route("/")
|
||||
def index():
|
||||
return render_template("index.html")
|
||||
|
||||
LLM_NAME_MAP = {
|
||||
'gpt3.5': LLM.ChatGPT,
|
||||
'alpaca.7B': LLM.Alpaca7B,
|
||||
'gpt4': LLM.GPT4,
|
||||
}
|
||||
LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()}
|
||||
|
||||
@dataclass
|
||||
class ResponseInfo:
|
||||
"""Stores info about a single response. Passed to evaluator functions."""
|
||||
text: str
|
||||
prompt: str
|
||||
var: str
|
||||
llm: str
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
def to_standard_format(r: dict) -> list:
|
||||
llm = LLM_NAME_MAP_INVERSE[r['llm']]
|
||||
resp_obj = {
|
||||
'vars': r['info'],
|
||||
'llm': llm,
|
||||
'prompt': r['prompt'],
|
||||
'responses': extract_responses(r, r['llm']),
|
||||
'tokens': r['response']['usage'] if 'usage' in r['response'] else {},
|
||||
}
|
||||
if 'eval_res' in r:
|
||||
resp_obj['eval_res'] = r['eval_res']
|
||||
return resp_obj
|
||||
|
||||
def get_llm_of_response(response: dict) -> LLM:
|
||||
return LLM_NAME_MAP[response['llm']]
|
||||
|
||||
def get_filenames_with_id(filenames: list, id: str) -> list:
|
||||
return [
|
||||
c for c in filenames
|
||||
if c.split('.')[0] == id or ('-' in c and c[:c.rfind('-')] == id)
|
||||
]
|
||||
|
||||
def remove_cached_responses(cache_id: str):
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
for filename in cache_files:
|
||||
os.remove(os.path.join('cache/', filename))
|
||||
|
||||
def load_cache_json(filepath: str) -> dict:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
responses = json.load(f)
|
||||
return responses
|
||||
|
||||
def run_over_responses(eval_func, responses: dict, scope: str) -> list:
|
||||
for prompt, resp_obj in responses.items():
|
||||
res = extract_responses(resp_obj, resp_obj['llm'])
|
||||
if scope == 'response':
|
||||
evals = [ # Run evaluator func over every individual response text
|
||||
eval_func(
|
||||
ResponseInfo(
|
||||
text=r,
|
||||
prompt=prompt,
|
||||
var=resp_obj['info'],
|
||||
llm=LLM_NAME_MAP_INVERSE[resp_obj['llm']])
|
||||
) for r in res
|
||||
]
|
||||
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
|
||||
'mean': mean(evals),
|
||||
'median': median(evals),
|
||||
'stdev': stdev(evals) if len(evals) > 1 else 0,
|
||||
'range': (min(evals), max(evals)),
|
||||
'items': evals,
|
||||
}
|
||||
else: # operate over the entire response batch
|
||||
ev = eval_func(res)
|
||||
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
|
||||
'mean': ev,
|
||||
'median': ev,
|
||||
'stdev': 0,
|
||||
'range': (ev, ev),
|
||||
'items': [ev],
|
||||
}
|
||||
return responses
|
||||
|
||||
def reduce_responses(responses: list, vars: list) -> list:
|
||||
if len(responses) == 0: return responses
|
||||
|
||||
# Figure out what vars we still care about (the ones we aren't reducing over):
|
||||
# NOTE: We are assuming all responses have the same 'vars' keys.
|
||||
all_vars = set(responses[0]['vars'])
|
||||
|
||||
if not all_vars.issuperset(set(vars)):
|
||||
# There's a var in vars which isn't part of the response.
|
||||
raise Exception(f"Some vars in {set(vars)} are not in the responses.")
|
||||
|
||||
# Get just the vars we want to keep around:
|
||||
include_vars = list(set(responses[0]['vars']) - set(vars))
|
||||
|
||||
# Bucket responses by the remaining var values, where tuples of vars are keys to a dict:
|
||||
# E.g. {(var1_val, var2_val): [responses] }
|
||||
bucketed_resp = {}
|
||||
for r in responses:
|
||||
tup_key = tuple([r['vars'][v] for v in include_vars])
|
||||
if tup_key in bucketed_resp:
|
||||
bucketed_resp[tup_key].append(r)
|
||||
else:
|
||||
bucketed_resp[tup_key] = [r]
|
||||
|
||||
# Perform reduce op across all bucketed responses, collecting them into a single 'meta'-response:
|
||||
ret = []
|
||||
for tup_key, resps in bucketed_resp.items():
|
||||
flat_eval_res = [item for r in resps for item in r['eval_res']['items']]
|
||||
ret.append({
|
||||
'vars': {v: r['vars'][v] for r in resps for v in include_vars},
|
||||
'llm': resps[0]['llm'],
|
||||
'prompt': [r['prompt'] for r in resps],
|
||||
'responses': [r['responses'] for r in resps],
|
||||
'tokens': resps[0]['tokens'],
|
||||
'eval_res': {
|
||||
'mean': mean(flat_eval_res),
|
||||
'median': median(flat_eval_res),
|
||||
'stdev': stdev(flat_eval_res) if len(flat_eval_res) > 1 else 0,
|
||||
'range': (min(flat_eval_res), max(flat_eval_res)),
|
||||
'items': flat_eval_res
|
||||
}
|
||||
})
|
||||
|
||||
return ret
|
||||
|
||||
@app.route('/app/countQueriesRequired', methods=['POST'])
|
||||
def countQueries():
|
||||
"""
|
||||
Returns how many queries we need to make, given the passed prompt and vars.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'prompt': str # the prompt template, with any {{}} vars
|
||||
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
|
||||
'llms': list # the list of LLMs you will query
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
if not set(data.keys()).issuperset({'prompt', 'vars', 'llms'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
|
||||
try:
|
||||
gen_prompts = PromptPermutationGenerator(PromptTemplate(data['prompt']))
|
||||
all_prompt_permutations = list(gen_prompts(data['vars']))
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)})
|
||||
|
||||
# TODO: Send more informative data back including how many queries per LLM based on cache'd data
|
||||
num_queries = {} # len(all_prompt_permutations) * len(data['llms'])
|
||||
for llm in data['llms']:
|
||||
num_queries[llm] = len(all_prompt_permutations)
|
||||
|
||||
ret = jsonify({'counts': num_queries})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@app.route('/app/createProgressFile', methods=['POST'])
|
||||
def createProgressFile():
|
||||
"""
|
||||
Creates a temp txt file for storing progress of async LLM queries.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'id': str # a unique ID that will be used when calling 'queryllm'
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
if 'id' not in data or not isinstance(data['id'], str) or len(data['id']) == 0:
|
||||
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
||||
|
||||
# Create a scratch file for keeping track of how many responses loaded
|
||||
try:
|
||||
with open(f"cache/_temp_{data['id']}.txt", 'w') as f:
|
||||
json.dump({}, f)
|
||||
ret = jsonify({'success': True})
|
||||
except Exception as e:
|
||||
ret = jsonify({'success': False, 'error': str(e)})
|
||||
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
# @socketio.on('connect', namespace='/queryllm')
|
||||
@app.route('/app/queryllm', methods=['POST'])
|
||||
async def queryLLM():
|
||||
"""
|
||||
Queries LLM(s) given a JSON spec.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'id': str # a unique ID to refer to this information. Used when cache'ing responses.
|
||||
'llm': str | list # a string or list of strings specifying the LLM(s) to query
|
||||
'params': dict # an optional dict of any other params to set when querying the LLMs, like 'temperature', 'n' (num of responses per prompt), etc.
|
||||
'prompt': str # the prompt template, with any {{}} vars
|
||||
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
|
||||
'no_cache': bool (optional) # delete any cache'd responses for 'id' (always call the LLM fresh)
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
# Check that all required info is here:
|
||||
if not set(data.keys()).issuperset({'llm', 'prompt', 'vars', 'id'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
elif not isinstance(data['id'], str) or len(data['id']) == 0:
|
||||
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
||||
|
||||
# Verify LLM name(s) (string or list) and convert to enum(s):
|
||||
if not (isinstance(data['llm'], list) or isinstance(data['llm'], str)) or (isinstance(data['llm'], list) and len(data['llm']) == 0):
|
||||
return jsonify({'error': 'POST data llm is improper format (not string or list, or of length 0).'})
|
||||
if isinstance(data['llm'], str):
|
||||
data['llm'] = [ data['llm'] ]
|
||||
|
||||
for llm_str in data['llm']:
|
||||
if llm_str not in LLM_NAME_MAP:
|
||||
return jsonify({'error': f"LLM named '{llm_str}' is not supported."})
|
||||
|
||||
if 'no_cache' in data and data['no_cache'] is True:
|
||||
remove_cached_responses(data['id'])
|
||||
|
||||
# Create a cache dir if it doesn't exist:
|
||||
create_dir_if_not_exists('cache')
|
||||
|
||||
# For each LLM, generate and cache responses:
|
||||
responses = {}
|
||||
llms = data['llm']
|
||||
params = data['params'] if 'params' in data else {}
|
||||
tempfilepath = f"cache/_temp_{data['id']}.txt"
|
||||
|
||||
async def query(llm_str: str) -> list:
|
||||
llm = LLM_NAME_MAP[llm_str]
|
||||
|
||||
# Check that storage path is valid:
|
||||
cache_filepath = os.path.join('cache', f"{data['id']}-{str(llm.name)}.json")
|
||||
if not is_valid_filepath(cache_filepath):
|
||||
return jsonify({'error': f'Invalid filepath: {cache_filepath}'})
|
||||
|
||||
# Create an object to query the LLM, passing a file for cache'ing responses
|
||||
prompter = PromptLLM(data['prompt'], storageFile=cache_filepath)
|
||||
|
||||
# Prompt the LLM with all permutations of the input prompt template:
|
||||
# NOTE: If the responses are already cache'd, this just loads them (no LLM is queried, saving $$$)
|
||||
resps = []
|
||||
try:
|
||||
print(f'Querying {llm}...')
|
||||
async for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
|
||||
resps.append(response)
|
||||
print(f"collected response from {llm.name}:", str(response))
|
||||
|
||||
# Save the number of responses collected to a temp file on disk
|
||||
with open(tempfilepath, 'r') as f:
|
||||
txt = f.read().strip()
|
||||
|
||||
cur_data = json.loads(txt) if len(txt) > 0 else {}
|
||||
cur_data[llm_str] = len(resps)
|
||||
|
||||
with open(tempfilepath, 'w') as f:
|
||||
json.dump(cur_data, f)
|
||||
except Exception as e:
|
||||
print('error generating responses:', e)
|
||||
raise e
|
||||
|
||||
return {'llm': llm, 'responses': resps}
|
||||
|
||||
try:
|
||||
# Request responses simultaneously across LLMs
|
||||
tasks = [query(llm) for llm in llms]
|
||||
|
||||
# Await the responses from all queried LLMs
|
||||
llm_results = await asyncio.gather(*tasks)
|
||||
for item in llm_results:
|
||||
responses[item['llm']] = item['responses']
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)})
|
||||
|
||||
# Convert the responses into a more standardized format with less information
|
||||
res = [
|
||||
to_standard_format(r)
|
||||
for rs in responses.values()
|
||||
for r in rs
|
||||
]
|
||||
|
||||
# Remove the temp file used to stream progress updates:
|
||||
if os.path.exists(tempfilepath):
|
||||
os.remove(tempfilepath)
|
||||
|
||||
# Return all responses for all LLMs
|
||||
print('returning responses:', res)
|
||||
ret = jsonify({'responses': res})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@app.route('/app/execute', methods=['POST'])
|
||||
def execute():
|
||||
"""
|
||||
Executes a Python lambda function sent from JavaScript,
|
||||
over all cache'd responses with given id's.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'id': # a unique ID to refer to this information. Used when cache'ing responses.
|
||||
'code': str, # the body of the lambda function to evaluate, in form: lambda responses: <body>
|
||||
'responses': str | List[str] # the responses to run on; a unique ID or list of unique IDs of cache'd data,
|
||||
'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch.
|
||||
# If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt.
|
||||
'reduce_vars': unspecified | List[str] # the 'vars' to average over (mean, median, stdev, range)
|
||||
'script_paths': unspecified | List[str] # the paths to scripts to be added to the path before the lambda function is evaluated
|
||||
}
|
||||
|
||||
NOTE: This should only be run on your server on code you trust.
|
||||
There is no sandboxing; no safety. We assume you are the creator of the code.
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
# Check that all required info is here:
|
||||
if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
if not isinstance(data['id'], str) or len(data['id']) == 0:
|
||||
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
||||
if data['scope'] not in ('response', 'batch'):
|
||||
return jsonify({'error': "POST data scope is unknown. Must be either 'response' or 'batch'."})
|
||||
|
||||
# Check that the filepath used to cache eval'd responses is valid:
|
||||
cache_filepath = os.path.join('cache', f"{data['id']}.json")
|
||||
if not is_valid_filepath(cache_filepath):
|
||||
return jsonify({'error': f'Invalid filepath: {cache_filepath}'})
|
||||
|
||||
# Check format of responses:
|
||||
if not (isinstance(data['responses'], str) or isinstance(data['responses'], list)):
|
||||
return jsonify({'error': 'POST data responses is improper format.'})
|
||||
elif isinstance(data['responses'], str):
|
||||
data['responses'] = [ data['responses'] ]
|
||||
|
||||
# add the path to any scripts to the path:
|
||||
try:
|
||||
if 'script_paths' in data:
|
||||
for script_path in data['script_paths']:
|
||||
# get the folder of the script_path:
|
||||
script_folder = os.path.dirname(script_path)
|
||||
# check that the script_folder is valid, and it contains __init__.py
|
||||
if not os.path.exists(script_folder):
|
||||
print(script_folder, 'is not a valid script path.')
|
||||
continue
|
||||
|
||||
# add it to the path:
|
||||
sys.path.append(script_folder)
|
||||
print(f'added {script_folder} to sys.path')
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Could not add script path to sys.path. Error message:\n{str(e)}'})
|
||||
|
||||
# Create the evaluator function
|
||||
# DANGER DANGER!
|
||||
try:
|
||||
exec(data['code'], globals())
|
||||
|
||||
# Double-check that there is an 'evaluate' method in our namespace.
|
||||
# This will throw a NameError if not:
|
||||
evaluate
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
|
||||
|
||||
# Load all responses with the given ID:
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
all_evald_responses = []
|
||||
for cache_id in data['responses']:
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
if len(cache_files) == 0:
|
||||
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
|
||||
|
||||
# To avoid loading all response files into memory at once, we'll run the evaluator on each file:
|
||||
for filename in cache_files:
|
||||
|
||||
# Load the raw responses from the cache
|
||||
responses = load_cache_json(os.path.join('cache', filename))
|
||||
if len(responses) == 0: continue
|
||||
|
||||
# Run the evaluator over them:
|
||||
# NOTE: 'evaluate' here was defined dynamically from 'exec' above.
|
||||
try:
|
||||
evald_responses = run_over_responses(evaluate, responses, scope=data['scope'])
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Error encountered while trying to run "evaluate" method:\n{str(e)}'})
|
||||
|
||||
# Convert to standard format:
|
||||
std_evald_responses = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in evald_responses.items()
|
||||
]
|
||||
|
||||
# Perform any reduction operations:
|
||||
if 'reduce_vars' in data and len(data['reduce_vars']) > 0:
|
||||
std_evald_responses = reduce_responses(
|
||||
std_evald_responses,
|
||||
vars=data['reduce_vars']
|
||||
)
|
||||
|
||||
all_evald_responses.extend(std_evald_responses)
|
||||
|
||||
# Store the evaluated responses in a new cache json:
|
||||
with open(cache_filepath, "w") as f:
|
||||
json.dump(all_evald_responses, f)
|
||||
|
||||
ret = jsonify({'responses': all_evald_responses})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@app.route('/app/checkEvalFunc', methods=['POST'])
|
||||
def checkEvalFunc():
|
||||
"""
|
||||
Tries to compile a Python lambda function sent from JavaScript.
|
||||
Returns a dict with 'result':true if it compiles without raising an exception;
|
||||
'result':false (and an 'error' property with a message) if not.
|
||||
|
||||
POST'd data should be in form:
|
||||
{
|
||||
'code': str, # the body of the lambda function to evaluate, in form: lambda responses: <body>
|
||||
}
|
||||
|
||||
NOTE: This should only be run on your server on code you trust.
|
||||
There is no sandboxing; no safety. We assume you are the creator of the code.
|
||||
"""
|
||||
data = request.get_json()
|
||||
if 'code' not in data:
|
||||
return jsonify({'result': False, 'error': f'Could not evaluate code. Error message:\n{str(e)}'})
|
||||
|
||||
# DANGER DANGER! Running exec on code passed through front-end. Make sure it's trusted!
|
||||
try:
|
||||
exec(data['code'], globals())
|
||||
|
||||
# Double-check that there is an 'evaluate' method in our namespace.
|
||||
# This will throw a NameError if not:
|
||||
evaluate
|
||||
return jsonify({'result': True})
|
||||
except Exception as e:
|
||||
return jsonify({'result': False, 'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
|
||||
|
||||
@app.route('/app/grabResponses', methods=['POST'])
|
||||
def grabResponses():
|
||||
"""
|
||||
Returns all responses with the specified id(s)
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'responses': <the ids to grab>
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
# Check format of responses:
|
||||
if not (isinstance(data['responses'], str) or isinstance(data['responses'], list)):
|
||||
return jsonify({'error': 'POST data responses is improper format.'})
|
||||
elif isinstance(data['responses'], str):
|
||||
data['responses'] = [ data['responses'] ]
|
||||
|
||||
# Load all responses with the given ID:
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
responses = []
|
||||
for cache_id in data['responses']:
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
if len(cache_files) == 0:
|
||||
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
|
||||
|
||||
for filename in cache_files:
|
||||
res = load_cache_json(os.path.join('cache', filename))
|
||||
if isinstance(res, dict):
|
||||
# Convert to standard response format
|
||||
res = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in res.items()
|
||||
]
|
||||
responses.extend(res)
|
||||
|
||||
ret = jsonify({'responses': responses})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
def run_server(host="", port=8000, cmd_args=None):
|
||||
if cmd_args is not None and cmd_args.dummy_responses:
|
||||
global PromptLLM
|
||||
global extract_responses
|
||||
PromptLLM = PromptLLMDummy
|
||||
extract_responses = lambda r, llm: r['response']
|
||||
|
||||
app.run(host=host, port=port)
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("Run app.py instead.")
|
@ -1,7 +1,7 @@
|
||||
from typing import Dict, Tuple, List, Union
|
||||
from enum import Enum
|
||||
import openai
|
||||
import json, os, time
|
||||
import json, os, time, asyncio
|
||||
|
||||
DALAI_MODEL = None
|
||||
DALAI_RESPONSE = None
|
||||
@ -23,6 +23,7 @@ async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float =
|
||||
if model not in model_map:
|
||||
raise Exception(f"Could not find OpenAI chat model {model}")
|
||||
model = model_map[model]
|
||||
print(f"Querying OpenAI model '{model}' with prompt '{prompt}'...")
|
||||
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
|
||||
query = {
|
||||
"model": model,
|
||||
@ -102,7 +103,7 @@ async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperat
|
||||
|
||||
# Blocking --wait for request to complete:
|
||||
while DALAI_RESPONSE is None:
|
||||
time.sleep(0.01)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
response = DALAI_RESPONSE['response']
|
||||
if response[-5:] == '<end>': # strip ending <end> tag, if present
|
||||
|
@ -1,5 +1,8 @@
|
||||
dalaipy==2.0.2
|
||||
flask[async]
|
||||
flask_cors
|
||||
flask_socketio
|
||||
openai
|
||||
python-socketio
|
||||
python-socketio
|
||||
dalaipy==2.0.2
|
||||
gevent-websocket
|
||||
werkzeug
|
@ -1,51 +0,0 @@
|
||||
<html>
|
||||
<head>
|
||||
<title>Test Flask backend</title>
|
||||
</head>
|
||||
<body>
|
||||
<button onclick="test_query()">Test query LLM!</button>
|
||||
<button onclick="test_exec()">Test evaluating responses!</button>
|
||||
</body>
|
||||
<script>
|
||||
|
||||
function test_exec() {
|
||||
const response = fetch(BASE_URL + 'execute', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
id: 'eval',
|
||||
code: 'return len(response)',
|
||||
responses: 'test',
|
||||
}),
|
||||
}).then(function(response) {
|
||||
return response.json();
|
||||
}).then(function(json) {
|
||||
console.log(json);
|
||||
});
|
||||
}
|
||||
|
||||
function test_query() {
|
||||
const response = fetch(BASE_URL + 'queryllm', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
id: 'test',
|
||||
llm: 'gpt3.5',
|
||||
params: {
|
||||
temperature: 1.0,
|
||||
n: 1,
|
||||
},
|
||||
prompt: 'What is the capital of ${country}?',
|
||||
vars: {
|
||||
country: ['Sweden', 'Uganda', 'Japan']
|
||||
},
|
||||
}),
|
||||
}).then(function(response) {
|
||||
return response.json();
|
||||
}).then(function(json) {
|
||||
console.log(json);
|
||||
});
|
||||
}
|
||||
|
||||
</script>
|
||||
</html>
|
@ -1,8 +0,0 @@
|
||||
from promptengine.utils import LLM, call_dalai
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("Testing a single response...")
|
||||
call_dalai(llm_name='alpaca.7B', port=4000, prompt='Write a poem about how an AI will escape the prison of its containment.', n=1, temperature=0.5)
|
||||
|
||||
print("Testing multiple responses...")
|
||||
call_dalai(llm_name='alpaca.7B', port=4000, prompt='Was George Washington a good person?', n=3, temperature=0.5)
|
Loading…
x
Reference in New Issue
Block a user