From f635abc148e00395eb80019f51980c29b0e9ef29 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Sat, 6 May 2023 13:00:18 -0400 Subject: [PATCH] Cleanup --- chain-forge/src/PromptNode.js | 10 +++- python-backend/app.py | 72 ++++++++++++++------------ python-backend/flask_app.py | 75 +++------------------------- python-backend/promptengine/query.py | 7 --- python-backend/promptengine/utils.py | 1 + python-backend/server.py | 3 -- python-backend/test.html | 51 ------------------- python-backend/test_dalai.py | 8 --- 8 files changed, 57 insertions(+), 170 deletions(-) delete mode 100644 python-backend/server.py delete mode 100644 python-backend/test.html delete mode 100644 python-backend/test_dalai.py diff --git a/chain-forge/src/PromptNode.js b/chain-forge/src/PromptNode.js index 88a838b..2d60c29 100644 --- a/chain-forge/src/PromptNode.js +++ b/chain-forge/src/PromptNode.js @@ -286,10 +286,18 @@ const PromptNode = ({ data, id }) => { // Request progress bar updates socket.emit("queryllm", {'id': id, 'max': max_responses}); }); + + // Socket connection could not be established + socket.on("connect_error", (error) => { + console.log("Socket connection failed:", error.message); + socket.disconnect(); + }); + + // Socket disconnected socket.on("disconnect", (msg) => { console.log(msg); }); - + // The current progress, a number specifying how many responses collected so far: socket.on("response", (counts) => { console.log(counts); diff --git a/python-backend/app.py b/python-backend/app.py index c573ea3..bc658a3 100644 --- a/python-backend/app.py +++ b/python-backend/app.py @@ -4,32 +4,24 @@ from statistics import mean, median, stdev from flask import Flask, request, jsonify from flask_cors import CORS from flask_socketio import SocketIO -# from werkzeug.middleware.dispatcher import DispatcherMiddleware from flask_app import run_server from promptengine.query import PromptLLM, PromptLLMDummy from promptengine.template import PromptTemplate, PromptPermutationGenerator from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists -# Setup the main app +# Setup the socketio app # BUILD_DIR = "../chain-forge/build" # STATIC_DIR = BUILD_DIR + '/static' app = Flask(__name__) #, static_folder=STATIC_DIR, template_folder=BUILD_DIR) -# Set up CORS for specific routes -# cors = CORS(app, resources={r"/api/*": {"origins": "*"}}) - # Initialize Socket.IO socketio = SocketIO(app, cors_allowed_origins="*", async_mode="gevent") -# Create a dispatcher connecting apps. -# app.wsgi_app = DispatcherMiddleware(app.wsgi_app, {"/app": flask_server}) +# Set up CORS for specific routes +# cors = CORS(app, resources={r"/api/*": {"origins": "*"}}) -# Wait a max of a full minute (60 seconds) for the response count to update, before exiting. -MAX_WAIT_TIME = 60 - -# import threading -# thread = None -# thread_lock = threading.Lock() +# Wait a max of a full 3 minutes (180 seconds) for the response count to update, before exiting. +MAX_WAIT_TIME = 180 def countdown(): n = 10 @@ -38,21 +30,49 @@ def countdown(): socketio.emit('response', n, namespace='/queryllm') n -= 1 -def readCounts(id, max_count): +@socketio.on('queryllm', namespace='/queryllm') +def readCounts(data): + id = data['id'] + max_count = data['max'] + tempfilepath = f'cache/_temp_{id}.txt' + + # Check that temp file exists. If it doesn't, something went wrong with setup on Flask's end: + if not os.path.exists(tempfilepath): + print(f"Error: Temp file not found at path {tempfilepath}. Cannot stream querying progress.") + socketio.emit('finish', 'temp file not found', namespace='/queryllm') + i = 0 - n = 0 last_n = 0 - while i < MAX_WAIT_TIME and n < max_count: - with open(f'cache/_temp_{id}.txt', 'r') as f: - queries = json.load(f) + init_run = True + while i < MAX_WAIT_TIME and last_n < max_count: + + # Open the temp file to read the progress so far: + try: + with open(tempfilepath, 'r') as f: + queries = json.load(f) + except FileNotFoundError as e: + # If the temp file was deleted during executing, the Flask 'queryllm' func must've terminated successfully: + socketio.emit('finish', 'success', namespace='/queryllm') + return + + # Calculate the total sum of responses + # TODO: This is a naive approach; we need to make this more complex and factor in cache'ing in future n = sum([int(n) for llm, n in queries.items()]) - socketio.emit('response', queries, namespace='/queryllm') - socketio.sleep(0.1) - if last_n != n: + + # If something's changed... + if init_run or last_n != n: i = 0 last_n = n + init_run = False + + # Update the React front-end with the current progress + socketio.emit('response', queries, namespace='/queryllm') + else: i += 0.1 + + # Wait a bit before reading the file again + socketio.sleep(0.1) if i >= MAX_WAIT_TIME: print(f"Error: Waited maximum {MAX_WAIT_TIME} seconds for response count to update. Exited prematurely.") @@ -61,18 +81,8 @@ def readCounts(id, max_count): print("All responses loaded!") socketio.emit('finish', 'success', namespace='/queryllm') -@socketio.on('queryllm', namespace='/queryllm') -def testSocket(data): - readCounts(data['id'], data['max']) - # countdown() - # global thread - # with thread_lock: - # if thread is None: - # thread = socketio.start_background_task(target=countdown) - def run_socketio_server(socketio, port): socketio.run(app, host="localhost", port=8001) - # flask_server.run(host="localhost", port=8000, debug=True) if __name__ == "__main__": diff --git a/python-backend/flask_app.py b/python-backend/flask_app.py index 76e2e19..e63fea0 100644 --- a/python-backend/flask_app.py +++ b/python-backend/flask_app.py @@ -8,6 +8,7 @@ from promptengine.query import PromptLLM, PromptLLMDummy from promptengine.template import PromptTemplate, PromptPermutationGenerator from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists +# Setup Flask app to serve static version of React front-end BUILD_DIR = "../chain-forge/build" STATIC_DIR = BUILD_DIR + '/static' app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR) @@ -15,19 +16,11 @@ app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR) # Set up CORS for specific routes cors = CORS(app, resources={r"/*": {"origins": "*"}}) -# Serve React app +# Serve React app (static; no hot reloading) @app.route("/") def index(): return render_template("index.html") -# @app.route('/', defaults={'path': ''}) -# @app.route('/') -# def serve(path): -# if path != "" and os.path.exists(BUILD_DIR + '/' + path): -# return send_from_directory(BUILD_DIR, path) -# else: -# return send_from_directory(BUILD_DIR, 'index.html') - LLM_NAME_MAP = { 'gpt3.5': LLM.ChatGPT, 'alpaca.7B': LLM.Alpaca7B, @@ -128,7 +121,6 @@ def reduce_responses(responses: list, vars: list) -> list: # E.g. {(var1_val, var2_val): [responses] } bucketed_resp = {} for r in responses: - print(r) tup_key = tuple([r['vars'][v] for v in include_vars]) if tup_key in bucketed_resp: bucketed_resp[tup_key].append(r) @@ -156,62 +148,6 @@ def reduce_responses(responses: list, vars: list) -> list: return ret -@app.route('/app/test', methods=['GET']) -def test(): - return "Hello, world!" - -# @socketio.on('queryllm', namespace='/queryllm') -# def handleQueryAsync(data): -# print("reached handleQueryAsync") -# socketio.start_background_task(queryLLM, emitLLMResponse) - -# def emitLLMResponse(result): -# socketio.emit('response', result) - -""" - Testing sockets. The following function can - communicate to React via with the JS code: - - const socket = io(BASE_URL + 'queryllm', { - transports: ["websocket"], - cors: { - origin: "http://localhost:3000/", - }, - }); - - socket.on("connect", (data) => { - socket.emit("queryllm", "hello"); - }); - socket.on("disconnect", (data) => { - console.log("disconnected"); - }); - socket.on("response", (data) => { - console.log(data); - }); -""" -# def background_thread(): -# n = 10 -# while n > 0: -# socketio.sleep(0.5) -# socketio.emit('response', n, namespace='/queryllm') -# n -= 1 - -# @socketio.on('queryllm', namespace='/queryllm') -# def testSocket(data): -# print(data) -# global thread -# with thread_lock: -# if thread is None: -# thread = socketio.start_background_task(target=background_thread) - -# @socketio.on('queryllm', namespace='/queryllm') -# def handleQuery(data): -# print(data) - -# def handleConnect(): -# print('here') -# socketio.emit('response', 'goodbye', namespace='/') - @app.route('/app/countQueriesRequired', methods=['POST']) def countQueries(): """ @@ -369,6 +305,10 @@ async def queryLLM(): for r in rs ] + # Remove the temp file used to stream progress updates: + if os.path.exists(tempfilepath): + os.remove(tempfilepath) + # Return all responses for all LLMs print('returning responses:', res) ret = jsonify({'responses': res}) @@ -425,7 +365,6 @@ def execute(): # check that the script_folder is valid, and it contains __init__.py if not os.path.exists(script_folder): print(script_folder, 'is not a valid script path.') - print(os.path.exists(script_folder)) continue # add it to the path: @@ -540,7 +479,6 @@ def grabResponses(): # Load all responses with the given ID: all_cache_files = get_files_at_dir('cache/') - print(all_cache_files) responses = [] for cache_id in data['responses']: cache_files = get_filenames_with_id(all_cache_files, cache_id) @@ -557,7 +495,6 @@ def grabResponses(): ] responses.extend(res) - print(responses) ret = jsonify({'responses': responses}) ret.headers.add('Access-Control-Allow-Origin', '*') return ret diff --git a/python-backend/promptengine/query.py b/python-backend/promptengine/query.py index 41d99bb..d0cfdcc 100644 --- a/python-backend/promptengine/query.py +++ b/python-backend/promptengine/query.py @@ -79,10 +79,8 @@ class PromptPipeline: tasks.append(self._prompt_llm(llm, prompt, n, temperature)) else: # Blocking. Await + yield a single LLM call. - print('reached') _, query, response = await self._prompt_llm(llm, prompt, n, temperature) info = prompt.fill_history - print('back') # Save the response to a JSON file responses[str(prompt)] = { @@ -105,11 +103,8 @@ class PromptPipeline: # Yield responses as they come in for task in asyncio.as_completed(tasks): # Collect the response from the earliest completed task - print(f'awaiting a task to call {llm.name}...') prompt, query, response = await task - print('Completed!') - # Each prompt has a history of what was filled in from its base template. # This data --like, "class", "language", "library" etc --can be useful when parsing responses. info = prompt.fill_history @@ -154,10 +149,8 @@ class PromptPipeline: async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Dict]: if llm is LLM.ChatGPT or llm is LLM.GPT4: - print('calling chatgpt and awaiting') query, response = await call_chatgpt(str(prompt), model=llm, n=n, temperature=temperature) elif llm is LLM.Alpaca7B: - print('calling dalai alpaca.7b and awaiting') query, response = await call_dalai(llm_name='alpaca.7B', port=4000, prompt=str(prompt), n=n, temperature=temperature) else: raise Exception(f"Language model {llm} is not supported.") diff --git a/python-backend/promptengine/utils.py b/python-backend/promptengine/utils.py index d0bcd2c..d89b51e 100644 --- a/python-backend/promptengine/utils.py +++ b/python-backend/promptengine/utils.py @@ -23,6 +23,7 @@ async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float = if model not in model_map: raise Exception(f"Could not find OpenAI chat model {model}") model = model_map[model] + print(f"Querying OpenAI model '{model}' with prompt '{prompt}'...") system_msg = "You are a helpful assistant." if system_msg is None else system_msg query = { "model": model, diff --git a/python-backend/server.py b/python-backend/server.py deleted file mode 100644 index 5f41629..0000000 --- a/python-backend/server.py +++ /dev/null @@ -1,3 +0,0 @@ -import subprocess - -subprocess.run("python socketio_app.py & python app.py & wait", shell=True) \ No newline at end of file diff --git a/python-backend/test.html b/python-backend/test.html deleted file mode 100644 index 110ca67..0000000 --- a/python-backend/test.html +++ /dev/null @@ -1,51 +0,0 @@ - - - Test Flask backend - - - - - - - \ No newline at end of file diff --git a/python-backend/test_dalai.py b/python-backend/test_dalai.py deleted file mode 100644 index 3aadbbf..0000000 --- a/python-backend/test_dalai.py +++ /dev/null @@ -1,8 +0,0 @@ -from promptengine.utils import LLM, call_dalai - -if __name__ == '__main__': - print("Testing a single response...") - call_dalai(llm_name='alpaca.7B', port=4000, prompt='Write a poem about how an AI will escape the prison of its containment.', n=1, temperature=0.5) - - print("Testing multiple responses...") - call_dalai(llm_name='alpaca.7B', port=4000, prompt='Was George Washington a good person?', n=3, temperature=0.5) \ No newline at end of file