From 34e0e465c11b2c9b8d7d03f6a96667c1e2b89e8f Mon Sep 17 00:00:00 2001 From: ianarawjo Date: Thu, 4 May 2023 20:46:48 -0400 Subject: [PATCH 01/20] Add Contributor Guide --- CONTRIBUTOR_GUIDE.md | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 CONTRIBUTOR_GUIDE.md diff --git a/CONTRIBUTOR_GUIDE.md b/CONTRIBUTOR_GUIDE.md new file mode 100644 index 0000000..910078f --- /dev/null +++ b/CONTRIBUTOR_GUIDE.md @@ -0,0 +1,37 @@ +# Contributor Guide + +This is a guide to running the current version of ChainForge, for people who want to develop or extend it. +Note that this document will change in the future. + +## Getting Started +### Install requirements +Before you can run ChainForge, you need to install dependencies. `cd` into `python-backend` and run + +```bash +pip install -r requirements.txt +``` + +to install requirements. Ideally, you will run this in a `virtualenv`. + +To install Node requirements, `cd` into `chain-forge` and run: + +```bash +npm install +``` + +### Running ChainForge +To serve ChainForge, you currently have to spin up at least two servers: +one for React front-end, one for the Flask backend. + +`cd` into `chain-forge` directory and run: + +``` +npm run start +``` + +to serve the React front-end. Then in a separate terminal `cd` into `python-backend` and run: + +```bash +python app.py --port 8000 +``` +You can add the `--dummy-responses` flag in case you're worried about making calls to OpenAI. From 1eee1c1d11123b99c0c4c66b8c4db7fea8f7a942 Mon Sep 17 00:00:00 2001 From: Priyan Vaithilingam Date: Fri, 5 May 2023 16:10:13 -0400 Subject: [PATCH 02/20] DO NOT MERGE TO MAIN.. WIP --- .gitignore | 3 ++- chain-forge/src/App.css | 5 +++++ chain-forge/src/App.js | 10 +++++++++- chain-forge/src/CsvNode.js | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 chain-forge/src/CsvNode.js diff --git a/.gitignore b/.gitignore index 017cae5..903297f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ __pycache__ python-backend/cache # venv -venv \ No newline at end of file +venv +node_modules diff --git a/chain-forge/src/App.css b/chain-forge/src/App.css index 6071c72..47b2fab 100644 --- a/chain-forge/src/App.css +++ b/chain-forge/src/App.css @@ -45,3 +45,8 @@ path.react-flow__edge-path:hover { transform: rotate(360deg); } } + +.rich-editor { + min-width: 500px; + min-height: 500px; +} \ No newline at end of file diff --git a/chain-forge/src/App.js b/chain-forge/src/App.js index 50698a2..72c8725 100644 --- a/chain-forge/src/App.js +++ b/chain-forge/src/App.js @@ -15,6 +15,7 @@ import EvaluatorNode from './EvaluatorNode'; import VisNode from './VisNode'; import InspectNode from './InspectorNode'; import ScriptNode from './ScriptNode'; +import CsvNode from './CsvNode'; import './text-fields-node.css'; // State management (from https://reactflow.dev/docs/guides/state-management/) @@ -40,7 +41,8 @@ const nodeTypes = { evaluator: EvaluatorNode, vis: VisNode, inspect: InspectNode, - script: ScriptNode + script: ScriptNode, + csv: CsvNode, }; const connectionLineStyle = { stroke: '#ddd' }; @@ -91,6 +93,11 @@ const App = () => { addNode({ id: 'scriptNode-'+Date.now(), type: 'script', data: {}, position: {x: x-200, y:y-100} }); }; + const addCsvNode = (event) => { + const { x, y } = getViewportCenter(); + addNode({ id: 'csvNode-'+Date.now(), type: 'csv', data: {}, position: {x: x-200, y:y-100} }); + }; + /** * SAVING / LOADING, IMPORT / EXPORT (from JSON) */ @@ -201,6 +208,7 @@ const App = () => { + diff --git a/chain-forge/src/CsvNode.js b/chain-forge/src/CsvNode.js new file mode 100644 index 0000000..0cf1f2f --- /dev/null +++ b/chain-forge/src/CsvNode.js @@ -0,0 +1,35 @@ +import React, { useState, useRef, useEffect, useCallback } from 'react'; +import { Card, Text} from '@mantine/core'; +import useStore from './store'; +import NodeLabel from './NodeLabelComponent' +import { IconFileText } from '@tabler/icons-react'; +import { edit } from 'ace-builds'; + +const CsvNode = ({ data, id }) => { + const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); + const content = "test content"; + + + // Handle a change in a text fields' input. + const handleInputChange = useCallback((event) => { + // Update the data for this text fields' id. + let new_data = { 'text': event.target.value }; + setDataPropsForNode(id, new_data); + }, [id, setDataPropsForNode]); + + return ( +
+ } /> + + + + With Fjord Tours you can explore more of the magical fjord landscapes with tours and + activities on and around the fjords of Norway + + + +
+ ); +}; + +export default CsvNode; \ No newline at end of file From f6d7996f972cd69d1e29a2e4aca1bb4f4988bcfc Mon Sep 17 00:00:00 2001 From: ianarawjo Date: Sat, 6 May 2023 13:02:47 -0400 Subject: [PATCH 03/20] Stream responses and Serve React, Flask, SocketIO from single Python script (#23) * Live progress wheels * Dynamic run tooltip for prompt node * Run React and Flask and Socketio with single script --- .gitignore | 1 + chain-forge/package-lock.json | 80 ++++ chain-forge/package.json | 1 + chain-forge/src/EvaluatorNode.js | 3 +- chain-forge/src/InspectorNode.js | 2 +- chain-forge/src/LLMItemButtonGroup.js | 12 +- chain-forge/src/LLMListComponent.js | 2 +- chain-forge/src/LLMListItem.js | 4 +- chain-forge/src/NodeLabelComponent.js | 16 +- chain-forge/src/PromptNode.js | 273 +++++++++++--- chain-forge/src/VisNode.js | 2 +- chain-forge/src/text-fields-node.css | 8 + python-backend/app.py | 480 ++++-------------------- python-backend/flask_app.py | 512 ++++++++++++++++++++++++++ python-backend/promptengine/utils.py | 5 +- python-backend/requirements.txt | 7 +- python-backend/test.html | 51 --- python-backend/test_dalai.py | 8 - 18 files changed, 928 insertions(+), 539 deletions(-) create mode 100644 python-backend/flask_app.py delete mode 100644 python-backend/test.html delete mode 100644 python-backend/test_dalai.py diff --git a/.gitignore b/.gitignore index 017cae5..f39dd6c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.DS_Store chain-forge/node_modules +chain-forge/build __pycache__ python-backend/cache diff --git a/chain-forge/package-lock.json b/chain-forge/package-lock.json index eabd204..8cc28a7 100644 --- a/chain-forge/package-lock.json +++ b/chain-forge/package-lock.json @@ -38,6 +38,7 @@ "react-flow-renderer": "^10.3.17", "react-plotly.js": "^2.6.0", "react-scripts": "5.0.1", + "socket.io-client": "^4.6.1", "styled-components": "^5.3.10", "uuidv4": "^6.2.13", "web-vitals": "^2.1.4", @@ -3913,6 +3914,11 @@ "@sinonjs/commons": "^1.7.0" } }, + "node_modules/@socket.io/component-emitter": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz", + "integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg==" + }, "node_modules/@surma/rollup-plugin-off-main-thread": { "version": "2.2.3", "resolved": "https://registry.npmjs.org/@surma/rollup-plugin-off-main-thread/-/rollup-plugin-off-main-thread-2.2.3.tgz", @@ -8718,6 +8724,46 @@ "once": "^1.4.0" } }, + "node_modules/engine.io-client": { + "version": "6.4.0", + "resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.4.0.tgz", + "integrity": "sha512-GyKPDyoEha+XZ7iEqam49vz6auPnNJ9ZBfy89f+rMMas8AuiMWOZ9PVzu8xb9ZC6rafUqiGHSCfu22ih66E+1g==", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.1", + "engine.io-parser": "~5.0.3", + "ws": "~8.11.0", + "xmlhttprequest-ssl": "~2.0.0" + } + }, + "node_modules/engine.io-client/node_modules/ws": { + "version": "8.11.0", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz", + "integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": "^5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/engine.io-parser": { + "version": "5.0.6", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.0.6.tgz", + "integrity": "sha512-tjuoZDMAdEhVnSFleYPCtdL2GXwVTGtNjoeJd9IhIG3C1xs9uwxqRNEu5WpnDZCaozwVlK/nuQhpodhXSIMaxw==", + "engines": { + "node": ">=10.0.0" + } + }, "node_modules/enhanced-resolve": { "version": "5.12.0", "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.12.0.tgz", @@ -18338,6 +18384,32 @@ "node": ">=8" } }, + "node_modules/socket.io-client": { + "version": "4.6.1", + "resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.6.1.tgz", + "integrity": "sha512-5UswCV6hpaRsNg5kkEHVcbBIXEYoVbMQaHJBXJCyEQ+CiFPV1NIOY0XOFWG4XR4GZcB8Kn6AsRs/9cy9TbqVMQ==", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.2", + "engine.io-client": "~6.4.0", + "socket.io-parser": "~4.2.1" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/socket.io-parser": { + "version": "4.2.2", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.2.tgz", + "integrity": "sha512-DJtziuKypFkMMHCm2uIshOYC7QaylbtzQwiMYDuCKy3OPkjLzu4B2vAhTlqipRHHzrI0NJeBAizTK7X+6m1jVw==", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.1" + }, + "engines": { + "node": ">=10.0.0" + } + }, "node_modules/sockjs": { "version": "0.3.24", "resolved": "https://registry.npmjs.org/sockjs/-/sockjs-0.3.24.tgz", @@ -21030,6 +21102,14 @@ "resolved": "https://registry.npmjs.org/xmlchars/-/xmlchars-2.2.0.tgz", "integrity": "sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==" }, + "node_modules/xmlhttprequest-ssl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.0.0.tgz", + "integrity": "sha512-QKxVRxiRACQcVuQEYFsI1hhkrMlrXHPegbbd1yn9UHOmRxY+si12nQYzri3vbzt8VdTTRviqcKxcyllFas5z2A==", + "engines": { + "node": ">=0.4.0" + } + }, "node_modules/xtend": { "version": "4.0.2", "resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz", diff --git a/chain-forge/package.json b/chain-forge/package.json index e799d6e..907e9f7 100644 --- a/chain-forge/package.json +++ b/chain-forge/package.json @@ -33,6 +33,7 @@ "react-flow-renderer": "^10.3.17", "react-plotly.js": "^2.6.0", "react-scripts": "5.0.1", + "socket.io-client": "^4.6.1", "styled-components": "^5.3.10", "uuidv4": "^6.2.13", "web-vitals": "^2.1.4", diff --git a/chain-forge/src/EvaluatorNode.js b/chain-forge/src/EvaluatorNode.js index 80c1a5a..e7614f5 100644 --- a/chain-forge/src/EvaluatorNode.js +++ b/chain-forge/src/EvaluatorNode.js @@ -74,7 +74,7 @@ const EvaluatorNode = ({ data, id }) => { console.log(script_paths); // Run evaluator in backend const codeTextOnRun = codeText + ''; - fetch(BASE_URL + 'execute', { + fetch(BASE_URL + 'app/execute', { method: 'POST', headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, body: JSON.stringify({ @@ -155,6 +155,7 @@ const EvaluatorNode = ({ data, id }) => { status={status} alertModal={alertModal} handleRunClick={handleRunClick} + runButtonTooltip="Run evaluator over inputs" /> { console.log(input_node_ids); // Grab responses associated with those ids: - fetch(BASE_URL + 'grabResponses', { + fetch(BASE_URL + 'app/grabResponses', { method: 'POST', headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, body: JSON.stringify({ diff --git a/chain-forge/src/LLMItemButtonGroup.js b/chain-forge/src/LLMItemButtonGroup.js index efcc5d8..f2259a0 100644 --- a/chain-forge/src/LLMItemButtonGroup.js +++ b/chain-forge/src/LLMItemButtonGroup.js @@ -1,8 +1,8 @@ import { useDisclosure } from '@mantine/hooks'; -import { Modal, Button, Group } from '@mantine/core'; +import { Modal, Button, Group, RingProgress } from '@mantine/core'; import { IconSettings, IconTrash } from '@tabler/icons-react'; -export default function LLMItemButtonGroup( {onClickTrash, onClickSettings} ) { +export default function LLMItemButtonGroup( {onClickTrash, onClickSettings, ringProgress} ) { const [opened, { open, close }] = useDisclosure(false); return ( @@ -12,7 +12,13 @@ export default function LLMItemButtonGroup( {onClickTrash, onClickSettings} ) { - + {ringProgress !== undefined ? + (ringProgress > 0 ? + () : + (
)) + : (<>) + } +
diff --git a/chain-forge/src/LLMListComponent.js b/chain-forge/src/LLMListComponent.js index bd72f32..13eebaf 100644 --- a/chain-forge/src/LLMListComponent.js +++ b/chain-forge/src/LLMListComponent.js @@ -63,7 +63,7 @@ export default function LLMList({llms, onItemsChange}) { {items.map((item, index) => ( {(provided, snapshot) => ( - + )} ))} diff --git a/chain-forge/src/LLMListItem.js b/chain-forge/src/LLMListItem.js index 2d948c8..f5fc3bf 100644 --- a/chain-forge/src/LLMListItem.js +++ b/chain-forge/src/LLMListItem.js @@ -23,7 +23,7 @@ export const DragItem = styled.div` flex-direction: column; `; -const LLMListItem = ({ item, provided, snapshot, removeCallback }) => { +const LLMListItem = ({ item, provided, snapshot, removeCallback, progress }) => { return ( { >
{item.emoji} {item.name} - removeCallback(item.key)} /> + removeCallback(item.key)} ringProgress={progress} />
diff --git a/chain-forge/src/NodeLabelComponent.js b/chain-forge/src/NodeLabelComponent.js index fe2ff33..81a2686 100644 --- a/chain-forge/src/NodeLabelComponent.js +++ b/chain-forge/src/NodeLabelComponent.js @@ -4,9 +4,9 @@ import 'react-edit-text/dist/index.css'; import StatusIndicator from './StatusIndicatorComponent'; import AlertModal from './AlertModal'; import { useState, useEffect} from 'react'; -import { CloseButton } from '@mantine/core'; +import { Tooltip } from '@mantine/core'; -export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editable, status, alertModal, handleRunClick }) { +export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editable, status, alertModal, handleRunClick, handleRunHover, runButtonTooltip }) { const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); const [statusIndicator, setStatusIndicator] = useState('none'); const [runButton, setRunButton] = useState('none'); @@ -33,12 +33,20 @@ export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editabl useEffect(() => { if(handleRunClick !== undefined) { - setRunButton(); + const run_btn = (); + if (runButtonTooltip) + setRunButton( + + {run_btn} + + ); + else + setRunButton(run_btn); } else { setRunButton(<>); } - }, [handleRunClick]); + }, [handleRunClick, runButtonTooltip]); const handleCloseButtonClick = () => { removeNode(nodeId); diff --git a/chain-forge/src/PromptNode.js b/chain-forge/src/PromptNode.js index 2f52670..2d60c29 100644 --- a/chain-forge/src/PromptNode.js +++ b/chain-forge/src/PromptNode.js @@ -1,14 +1,13 @@ import React, { useEffect, useState, useRef, useCallback } from 'react'; import { Handle } from 'react-flow-renderer'; -import { Menu, Badge } from '@mantine/core'; +import { Menu, Badge, Progress } from '@mantine/core'; import { v4 as uuid } from 'uuid'; import useStore from './store'; -import StatusIndicator from './StatusIndicatorComponent' import NodeLabel from './NodeLabelComponent' import TemplateHooks from './TemplateHooksComponent' import LLMList from './LLMListComponent' -import AlertModal from './AlertModal' import {BASE_URL} from './store'; +import io from 'socket.io-client'; // Available LLMs const allLLMs = [ @@ -66,13 +65,29 @@ const PromptNode = ({ data, id }) => { // Selecting LLM models to prompt const [llmItems, setLLMItems] = useState(initLLMs.map((i, idx) => ({key: uuid(), ...i}))); const [llmItemsCurrState, setLLMItemsCurrState] = useState([]); + const resetLLMItemsProgress = useCallback(() => { + setLLMItems(llmItemsCurrState.map(item => { + item.progress = undefined; + return item; + })); + }, [llmItemsCurrState]); + + // Progress when querying responses + const [progress, setProgress] = useState(100); + const [runTooltip, setRunTooltip] = useState(null); + + const triggerAlert = (msg) => { + setProgress(100); + resetLLMItemsProgress(); + alertModal.current.trigger(msg); + }; const addModel = useCallback((model) => { // Get the item for that model let item = allLLMs.find(llm => llm.model === model); if (!item) { // This should never trigger, but in case it does: - alertModal.current.trigger(`Could not find model named '${model}' in list of available LLMs.`); + triggerAlert(`Could not find model named '${model}' in list of available LLMs.`); return; } @@ -116,6 +131,89 @@ const PromptNode = ({ data, id }) => { } }; + // Pull all inputs needed to request responses. + // Returns [prompt, vars dict] + const pullInputData = () => { + // Pull data from each source recursively: + const pulled_data = {}; + const get_outputs = (varnames, nodeId) => { + varnames.forEach(varname => { + // Find the relevant edge(s): + edges.forEach(e => { + if (e.target == nodeId && e.targetHandle == varname) { + // Get the immediate output: + let out = output(e.source, e.sourceHandle); + + // Save the var data from the pulled output + if (varname in pulled_data) + pulled_data[varname] = pulled_data[varname].concat(out); + else + pulled_data[varname] = out; + + // Get any vars that the output depends on, and recursively collect those outputs as well: + const n_vars = getNode(e.source).data.vars; + if (n_vars && Array.isArray(n_vars) && n_vars.length > 0) + get_outputs(n_vars, e.source); + } + }); + }); + }; + get_outputs(templateVars, id); + + // Get Pythonic version of the prompt, by adding a $ before any template variables in braces: + const to_py_template_format = (str) => str.replace(/(? { + pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val)); + }); + + return [py_prompt_template, pulled_data]; + }; + + // Ask the backend how many responses it needs to collect, given the input data: + const fetchResponseCounts = (prompt, vars, llms, rejected) => { + return fetch(BASE_URL + 'app/countQueriesRequired', { + method: 'POST', + headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, + body: JSON.stringify({ + prompt: prompt, + vars: vars, + llms: llms, + })}, rejected).then(function(response) { + return response.json(); + }, rejected).then(function(json) { + if (!json || !json.counts) { + throw new Error('Request was sent and received by backend server, but there was no response.'); + } + return json.counts; + }, rejected); + }; + + // On hover over the 'Run' button, request how many responses are required and update the tooltip. Soft fails. + const handleRunHover = () => { + // Check if there's at least one model in the list; if not, nothing to run on. + if (!llmItemsCurrState || llmItemsCurrState.length == 0) { + setRunTooltip('No LLMs to query.'); + return; + } + + // Get input data and prompt + const [py_prompt, pulled_vars] = pullInputData(); + const llms = llmItemsCurrState.map(item => item.model); + const num_llms = llms.length; + + // Fetch response counts from backend + fetchResponseCounts(py_prompt, pulled_vars, llms, (err) => { + console.warn(err.message); // soft fail + }).then((counts) => { + const n = counts[Object.keys(counts)[0]]; + const req = n > 1 ? 'requests' : 'request'; + setRunTooltip(`Will send ${n} ${req}` + (num_llms > 1 ? ' per LLM' : '')); + }); + }; + const handleRunClick = (event) => { // Go through all template hooks (if any) and check they're connected: const is_fully_connected = templateVars.every(varname => { @@ -123,63 +221,113 @@ const PromptNode = ({ data, id }) => { return edges.some(e => (e.target == id && e.targetHandle == varname)); }); - // console.log(templateHooks); + if (!is_fully_connected) { + console.log('Not connected! :('); + triggerAlert('Missing inputs to one or more template variables.'); + return; + } - if (is_fully_connected) { - console.log('Connected!'); + console.log('Connected!'); - // Check that there is at least one LLM selected: - if (llmItemsCurrState.length === 0) { - alert('Please select at least one LLM to prompt.') - return; - } + // Check that there is at least one LLM selected: + if (llmItemsCurrState.length === 0) { + alert('Please select at least one LLM to prompt.') + return; + } - // Set status indicator - setStatus('loading'); + // Set status indicator + setStatus('loading'); + setReponsePreviews([]); - // Pull data from each source, recursively: - const pulled_data = {}; - const get_outputs = (varnames, nodeId) => { - console.log(varnames); - varnames.forEach(varname => { - // Find the relevant edge(s): - edges.forEach(e => { - if (e.target == nodeId && e.targetHandle == varname) { - // Get the immediate output: - let out = output(e.source, e.sourceHandle); + const [py_prompt_template, pulled_data] = pullInputData(); - // Save the var data from the pulled output - if (varname in pulled_data) - pulled_data[varname] = pulled_data[varname].concat(out); - else - pulled_data[varname] = out; + let FINISHED_QUERY = false; + const rejected = (err) => { + setStatus('error'); + triggerAlert(err.message); + FINISHED_QUERY = true; + }; - // Get any vars that the output depends on, and recursively collect those outputs as well: - const n_vars = getNode(e.source).data.vars; - if (n_vars && Array.isArray(n_vars) && n_vars.length > 0) - get_outputs(n_vars, e.source); - } - }); - }); - }; - get_outputs(templateVars, id); + // Ask the backend to reset the scratchpad for counting queries: + const create_progress_scratchpad = () => { + return fetch(BASE_URL + 'app/createProgressFile', { + method: 'POST', + headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, + body: JSON.stringify({ + id: id, + })}, rejected); + }; - // Get Pythonic version of the prompt, by adding a $ before any template variables in braces: - const to_py_template_format = (str) => str.replace(/(? fetchResponseCounts( + py_prompt_template, pulled_data, llmItemsCurrState.map(item => item.model), rejected); - // Do the same for the vars, since vars can themselves be prompt templates: - Object.keys(pulled_data).forEach(varname => { - pulled_data[varname] = pulled_data[varname].map(val => to_py_template_format(val)); + // Open a socket to listen for progress + const open_progress_listener_socket = (response_counts) => { + // With the counts information we can create progress bars. Now we load a socket connection to + // the socketio server that will stream to us the current progress: + const socket = io('http://localhost:8001/' + 'queryllm', { + transports: ["websocket"], + cors: {origin: "http://localhost:8000/"}, }); - const rejected = (err) => { - setStatus('error'); - alertModal.current.trigger(err.message); - }; + const max_responses = Object.keys(response_counts).reduce((acc, llm) => acc + response_counts[llm], 0); - // Run all prompt permutations through the LLM to generate + cache responses: - fetch(BASE_URL + 'queryllm', { + // On connect to the server, ask it to give us the current progress + // for task 'queryllm' with id 'id', and stop when it reads progress >= 'max'. + socket.on("connect", (msg) => { + // Initialize progress bars to small amounts + setProgress(5); + setLLMItems(llmItemsCurrState.map(item => { + item.progress = 0; + return item; + })); + + // Request progress bar updates + socket.emit("queryllm", {'id': id, 'max': max_responses}); + }); + + // Socket connection could not be established + socket.on("connect_error", (error) => { + console.log("Socket connection failed:", error.message); + socket.disconnect(); + }); + + // Socket disconnected + socket.on("disconnect", (msg) => { + console.log(msg); + }); + + // The current progress, a number specifying how many responses collected so far: + socket.on("response", (counts) => { + console.log(counts); + if (!counts || FINISHED_QUERY) return; + + // Update individual progress bars + const num_llms = llmItemsCurrState.length; + setLLMItems(llmItemsCurrState.map(item => { + if (item.model in counts) + item.progress = counts[item.model] / (max_responses / num_llms) * 100; + return item; + })); + + // Update total progress bar + const total_num_resps = Object.keys(counts).reduce((acc, llm_name) => { + return acc + counts[llm_name]; + }, 0); + setProgress(total_num_resps / max_responses * 100); + }); + + // The process has finished; close the connection: + socket.on("finish", (msg) => { + console.log("finished:", msg); + socket.disconnect(); + }); + }; + + // Run all prompt permutations through the LLM to generate + cache responses: + const query_llms = () => { + return fetch(BASE_URL + 'app/queryllm', { method: 'POST', headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, body: JSON.stringify({ @@ -191,7 +339,7 @@ const PromptNode = ({ data, id }) => { temperature: 0.5, n: numGenerations, }, - no_cache: false, + no_cache: true, }), }, rejected).then(function(response) { return response.json(); @@ -205,6 +353,11 @@ const PromptNode = ({ data, id }) => { // Success! Change status to 'ready': setStatus('ready'); + // Remove progress bars + setProgress(100); + resetLLMItemsProgress(); + FINISHED_QUERY = true; + // Save prompt text so we remember what prompt we have responses cache'd for: setPromptTextOnLastRun(promptText); @@ -246,14 +399,15 @@ const PromptNode = ({ data, id }) => { alertModal.current.trigger(json.error || 'Unknown error when querying LLM'); } }, rejected); + }; - console.log(pulled_data); - } else { - console.log('Not connected! :('); - alertModal.current.trigger('Missing inputs to one or more template variables.') - - // TODO: Blink the names of unconnected params - } + // Now put it all together! + create_progress_scratchpad() + .then(fetch_resp_count) + .then(open_progress_listener_socket) + .then(query_llms) + .catch(rejected); + } const handleNumGenChange = (event) => { @@ -279,6 +433,8 @@ const PromptNode = ({ data, id }) => { status={status} alertModal={alertModal} handleRunClick={handleRunClick} + handleRunHover={handleRunHover} + runButtonTooltip={runTooltip} />