mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Cleanup
This commit is contained in:
parent
cdadee03f2
commit
f635abc148
@ -286,10 +286,18 @@ const PromptNode = ({ data, id }) => {
|
||||
// Request progress bar updates
|
||||
socket.emit("queryllm", {'id': id, 'max': max_responses});
|
||||
});
|
||||
|
||||
// Socket connection could not be established
|
||||
socket.on("connect_error", (error) => {
|
||||
console.log("Socket connection failed:", error.message);
|
||||
socket.disconnect();
|
||||
});
|
||||
|
||||
// Socket disconnected
|
||||
socket.on("disconnect", (msg) => {
|
||||
console.log(msg);
|
||||
});
|
||||
|
||||
|
||||
// The current progress, a number specifying how many responses collected so far:
|
||||
socket.on("response", (counts) => {
|
||||
console.log(counts);
|
||||
|
@ -4,32 +4,24 @@ from statistics import mean, median, stdev
|
||||
from flask import Flask, request, jsonify
|
||||
from flask_cors import CORS
|
||||
from flask_socketio import SocketIO
|
||||
# from werkzeug.middleware.dispatcher import DispatcherMiddleware
|
||||
from flask_app import run_server
|
||||
from promptengine.query import PromptLLM, PromptLLMDummy
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
|
||||
|
||||
# Setup the main app
|
||||
# Setup the socketio app
|
||||
# BUILD_DIR = "../chain-forge/build"
|
||||
# STATIC_DIR = BUILD_DIR + '/static'
|
||||
app = Flask(__name__) #, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
|
||||
|
||||
# Set up CORS for specific routes
|
||||
# cors = CORS(app, resources={r"/api/*": {"origins": "*"}})
|
||||
|
||||
# Initialize Socket.IO
|
||||
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="gevent")
|
||||
|
||||
# Create a dispatcher connecting apps.
|
||||
# app.wsgi_app = DispatcherMiddleware(app.wsgi_app, {"/app": flask_server})
|
||||
# Set up CORS for specific routes
|
||||
# cors = CORS(app, resources={r"/api/*": {"origins": "*"}})
|
||||
|
||||
# Wait a max of a full minute (60 seconds) for the response count to update, before exiting.
|
||||
MAX_WAIT_TIME = 60
|
||||
|
||||
# import threading
|
||||
# thread = None
|
||||
# thread_lock = threading.Lock()
|
||||
# Wait a max of a full 3 minutes (180 seconds) for the response count to update, before exiting.
|
||||
MAX_WAIT_TIME = 180
|
||||
|
||||
def countdown():
|
||||
n = 10
|
||||
@ -38,21 +30,49 @@ def countdown():
|
||||
socketio.emit('response', n, namespace='/queryllm')
|
||||
n -= 1
|
||||
|
||||
def readCounts(id, max_count):
|
||||
@socketio.on('queryllm', namespace='/queryllm')
|
||||
def readCounts(data):
|
||||
id = data['id']
|
||||
max_count = data['max']
|
||||
tempfilepath = f'cache/_temp_{id}.txt'
|
||||
|
||||
# Check that temp file exists. If it doesn't, something went wrong with setup on Flask's end:
|
||||
if not os.path.exists(tempfilepath):
|
||||
print(f"Error: Temp file not found at path {tempfilepath}. Cannot stream querying progress.")
|
||||
socketio.emit('finish', 'temp file not found', namespace='/queryllm')
|
||||
|
||||
i = 0
|
||||
n = 0
|
||||
last_n = 0
|
||||
while i < MAX_WAIT_TIME and n < max_count:
|
||||
with open(f'cache/_temp_{id}.txt', 'r') as f:
|
||||
queries = json.load(f)
|
||||
init_run = True
|
||||
while i < MAX_WAIT_TIME and last_n < max_count:
|
||||
|
||||
# Open the temp file to read the progress so far:
|
||||
try:
|
||||
with open(tempfilepath, 'r') as f:
|
||||
queries = json.load(f)
|
||||
except FileNotFoundError as e:
|
||||
# If the temp file was deleted during executing, the Flask 'queryllm' func must've terminated successfully:
|
||||
socketio.emit('finish', 'success', namespace='/queryllm')
|
||||
return
|
||||
|
||||
# Calculate the total sum of responses
|
||||
# TODO: This is a naive approach; we need to make this more complex and factor in cache'ing in future
|
||||
n = sum([int(n) for llm, n in queries.items()])
|
||||
socketio.emit('response', queries, namespace='/queryllm')
|
||||
socketio.sleep(0.1)
|
||||
if last_n != n:
|
||||
|
||||
# If something's changed...
|
||||
if init_run or last_n != n:
|
||||
i = 0
|
||||
last_n = n
|
||||
init_run = False
|
||||
|
||||
# Update the React front-end with the current progress
|
||||
socketio.emit('response', queries, namespace='/queryllm')
|
||||
|
||||
else:
|
||||
i += 0.1
|
||||
|
||||
# Wait a bit before reading the file again
|
||||
socketio.sleep(0.1)
|
||||
|
||||
if i >= MAX_WAIT_TIME:
|
||||
print(f"Error: Waited maximum {MAX_WAIT_TIME} seconds for response count to update. Exited prematurely.")
|
||||
@ -61,18 +81,8 @@ def readCounts(id, max_count):
|
||||
print("All responses loaded!")
|
||||
socketio.emit('finish', 'success', namespace='/queryllm')
|
||||
|
||||
@socketio.on('queryllm', namespace='/queryllm')
|
||||
def testSocket(data):
|
||||
readCounts(data['id'], data['max'])
|
||||
# countdown()
|
||||
# global thread
|
||||
# with thread_lock:
|
||||
# if thread is None:
|
||||
# thread = socketio.start_background_task(target=countdown)
|
||||
|
||||
def run_socketio_server(socketio, port):
|
||||
socketio.run(app, host="localhost", port=8001)
|
||||
# flask_server.run(host="localhost", port=8000, debug=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
@ -8,6 +8,7 @@ from promptengine.query import PromptLLM, PromptLLMDummy
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
|
||||
|
||||
# Setup Flask app to serve static version of React front-end
|
||||
BUILD_DIR = "../chain-forge/build"
|
||||
STATIC_DIR = BUILD_DIR + '/static'
|
||||
app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
|
||||
@ -15,19 +16,11 @@ app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
|
||||
# Set up CORS for specific routes
|
||||
cors = CORS(app, resources={r"/*": {"origins": "*"}})
|
||||
|
||||
# Serve React app
|
||||
# Serve React app (static; no hot reloading)
|
||||
@app.route("/")
|
||||
def index():
|
||||
return render_template("index.html")
|
||||
|
||||
# @app.route('/', defaults={'path': ''})
|
||||
# @app.route('/<path:path>')
|
||||
# def serve(path):
|
||||
# if path != "" and os.path.exists(BUILD_DIR + '/' + path):
|
||||
# return send_from_directory(BUILD_DIR, path)
|
||||
# else:
|
||||
# return send_from_directory(BUILD_DIR, 'index.html')
|
||||
|
||||
LLM_NAME_MAP = {
|
||||
'gpt3.5': LLM.ChatGPT,
|
||||
'alpaca.7B': LLM.Alpaca7B,
|
||||
@ -128,7 +121,6 @@ def reduce_responses(responses: list, vars: list) -> list:
|
||||
# E.g. {(var1_val, var2_val): [responses] }
|
||||
bucketed_resp = {}
|
||||
for r in responses:
|
||||
print(r)
|
||||
tup_key = tuple([r['vars'][v] for v in include_vars])
|
||||
if tup_key in bucketed_resp:
|
||||
bucketed_resp[tup_key].append(r)
|
||||
@ -156,62 +148,6 @@ def reduce_responses(responses: list, vars: list) -> list:
|
||||
|
||||
return ret
|
||||
|
||||
@app.route('/app/test', methods=['GET'])
|
||||
def test():
|
||||
return "Hello, world!"
|
||||
|
||||
# @socketio.on('queryllm', namespace='/queryllm')
|
||||
# def handleQueryAsync(data):
|
||||
# print("reached handleQueryAsync")
|
||||
# socketio.start_background_task(queryLLM, emitLLMResponse)
|
||||
|
||||
# def emitLLMResponse(result):
|
||||
# socketio.emit('response', result)
|
||||
|
||||
"""
|
||||
Testing sockets. The following function can
|
||||
communicate to React via with the JS code:
|
||||
|
||||
const socket = io(BASE_URL + 'queryllm', {
|
||||
transports: ["websocket"],
|
||||
cors: {
|
||||
origin: "http://localhost:3000/",
|
||||
},
|
||||
});
|
||||
|
||||
socket.on("connect", (data) => {
|
||||
socket.emit("queryllm", "hello");
|
||||
});
|
||||
socket.on("disconnect", (data) => {
|
||||
console.log("disconnected");
|
||||
});
|
||||
socket.on("response", (data) => {
|
||||
console.log(data);
|
||||
});
|
||||
"""
|
||||
# def background_thread():
|
||||
# n = 10
|
||||
# while n > 0:
|
||||
# socketio.sleep(0.5)
|
||||
# socketio.emit('response', n, namespace='/queryllm')
|
||||
# n -= 1
|
||||
|
||||
# @socketio.on('queryllm', namespace='/queryllm')
|
||||
# def testSocket(data):
|
||||
# print(data)
|
||||
# global thread
|
||||
# with thread_lock:
|
||||
# if thread is None:
|
||||
# thread = socketio.start_background_task(target=background_thread)
|
||||
|
||||
# @socketio.on('queryllm', namespace='/queryllm')
|
||||
# def handleQuery(data):
|
||||
# print(data)
|
||||
|
||||
# def handleConnect():
|
||||
# print('here')
|
||||
# socketio.emit('response', 'goodbye', namespace='/')
|
||||
|
||||
@app.route('/app/countQueriesRequired', methods=['POST'])
|
||||
def countQueries():
|
||||
"""
|
||||
@ -369,6 +305,10 @@ async def queryLLM():
|
||||
for r in rs
|
||||
]
|
||||
|
||||
# Remove the temp file used to stream progress updates:
|
||||
if os.path.exists(tempfilepath):
|
||||
os.remove(tempfilepath)
|
||||
|
||||
# Return all responses for all LLMs
|
||||
print('returning responses:', res)
|
||||
ret = jsonify({'responses': res})
|
||||
@ -425,7 +365,6 @@ def execute():
|
||||
# check that the script_folder is valid, and it contains __init__.py
|
||||
if not os.path.exists(script_folder):
|
||||
print(script_folder, 'is not a valid script path.')
|
||||
print(os.path.exists(script_folder))
|
||||
continue
|
||||
|
||||
# add it to the path:
|
||||
@ -540,7 +479,6 @@ def grabResponses():
|
||||
|
||||
# Load all responses with the given ID:
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
print(all_cache_files)
|
||||
responses = []
|
||||
for cache_id in data['responses']:
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
@ -557,7 +495,6 @@ def grabResponses():
|
||||
]
|
||||
responses.extend(res)
|
||||
|
||||
print(responses)
|
||||
ret = jsonify({'responses': responses})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
@ -79,10 +79,8 @@ class PromptPipeline:
|
||||
tasks.append(self._prompt_llm(llm, prompt, n, temperature))
|
||||
else:
|
||||
# Blocking. Await + yield a single LLM call.
|
||||
print('reached')
|
||||
_, query, response = await self._prompt_llm(llm, prompt, n, temperature)
|
||||
info = prompt.fill_history
|
||||
print('back')
|
||||
|
||||
# Save the response to a JSON file
|
||||
responses[str(prompt)] = {
|
||||
@ -105,11 +103,8 @@ class PromptPipeline:
|
||||
# Yield responses as they come in
|
||||
for task in asyncio.as_completed(tasks):
|
||||
# Collect the response from the earliest completed task
|
||||
print(f'awaiting a task to call {llm.name}...')
|
||||
prompt, query, response = await task
|
||||
|
||||
print('Completed!')
|
||||
|
||||
# Each prompt has a history of what was filled in from its base template.
|
||||
# This data --like, "class", "language", "library" etc --can be useful when parsing responses.
|
||||
info = prompt.fill_history
|
||||
@ -154,10 +149,8 @@ class PromptPipeline:
|
||||
|
||||
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Dict]:
|
||||
if llm is LLM.ChatGPT or llm is LLM.GPT4:
|
||||
print('calling chatgpt and awaiting')
|
||||
query, response = await call_chatgpt(str(prompt), model=llm, n=n, temperature=temperature)
|
||||
elif llm is LLM.Alpaca7B:
|
||||
print('calling dalai alpaca.7b and awaiting')
|
||||
query, response = await call_dalai(llm_name='alpaca.7B', port=4000, prompt=str(prompt), n=n, temperature=temperature)
|
||||
else:
|
||||
raise Exception(f"Language model {llm} is not supported.")
|
||||
|
@ -23,6 +23,7 @@ async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float =
|
||||
if model not in model_map:
|
||||
raise Exception(f"Could not find OpenAI chat model {model}")
|
||||
model = model_map[model]
|
||||
print(f"Querying OpenAI model '{model}' with prompt '{prompt}'...")
|
||||
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
|
||||
query = {
|
||||
"model": model,
|
||||
|
@ -1,3 +0,0 @@
|
||||
import subprocess
|
||||
|
||||
subprocess.run("python socketio_app.py & python app.py & wait", shell=True)
|
@ -1,51 +0,0 @@
|
||||
<html>
|
||||
<head>
|
||||
<title>Test Flask backend</title>
|
||||
</head>
|
||||
<body>
|
||||
<button onclick="test_query()">Test query LLM!</button>
|
||||
<button onclick="test_exec()">Test evaluating responses!</button>
|
||||
</body>
|
||||
<script>
|
||||
|
||||
function test_exec() {
|
||||
const response = fetch(BASE_URL + 'execute', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
id: 'eval',
|
||||
code: 'return len(response)',
|
||||
responses: 'test',
|
||||
}),
|
||||
}).then(function(response) {
|
||||
return response.json();
|
||||
}).then(function(json) {
|
||||
console.log(json);
|
||||
});
|
||||
}
|
||||
|
||||
function test_query() {
|
||||
const response = fetch(BASE_URL + 'queryllm', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
id: 'test',
|
||||
llm: 'gpt3.5',
|
||||
params: {
|
||||
temperature: 1.0,
|
||||
n: 1,
|
||||
},
|
||||
prompt: 'What is the capital of ${country}?',
|
||||
vars: {
|
||||
country: ['Sweden', 'Uganda', 'Japan']
|
||||
},
|
||||
}),
|
||||
}).then(function(response) {
|
||||
return response.json();
|
||||
}).then(function(json) {
|
||||
console.log(json);
|
||||
});
|
||||
}
|
||||
|
||||
</script>
|
||||
</html>
|
@ -1,8 +0,0 @@
|
||||
from promptengine.utils import LLM, call_dalai
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("Testing a single response...")
|
||||
call_dalai(llm_name='alpaca.7B', port=4000, prompt='Write a poem about how an AI will escape the prison of its containment.', n=1, temperature=0.5)
|
||||
|
||||
print("Testing multiple responses...")
|
||||
call_dalai(llm_name='alpaca.7B', port=4000, prompt='Was George Washington a good person?', n=3, temperature=0.5)
|
Loading…
x
Reference in New Issue
Block a user