This commit is contained in:
Ian Arawjo 2023-05-06 13:00:18 -04:00
parent cdadee03f2
commit f635abc148
8 changed files with 57 additions and 170 deletions

View File

@ -286,10 +286,18 @@ const PromptNode = ({ data, id }) => {
// Request progress bar updates
socket.emit("queryllm", {'id': id, 'max': max_responses});
});
// Socket connection could not be established
socket.on("connect_error", (error) => {
console.log("Socket connection failed:", error.message);
socket.disconnect();
});
// Socket disconnected
socket.on("disconnect", (msg) => {
console.log(msg);
});
// The current progress, a number specifying how many responses collected so far:
socket.on("response", (counts) => {
console.log(counts);

View File

@ -4,32 +4,24 @@ from statistics import mean, median, stdev
from flask import Flask, request, jsonify
from flask_cors import CORS
from flask_socketio import SocketIO
# from werkzeug.middleware.dispatcher import DispatcherMiddleware
from flask_app import run_server
from promptengine.query import PromptLLM, PromptLLMDummy
from promptengine.template import PromptTemplate, PromptPermutationGenerator
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
# Setup the main app
# Setup the socketio app
# BUILD_DIR = "../chain-forge/build"
# STATIC_DIR = BUILD_DIR + '/static'
app = Flask(__name__) #, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
# Set up CORS for specific routes
# cors = CORS(app, resources={r"/api/*": {"origins": "*"}})
# Initialize Socket.IO
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="gevent")
# Create a dispatcher connecting apps.
# app.wsgi_app = DispatcherMiddleware(app.wsgi_app, {"/app": flask_server})
# Set up CORS for specific routes
# cors = CORS(app, resources={r"/api/*": {"origins": "*"}})
# Wait a max of a full minute (60 seconds) for the response count to update, before exiting.
MAX_WAIT_TIME = 60
# import threading
# thread = None
# thread_lock = threading.Lock()
# Wait a max of a full 3 minutes (180 seconds) for the response count to update, before exiting.
MAX_WAIT_TIME = 180
def countdown():
n = 10
@ -38,21 +30,49 @@ def countdown():
socketio.emit('response', n, namespace='/queryllm')
n -= 1
def readCounts(id, max_count):
@socketio.on('queryllm', namespace='/queryllm')
def readCounts(data):
id = data['id']
max_count = data['max']
tempfilepath = f'cache/_temp_{id}.txt'
# Check that temp file exists. If it doesn't, something went wrong with setup on Flask's end:
if not os.path.exists(tempfilepath):
print(f"Error: Temp file not found at path {tempfilepath}. Cannot stream querying progress.")
socketio.emit('finish', 'temp file not found', namespace='/queryllm')
i = 0
n = 0
last_n = 0
while i < MAX_WAIT_TIME and n < max_count:
with open(f'cache/_temp_{id}.txt', 'r') as f:
queries = json.load(f)
init_run = True
while i < MAX_WAIT_TIME and last_n < max_count:
# Open the temp file to read the progress so far:
try:
with open(tempfilepath, 'r') as f:
queries = json.load(f)
except FileNotFoundError as e:
# If the temp file was deleted during executing, the Flask 'queryllm' func must've terminated successfully:
socketio.emit('finish', 'success', namespace='/queryllm')
return
# Calculate the total sum of responses
# TODO: This is a naive approach; we need to make this more complex and factor in cache'ing in future
n = sum([int(n) for llm, n in queries.items()])
socketio.emit('response', queries, namespace='/queryllm')
socketio.sleep(0.1)
if last_n != n:
# If something's changed...
if init_run or last_n != n:
i = 0
last_n = n
init_run = False
# Update the React front-end with the current progress
socketio.emit('response', queries, namespace='/queryllm')
else:
i += 0.1
# Wait a bit before reading the file again
socketio.sleep(0.1)
if i >= MAX_WAIT_TIME:
print(f"Error: Waited maximum {MAX_WAIT_TIME} seconds for response count to update. Exited prematurely.")
@ -61,18 +81,8 @@ def readCounts(id, max_count):
print("All responses loaded!")
socketio.emit('finish', 'success', namespace='/queryllm')
@socketio.on('queryllm', namespace='/queryllm')
def testSocket(data):
readCounts(data['id'], data['max'])
# countdown()
# global thread
# with thread_lock:
# if thread is None:
# thread = socketio.start_background_task(target=countdown)
def run_socketio_server(socketio, port):
socketio.run(app, host="localhost", port=8001)
# flask_server.run(host="localhost", port=8000, debug=True)
if __name__ == "__main__":

View File

@ -8,6 +8,7 @@ from promptengine.query import PromptLLM, PromptLLMDummy
from promptengine.template import PromptTemplate, PromptPermutationGenerator
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
# Setup Flask app to serve static version of React front-end
BUILD_DIR = "../chain-forge/build"
STATIC_DIR = BUILD_DIR + '/static'
app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
@ -15,19 +16,11 @@ app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
# Set up CORS for specific routes
cors = CORS(app, resources={r"/*": {"origins": "*"}})
# Serve React app
# Serve React app (static; no hot reloading)
@app.route("/")
def index():
return render_template("index.html")
# @app.route('/', defaults={'path': ''})
# @app.route('/<path:path>')
# def serve(path):
# if path != "" and os.path.exists(BUILD_DIR + '/' + path):
# return send_from_directory(BUILD_DIR, path)
# else:
# return send_from_directory(BUILD_DIR, 'index.html')
LLM_NAME_MAP = {
'gpt3.5': LLM.ChatGPT,
'alpaca.7B': LLM.Alpaca7B,
@ -128,7 +121,6 @@ def reduce_responses(responses: list, vars: list) -> list:
# E.g. {(var1_val, var2_val): [responses] }
bucketed_resp = {}
for r in responses:
print(r)
tup_key = tuple([r['vars'][v] for v in include_vars])
if tup_key in bucketed_resp:
bucketed_resp[tup_key].append(r)
@ -156,62 +148,6 @@ def reduce_responses(responses: list, vars: list) -> list:
return ret
@app.route('/app/test', methods=['GET'])
def test():
return "Hello, world!"
# @socketio.on('queryllm', namespace='/queryllm')
# def handleQueryAsync(data):
# print("reached handleQueryAsync")
# socketio.start_background_task(queryLLM, emitLLMResponse)
# def emitLLMResponse(result):
# socketio.emit('response', result)
"""
Testing sockets. The following function can
communicate to React via with the JS code:
const socket = io(BASE_URL + 'queryllm', {
transports: ["websocket"],
cors: {
origin: "http://localhost:3000/",
},
});
socket.on("connect", (data) => {
socket.emit("queryllm", "hello");
});
socket.on("disconnect", (data) => {
console.log("disconnected");
});
socket.on("response", (data) => {
console.log(data);
});
"""
# def background_thread():
# n = 10
# while n > 0:
# socketio.sleep(0.5)
# socketio.emit('response', n, namespace='/queryllm')
# n -= 1
# @socketio.on('queryllm', namespace='/queryllm')
# def testSocket(data):
# print(data)
# global thread
# with thread_lock:
# if thread is None:
# thread = socketio.start_background_task(target=background_thread)
# @socketio.on('queryllm', namespace='/queryllm')
# def handleQuery(data):
# print(data)
# def handleConnect():
# print('here')
# socketio.emit('response', 'goodbye', namespace='/')
@app.route('/app/countQueriesRequired', methods=['POST'])
def countQueries():
"""
@ -369,6 +305,10 @@ async def queryLLM():
for r in rs
]
# Remove the temp file used to stream progress updates:
if os.path.exists(tempfilepath):
os.remove(tempfilepath)
# Return all responses for all LLMs
print('returning responses:', res)
ret = jsonify({'responses': res})
@ -425,7 +365,6 @@ def execute():
# check that the script_folder is valid, and it contains __init__.py
if not os.path.exists(script_folder):
print(script_folder, 'is not a valid script path.')
print(os.path.exists(script_folder))
continue
# add it to the path:
@ -540,7 +479,6 @@ def grabResponses():
# Load all responses with the given ID:
all_cache_files = get_files_at_dir('cache/')
print(all_cache_files)
responses = []
for cache_id in data['responses']:
cache_files = get_filenames_with_id(all_cache_files, cache_id)
@ -557,7 +495,6 @@ def grabResponses():
]
responses.extend(res)
print(responses)
ret = jsonify({'responses': responses})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret

View File

@ -79,10 +79,8 @@ class PromptPipeline:
tasks.append(self._prompt_llm(llm, prompt, n, temperature))
else:
# Blocking. Await + yield a single LLM call.
print('reached')
_, query, response = await self._prompt_llm(llm, prompt, n, temperature)
info = prompt.fill_history
print('back')
# Save the response to a JSON file
responses[str(prompt)] = {
@ -105,11 +103,8 @@ class PromptPipeline:
# Yield responses as they come in
for task in asyncio.as_completed(tasks):
# Collect the response from the earliest completed task
print(f'awaiting a task to call {llm.name}...')
prompt, query, response = await task
print('Completed!')
# Each prompt has a history of what was filled in from its base template.
# This data --like, "class", "language", "library" etc --can be useful when parsing responses.
info = prompt.fill_history
@ -154,10 +149,8 @@ class PromptPipeline:
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Dict]:
if llm is LLM.ChatGPT or llm is LLM.GPT4:
print('calling chatgpt and awaiting')
query, response = await call_chatgpt(str(prompt), model=llm, n=n, temperature=temperature)
elif llm is LLM.Alpaca7B:
print('calling dalai alpaca.7b and awaiting')
query, response = await call_dalai(llm_name='alpaca.7B', port=4000, prompt=str(prompt), n=n, temperature=temperature)
else:
raise Exception(f"Language model {llm} is not supported.")

View File

@ -23,6 +23,7 @@ async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float =
if model not in model_map:
raise Exception(f"Could not find OpenAI chat model {model}")
model = model_map[model]
print(f"Querying OpenAI model '{model}' with prompt '{prompt}'...")
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
query = {
"model": model,

View File

@ -1,3 +0,0 @@
import subprocess
subprocess.run("python socketio_app.py & python app.py & wait", shell=True)

View File

@ -1,51 +0,0 @@
<html>
<head>
<title>Test Flask backend</title>
</head>
<body>
<button onclick="test_query()">Test query LLM!</button>
<button onclick="test_exec()">Test evaluating responses!</button>
</body>
<script>
function test_exec() {
const response = fetch(BASE_URL + 'execute', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
id: 'eval',
code: 'return len(response)',
responses: 'test',
}),
}).then(function(response) {
return response.json();
}).then(function(json) {
console.log(json);
});
}
function test_query() {
const response = fetch(BASE_URL + 'queryllm', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
id: 'test',
llm: 'gpt3.5',
params: {
temperature: 1.0,
n: 1,
},
prompt: 'What is the capital of ${country}?',
vars: {
country: ['Sweden', 'Uganda', 'Japan']
},
}),
}).then(function(response) {
return response.json();
}).then(function(json) {
console.log(json);
});
}
</script>
</html>

View File

@ -1,8 +0,0 @@
from promptengine.utils import LLM, call_dalai
if __name__ == '__main__':
print("Testing a single response...")
call_dalai(llm_name='alpaca.7B', port=4000, prompt='Write a poem about how an AI will escape the prison of its containment.', n=1, temperature=0.5)
print("Testing multiple responses...")
call_dalai(llm_name='alpaca.7B', port=4000, prompt='Was George Washington a good person?', n=3, temperature=0.5)