mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Run React and Flask and Socketio with single script
This commit is contained in:
parent
d4e0630564
commit
cdadee03f2
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,5 +1,6 @@
|
||||
*.DS_Store
|
||||
chain-forge/node_modules
|
||||
chain-forge/build
|
||||
__pycache__
|
||||
python-backend/cache
|
||||
|
||||
|
@ -74,7 +74,7 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
console.log(script_paths);
|
||||
// Run evaluator in backend
|
||||
const codeTextOnRun = codeText + '';
|
||||
fetch(BASE_URL + 'api/execute', {
|
||||
fetch(BASE_URL + 'app/execute', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
|
@ -32,7 +32,7 @@ const InspectorNode = ({ data, id }) => {
|
||||
console.log(input_node_ids);
|
||||
|
||||
// Grab responses associated with those ids:
|
||||
fetch(BASE_URL + 'api/grabResponses', {
|
||||
fetch(BASE_URL + 'app/grabResponses', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
|
@ -174,7 +174,7 @@ const PromptNode = ({ data, id }) => {
|
||||
|
||||
// Ask the backend how many responses it needs to collect, given the input data:
|
||||
const fetchResponseCounts = (prompt, vars, llms, rejected) => {
|
||||
return fetch(BASE_URL + 'api/countQueriesRequired', {
|
||||
return fetch(BASE_URL + 'app/countQueriesRequired', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
@ -250,7 +250,7 @@ const PromptNode = ({ data, id }) => {
|
||||
|
||||
// Ask the backend to reset the scratchpad for counting queries:
|
||||
const create_progress_scratchpad = () => {
|
||||
return fetch(BASE_URL + 'api/createProgressFile', {
|
||||
return fetch(BASE_URL + 'app/createProgressFile', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
@ -268,7 +268,7 @@ const PromptNode = ({ data, id }) => {
|
||||
// the socketio server that will stream to us the current progress:
|
||||
const socket = io('http://localhost:8001/' + 'queryllm', {
|
||||
transports: ["websocket"],
|
||||
cors: {origin: "http://localhost:3000/"},
|
||||
cors: {origin: "http://localhost:8000/"},
|
||||
});
|
||||
|
||||
const max_responses = Object.keys(response_counts).reduce((acc, llm) => acc + response_counts[llm], 0);
|
||||
@ -319,7 +319,7 @@ const PromptNode = ({ data, id }) => {
|
||||
|
||||
// Run all prompt permutations through the LLM to generate + cache responses:
|
||||
const query_llms = () => {
|
||||
return fetch(BASE_URL + 'api/queryllm', {
|
||||
return fetch(BASE_URL + 'app/queryllm', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
|
@ -47,7 +47,7 @@ const VisNode = ({ data, id }) => {
|
||||
// Grab the input node ids
|
||||
const input_node_ids = [data.input];
|
||||
|
||||
fetch(BASE_URL + 'api/grabResponses', {
|
||||
fetch(BASE_URL + 'app/grabResponses', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
|
@ -1,561 +1,81 @@
|
||||
import json, os, asyncio, sys, argparse
|
||||
import json, os, asyncio, sys, argparse, threading
|
||||
from dataclasses import dataclass
|
||||
from statistics import mean, median, stdev
|
||||
from flask import Flask, request, jsonify
|
||||
from flask_cors import CORS
|
||||
from flask_socketio import SocketIO
|
||||
# from werkzeug.middleware.dispatcher import DispatcherMiddleware
|
||||
from flask_app import run_server
|
||||
from promptengine.query import PromptLLM, PromptLLMDummy
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
|
||||
|
||||
app = Flask(__name__)
|
||||
# Setup the main app
|
||||
# BUILD_DIR = "../chain-forge/build"
|
||||
# STATIC_DIR = BUILD_DIR + '/static'
|
||||
app = Flask(__name__) #, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
|
||||
|
||||
# Set up CORS for specific routes
|
||||
cors = CORS(app, resources={r"/api/*": {"origins": "*"}})
|
||||
# cors = CORS(app, resources={r"/api/*": {"origins": "*"}})
|
||||
|
||||
# Initialize Socket.IO
|
||||
# socketio = SocketIO(app, cors_allowed_origins="*", async_mode="gevent")
|
||||
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="gevent")
|
||||
|
||||
# Create a dispatcher connecting apps.
|
||||
# app.wsgi_app = DispatcherMiddleware(app.wsgi_app, {"/app": flask_server})
|
||||
|
||||
# Wait a max of a full minute (60 seconds) for the response count to update, before exiting.
|
||||
MAX_WAIT_TIME = 60
|
||||
|
||||
# import threading
|
||||
# thread = None
|
||||
# thread_lock = threading.Lock()
|
||||
|
||||
def countdown():
|
||||
n = 10
|
||||
while n > 0:
|
||||
socketio.sleep(0.5)
|
||||
socketio.emit('response', n, namespace='/queryllm')
|
||||
n -= 1
|
||||
|
||||
LLM_NAME_MAP = {
|
||||
'gpt3.5': LLM.ChatGPT,
|
||||
'alpaca.7B': LLM.Alpaca7B,
|
||||
'gpt4': LLM.GPT4,
|
||||
}
|
||||
LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()}
|
||||
|
||||
@dataclass
|
||||
class ResponseInfo:
|
||||
"""Stores info about a single response. Passed to evaluator functions."""
|
||||
text: str
|
||||
prompt: str
|
||||
var: str
|
||||
llm: str
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
def to_standard_format(r: dict) -> list:
|
||||
llm = LLM_NAME_MAP_INVERSE[r['llm']]
|
||||
resp_obj = {
|
||||
'vars': r['info'],
|
||||
'llm': llm,
|
||||
'prompt': r['prompt'],
|
||||
'responses': extract_responses(r, r['llm']),
|
||||
'tokens': r['response']['usage'] if 'usage' in r['response'] else {},
|
||||
}
|
||||
if 'eval_res' in r:
|
||||
resp_obj['eval_res'] = r['eval_res']
|
||||
return resp_obj
|
||||
|
||||
def get_llm_of_response(response: dict) -> LLM:
|
||||
return LLM_NAME_MAP[response['llm']]
|
||||
|
||||
def get_filenames_with_id(filenames: list, id: str) -> list:
|
||||
return [
|
||||
c for c in filenames
|
||||
if c.split('.')[0] == id or ('-' in c and c[:c.rfind('-')] == id)
|
||||
]
|
||||
|
||||
def remove_cached_responses(cache_id: str):
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
for filename in cache_files:
|
||||
os.remove(os.path.join('cache/', filename))
|
||||
|
||||
def load_cache_json(filepath: str) -> dict:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
responses = json.load(f)
|
||||
return responses
|
||||
|
||||
def run_over_responses(eval_func, responses: dict, scope: str) -> list:
|
||||
for prompt, resp_obj in responses.items():
|
||||
res = extract_responses(resp_obj, resp_obj['llm'])
|
||||
if scope == 'response':
|
||||
evals = [ # Run evaluator func over every individual response text
|
||||
eval_func(
|
||||
ResponseInfo(
|
||||
text=r,
|
||||
prompt=prompt,
|
||||
var=resp_obj['info'],
|
||||
llm=LLM_NAME_MAP_INVERSE[resp_obj['llm']])
|
||||
) for r in res
|
||||
]
|
||||
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
|
||||
'mean': mean(evals),
|
||||
'median': median(evals),
|
||||
'stdev': stdev(evals) if len(evals) > 1 else 0,
|
||||
'range': (min(evals), max(evals)),
|
||||
'items': evals,
|
||||
}
|
||||
else: # operate over the entire response batch
|
||||
ev = eval_func(res)
|
||||
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
|
||||
'mean': ev,
|
||||
'median': ev,
|
||||
'stdev': 0,
|
||||
'range': (ev, ev),
|
||||
'items': [ev],
|
||||
}
|
||||
return responses
|
||||
|
||||
def reduce_responses(responses: list, vars: list) -> list:
|
||||
if len(responses) == 0: return responses
|
||||
|
||||
# Figure out what vars we still care about (the ones we aren't reducing over):
|
||||
# NOTE: We are assuming all responses have the same 'vars' keys.
|
||||
all_vars = set(responses[0]['vars'])
|
||||
|
||||
if not all_vars.issuperset(set(vars)):
|
||||
# There's a var in vars which isn't part of the response.
|
||||
raise Exception(f"Some vars in {set(vars)} are not in the responses.")
|
||||
|
||||
# Get just the vars we want to keep around:
|
||||
include_vars = list(set(responses[0]['vars']) - set(vars))
|
||||
|
||||
# Bucket responses by the remaining var values, where tuples of vars are keys to a dict:
|
||||
# E.g. {(var1_val, var2_val): [responses] }
|
||||
bucketed_resp = {}
|
||||
for r in responses:
|
||||
print(r)
|
||||
tup_key = tuple([r['vars'][v] for v in include_vars])
|
||||
if tup_key in bucketed_resp:
|
||||
bucketed_resp[tup_key].append(r)
|
||||
def readCounts(id, max_count):
|
||||
i = 0
|
||||
n = 0
|
||||
last_n = 0
|
||||
while i < MAX_WAIT_TIME and n < max_count:
|
||||
with open(f'cache/_temp_{id}.txt', 'r') as f:
|
||||
queries = json.load(f)
|
||||
n = sum([int(n) for llm, n in queries.items()])
|
||||
socketio.emit('response', queries, namespace='/queryllm')
|
||||
socketio.sleep(0.1)
|
||||
if last_n != n:
|
||||
i = 0
|
||||
last_n = n
|
||||
else:
|
||||
bucketed_resp[tup_key] = [r]
|
||||
i += 0.1
|
||||
|
||||
# Perform reduce op across all bucketed responses, collecting them into a single 'meta'-response:
|
||||
ret = []
|
||||
for tup_key, resps in bucketed_resp.items():
|
||||
flat_eval_res = [item for r in resps for item in r['eval_res']['items']]
|
||||
ret.append({
|
||||
'vars': {v: r['vars'][v] for r in resps for v in include_vars},
|
||||
'llm': resps[0]['llm'],
|
||||
'prompt': [r['prompt'] for r in resps],
|
||||
'responses': [r['responses'] for r in resps],
|
||||
'tokens': resps[0]['tokens'],
|
||||
'eval_res': {
|
||||
'mean': mean(flat_eval_res),
|
||||
'median': median(flat_eval_res),
|
||||
'stdev': stdev(flat_eval_res) if len(flat_eval_res) > 1 else 0,
|
||||
'range': (min(flat_eval_res), max(flat_eval_res)),
|
||||
'items': flat_eval_res
|
||||
}
|
||||
})
|
||||
if i >= MAX_WAIT_TIME:
|
||||
print(f"Error: Waited maximum {MAX_WAIT_TIME} seconds for response count to update. Exited prematurely.")
|
||||
socketio.emit('finish', 'max_wait_reached', namespace='/queryllm')
|
||||
else:
|
||||
print("All responses loaded!")
|
||||
socketio.emit('finish', 'success', namespace='/queryllm')
|
||||
|
||||
@socketio.on('queryllm', namespace='/queryllm')
|
||||
def testSocket(data):
|
||||
readCounts(data['id'], data['max'])
|
||||
# countdown()
|
||||
# global thread
|
||||
# with thread_lock:
|
||||
# if thread is None:
|
||||
# thread = socketio.start_background_task(target=countdown)
|
||||
|
||||
def run_socketio_server(socketio, port):
|
||||
socketio.run(app, host="localhost", port=8001)
|
||||
# flask_server.run(host="localhost", port=8000, debug=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
return ret
|
||||
|
||||
@app.route('/api/test', methods=['GET'])
|
||||
def test():
|
||||
return "Hello, world!"
|
||||
|
||||
# @socketio.on('queryllm', namespace='/queryllm')
|
||||
# def handleQueryAsync(data):
|
||||
# print("reached handleQueryAsync")
|
||||
# socketio.start_background_task(queryLLM, emitLLMResponse)
|
||||
|
||||
# def emitLLMResponse(result):
|
||||
# socketio.emit('response', result)
|
||||
|
||||
"""
|
||||
Testing sockets. The following function can
|
||||
communicate to React via with the JS code:
|
||||
|
||||
const socket = io(BASE_URL + 'queryllm', {
|
||||
transports: ["websocket"],
|
||||
cors: {
|
||||
origin: "http://localhost:3000/",
|
||||
},
|
||||
});
|
||||
|
||||
socket.on("connect", (data) => {
|
||||
socket.emit("queryllm", "hello");
|
||||
});
|
||||
socket.on("disconnect", (data) => {
|
||||
console.log("disconnected");
|
||||
});
|
||||
socket.on("response", (data) => {
|
||||
console.log(data);
|
||||
});
|
||||
"""
|
||||
# def background_thread():
|
||||
# n = 10
|
||||
# while n > 0:
|
||||
# socketio.sleep(0.5)
|
||||
# socketio.emit('response', n, namespace='/queryllm')
|
||||
# n -= 1
|
||||
|
||||
# @socketio.on('queryllm', namespace='/queryllm')
|
||||
# def testSocket(data):
|
||||
# print(data)
|
||||
# global thread
|
||||
# with thread_lock:
|
||||
# if thread is None:
|
||||
# thread = socketio.start_background_task(target=background_thread)
|
||||
|
||||
# @socketio.on('queryllm', namespace='/queryllm')
|
||||
# def handleQuery(data):
|
||||
# print(data)
|
||||
|
||||
# def handleConnect():
|
||||
# print('here')
|
||||
# socketio.emit('response', 'goodbye', namespace='/')
|
||||
|
||||
@app.route('/api/countQueriesRequired', methods=['POST'])
|
||||
def countQueries():
|
||||
"""
|
||||
Returns how many queries we need to make, given the passed prompt and vars.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'prompt': str # the prompt template, with any {{}} vars
|
||||
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
|
||||
'llms': list # the list of LLMs you will query
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
if not set(data.keys()).issuperset({'prompt', 'vars', 'llms'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
|
||||
try:
|
||||
gen_prompts = PromptPermutationGenerator(PromptTemplate(data['prompt']))
|
||||
all_prompt_permutations = list(gen_prompts(data['vars']))
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)})
|
||||
|
||||
# TODO: Send more informative data back including how many queries per LLM based on cache'd data
|
||||
num_queries = {} # len(all_prompt_permutations) * len(data['llms'])
|
||||
for llm in data['llms']:
|
||||
num_queries[llm] = len(all_prompt_permutations)
|
||||
|
||||
ret = jsonify({'counts': num_queries})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@app.route('/api/createProgressFile', methods=['POST'])
|
||||
def createProgressFile():
|
||||
"""
|
||||
Creates a temp txt file for storing progress of async LLM queries.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'id': str # a unique ID that will be used when calling 'queryllm'
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
if 'id' not in data or not isinstance(data['id'], str) or len(data['id']) == 0:
|
||||
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
||||
|
||||
# Create a scratch file for keeping track of how many responses loaded
|
||||
try:
|
||||
with open(f"cache/_temp_{data['id']}.txt", 'w') as f:
|
||||
json.dump({}, f)
|
||||
ret = jsonify({'success': True})
|
||||
except Exception as e:
|
||||
ret = jsonify({'success': False, 'error': str(e)})
|
||||
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
# @socketio.on('connect', namespace='/queryllm')
|
||||
@app.route('/api/queryllm', methods=['POST'])
|
||||
async def queryLLM():
|
||||
"""
|
||||
Queries LLM(s) given a JSON spec.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'id': str # a unique ID to refer to this information. Used when cache'ing responses.
|
||||
'llm': str | list # a string or list of strings specifying the LLM(s) to query
|
||||
'params': dict # an optional dict of any other params to set when querying the LLMs, like 'temperature', 'n' (num of responses per prompt), etc.
|
||||
'prompt': str # the prompt template, with any {{}} vars
|
||||
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
|
||||
'no_cache': bool (optional) # delete any cache'd responses for 'id' (always call the LLM fresh)
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
# Check that all required info is here:
|
||||
if not set(data.keys()).issuperset({'llm', 'prompt', 'vars', 'id'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
elif not isinstance(data['id'], str) or len(data['id']) == 0:
|
||||
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
||||
|
||||
# Verify LLM name(s) (string or list) and convert to enum(s):
|
||||
if not (isinstance(data['llm'], list) or isinstance(data['llm'], str)) or (isinstance(data['llm'], list) and len(data['llm']) == 0):
|
||||
return jsonify({'error': 'POST data llm is improper format (not string or list, or of length 0).'})
|
||||
if isinstance(data['llm'], str):
|
||||
data['llm'] = [ data['llm'] ]
|
||||
|
||||
for llm_str in data['llm']:
|
||||
if llm_str not in LLM_NAME_MAP:
|
||||
return jsonify({'error': f"LLM named '{llm_str}' is not supported."})
|
||||
|
||||
if 'no_cache' in data and data['no_cache'] is True:
|
||||
remove_cached_responses(data['id'])
|
||||
|
||||
# Create a cache dir if it doesn't exist:
|
||||
create_dir_if_not_exists('cache')
|
||||
|
||||
# For each LLM, generate and cache responses:
|
||||
responses = {}
|
||||
llms = data['llm']
|
||||
params = data['params'] if 'params' in data else {}
|
||||
tempfilepath = f"cache/_temp_{data['id']}.txt"
|
||||
|
||||
async def query(llm_str: str) -> list:
|
||||
llm = LLM_NAME_MAP[llm_str]
|
||||
|
||||
# Check that storage path is valid:
|
||||
cache_filepath = os.path.join('cache', f"{data['id']}-{str(llm.name)}.json")
|
||||
if not is_valid_filepath(cache_filepath):
|
||||
return jsonify({'error': f'Invalid filepath: {cache_filepath}'})
|
||||
|
||||
# Create an object to query the LLM, passing a file for cache'ing responses
|
||||
prompter = PromptLLM(data['prompt'], storageFile=cache_filepath)
|
||||
|
||||
# Prompt the LLM with all permutations of the input prompt template:
|
||||
# NOTE: If the responses are already cache'd, this just loads them (no LLM is queried, saving $$$)
|
||||
resps = []
|
||||
try:
|
||||
print(f'Querying {llm}...')
|
||||
async for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
|
||||
resps.append(response)
|
||||
print(f"collected response from {llm.name}:", str(response))
|
||||
|
||||
# Save the number of responses collected to a temp file on disk
|
||||
with open(tempfilepath, 'r') as f:
|
||||
txt = f.read().strip()
|
||||
|
||||
cur_data = json.loads(txt) if len(txt) > 0 else {}
|
||||
cur_data[llm_str] = len(resps)
|
||||
|
||||
with open(tempfilepath, 'w') as f:
|
||||
json.dump(cur_data, f)
|
||||
except Exception as e:
|
||||
print('error generating responses:', e)
|
||||
raise e
|
||||
|
||||
return {'llm': llm, 'responses': resps}
|
||||
|
||||
try:
|
||||
# Request responses simultaneously across LLMs
|
||||
tasks = [query(llm) for llm in llms]
|
||||
|
||||
# Await the responses from all queried LLMs
|
||||
llm_results = await asyncio.gather(*tasks)
|
||||
for item in llm_results:
|
||||
responses[item['llm']] = item['responses']
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)})
|
||||
|
||||
# Convert the responses into a more standardized format with less information
|
||||
res = [
|
||||
to_standard_format(r)
|
||||
for rs in responses.values()
|
||||
for r in rs
|
||||
]
|
||||
|
||||
# Return all responses for all LLMs
|
||||
print('returning responses:', res)
|
||||
ret = jsonify({'responses': res})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@app.route('/api/execute', methods=['POST'])
|
||||
def execute():
|
||||
"""
|
||||
Executes a Python lambda function sent from JavaScript,
|
||||
over all cache'd responses with given id's.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'id': # a unique ID to refer to this information. Used when cache'ing responses.
|
||||
'code': str, # the body of the lambda function to evaluate, in form: lambda responses: <body>
|
||||
'responses': str | List[str] # the responses to run on; a unique ID or list of unique IDs of cache'd data,
|
||||
'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch.
|
||||
# If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt.
|
||||
'reduce_vars': unspecified | List[str] # the 'vars' to average over (mean, median, stdev, range)
|
||||
'script_paths': unspecified | List[str] # the paths to scripts to be added to the path before the lambda function is evaluated
|
||||
}
|
||||
|
||||
NOTE: This should only be run on your server on code you trust.
|
||||
There is no sandboxing; no safety. We assume you are the creator of the code.
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
# Check that all required info is here:
|
||||
if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
if not isinstance(data['id'], str) or len(data['id']) == 0:
|
||||
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
||||
if data['scope'] not in ('response', 'batch'):
|
||||
return jsonify({'error': "POST data scope is unknown. Must be either 'response' or 'batch'."})
|
||||
|
||||
# Check that the filepath used to cache eval'd responses is valid:
|
||||
cache_filepath = os.path.join('cache', f"{data['id']}.json")
|
||||
if not is_valid_filepath(cache_filepath):
|
||||
return jsonify({'error': f'Invalid filepath: {cache_filepath}'})
|
||||
|
||||
# Check format of responses:
|
||||
if not (isinstance(data['responses'], str) or isinstance(data['responses'], list)):
|
||||
return jsonify({'error': 'POST data responses is improper format.'})
|
||||
elif isinstance(data['responses'], str):
|
||||
data['responses'] = [ data['responses'] ]
|
||||
|
||||
# add the path to any scripts to the path:
|
||||
try:
|
||||
if 'script_paths' in data:
|
||||
for script_path in data['script_paths']:
|
||||
# get the folder of the script_path:
|
||||
script_folder = os.path.dirname(script_path)
|
||||
# check that the script_folder is valid, and it contains __init__.py
|
||||
if not os.path.exists(script_folder):
|
||||
print(script_folder, 'is not a valid script path.')
|
||||
print(os.path.exists(script_folder))
|
||||
continue
|
||||
|
||||
# add it to the path:
|
||||
sys.path.append(script_folder)
|
||||
print(f'added {script_folder} to sys.path')
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Could not add script path to sys.path. Error message:\n{str(e)}'})
|
||||
|
||||
# Create the evaluator function
|
||||
# DANGER DANGER!
|
||||
try:
|
||||
exec(data['code'], globals())
|
||||
|
||||
# Double-check that there is an 'evaluate' method in our namespace.
|
||||
# This will throw a NameError if not:
|
||||
evaluate
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
|
||||
|
||||
# Load all responses with the given ID:
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
all_evald_responses = []
|
||||
for cache_id in data['responses']:
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
if len(cache_files) == 0:
|
||||
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
|
||||
|
||||
# To avoid loading all response files into memory at once, we'll run the evaluator on each file:
|
||||
for filename in cache_files:
|
||||
|
||||
# Load the raw responses from the cache
|
||||
responses = load_cache_json(os.path.join('cache', filename))
|
||||
if len(responses) == 0: continue
|
||||
|
||||
# Run the evaluator over them:
|
||||
# NOTE: 'evaluate' here was defined dynamically from 'exec' above.
|
||||
try:
|
||||
evald_responses = run_over_responses(evaluate, responses, scope=data['scope'])
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Error encountered while trying to run "evaluate" method:\n{str(e)}'})
|
||||
|
||||
# Convert to standard format:
|
||||
std_evald_responses = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in evald_responses.items()
|
||||
]
|
||||
|
||||
# Perform any reduction operations:
|
||||
if 'reduce_vars' in data and len(data['reduce_vars']) > 0:
|
||||
std_evald_responses = reduce_responses(
|
||||
std_evald_responses,
|
||||
vars=data['reduce_vars']
|
||||
)
|
||||
|
||||
all_evald_responses.extend(std_evald_responses)
|
||||
|
||||
# Store the evaluated responses in a new cache json:
|
||||
with open(cache_filepath, "w") as f:
|
||||
json.dump(all_evald_responses, f)
|
||||
|
||||
ret = jsonify({'responses': all_evald_responses})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@app.route('/api/checkEvalFunc', methods=['POST'])
|
||||
def checkEvalFunc():
|
||||
"""
|
||||
Tries to compile a Python lambda function sent from JavaScript.
|
||||
Returns a dict with 'result':true if it compiles without raising an exception;
|
||||
'result':false (and an 'error' property with a message) if not.
|
||||
|
||||
POST'd data should be in form:
|
||||
{
|
||||
'code': str, # the body of the lambda function to evaluate, in form: lambda responses: <body>
|
||||
}
|
||||
|
||||
NOTE: This should only be run on your server on code you trust.
|
||||
There is no sandboxing; no safety. We assume you are the creator of the code.
|
||||
"""
|
||||
data = request.get_json()
|
||||
if 'code' not in data:
|
||||
return jsonify({'result': False, 'error': f'Could not evaluate code. Error message:\n{str(e)}'})
|
||||
|
||||
# DANGER DANGER! Running exec on code passed through front-end. Make sure it's trusted!
|
||||
try:
|
||||
exec(data['code'], globals())
|
||||
|
||||
# Double-check that there is an 'evaluate' method in our namespace.
|
||||
# This will throw a NameError if not:
|
||||
evaluate
|
||||
return jsonify({'result': True})
|
||||
except Exception as e:
|
||||
return jsonify({'result': False, 'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
|
||||
|
||||
@app.route('/api/grabResponses', methods=['POST'])
|
||||
def grabResponses():
|
||||
"""
|
||||
Returns all responses with the specified id(s)
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'responses': <the ids to grab>
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
# Check format of responses:
|
||||
if not (isinstance(data['responses'], str) or isinstance(data['responses'], list)):
|
||||
return jsonify({'error': 'POST data responses is improper format.'})
|
||||
elif isinstance(data['responses'], str):
|
||||
data['responses'] = [ data['responses'] ]
|
||||
|
||||
# Load all responses with the given ID:
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
print(all_cache_files)
|
||||
responses = []
|
||||
for cache_id in data['responses']:
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
if len(cache_files) == 0:
|
||||
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
|
||||
|
||||
for filename in cache_files:
|
||||
res = load_cache_json(os.path.join('cache', filename))
|
||||
if isinstance(res, dict):
|
||||
# Convert to standard response format
|
||||
res = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in res.items()
|
||||
]
|
||||
responses.extend(res)
|
||||
|
||||
print(responses)
|
||||
ret = jsonify({'responses': responses})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='This script spins up a Flask server that serves as the backend for ChainForge')
|
||||
|
||||
# Turn on to disable all outbound LLM API calls and replace them with dummy calls
|
||||
@ -568,11 +88,12 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--port', help='The port to run the server on. Defaults to 8000.', type=int, default=8000, nargs='?')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dummy_responses:
|
||||
PromptLLM = PromptLLMDummy
|
||||
extract_responses = lambda r, llm: r['response']
|
||||
|
||||
port = args.port if args.port else 8000
|
||||
|
||||
# Spin up separate thread for socketio app, on port+1 (8001 default)
|
||||
print(f"Serving SocketIO server on port {port+1}...")
|
||||
t1 = threading.Thread(target=run_socketio_server, args=[socketio, port+1])
|
||||
t1.start()
|
||||
|
||||
print(f"Serving Flask server on port {port}...")
|
||||
# socketio.run(app, host="localhost", port=port)
|
||||
app.run(host="localhost", port=8000, debug=True)
|
||||
run_server(host="localhost", port=port, cmd_args=args)
|
575
python-backend/flask_app.py
Normal file
575
python-backend/flask_app.py
Normal file
@ -0,0 +1,575 @@
|
||||
import json, os, asyncio, sys, argparse, threading
|
||||
from dataclasses import dataclass
|
||||
from statistics import mean, median, stdev
|
||||
from flask import Flask, request, jsonify, render_template, send_from_directory
|
||||
from flask_cors import CORS
|
||||
from flask_socketio import SocketIO
|
||||
from promptengine.query import PromptLLM, PromptLLMDummy
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
|
||||
|
||||
BUILD_DIR = "../chain-forge/build"
|
||||
STATIC_DIR = BUILD_DIR + '/static'
|
||||
app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
|
||||
|
||||
# Set up CORS for specific routes
|
||||
cors = CORS(app, resources={r"/*": {"origins": "*"}})
|
||||
|
||||
# Serve React app
|
||||
@app.route("/")
|
||||
def index():
|
||||
return render_template("index.html")
|
||||
|
||||
# @app.route('/', defaults={'path': ''})
|
||||
# @app.route('/<path:path>')
|
||||
# def serve(path):
|
||||
# if path != "" and os.path.exists(BUILD_DIR + '/' + path):
|
||||
# return send_from_directory(BUILD_DIR, path)
|
||||
# else:
|
||||
# return send_from_directory(BUILD_DIR, 'index.html')
|
||||
|
||||
LLM_NAME_MAP = {
|
||||
'gpt3.5': LLM.ChatGPT,
|
||||
'alpaca.7B': LLM.Alpaca7B,
|
||||
'gpt4': LLM.GPT4,
|
||||
}
|
||||
LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()}
|
||||
|
||||
@dataclass
|
||||
class ResponseInfo:
|
||||
"""Stores info about a single response. Passed to evaluator functions."""
|
||||
text: str
|
||||
prompt: str
|
||||
var: str
|
||||
llm: str
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
def to_standard_format(r: dict) -> list:
|
||||
llm = LLM_NAME_MAP_INVERSE[r['llm']]
|
||||
resp_obj = {
|
||||
'vars': r['info'],
|
||||
'llm': llm,
|
||||
'prompt': r['prompt'],
|
||||
'responses': extract_responses(r, r['llm']),
|
||||
'tokens': r['response']['usage'] if 'usage' in r['response'] else {},
|
||||
}
|
||||
if 'eval_res' in r:
|
||||
resp_obj['eval_res'] = r['eval_res']
|
||||
return resp_obj
|
||||
|
||||
def get_llm_of_response(response: dict) -> LLM:
|
||||
return LLM_NAME_MAP[response['llm']]
|
||||
|
||||
def get_filenames_with_id(filenames: list, id: str) -> list:
|
||||
return [
|
||||
c for c in filenames
|
||||
if c.split('.')[0] == id or ('-' in c and c[:c.rfind('-')] == id)
|
||||
]
|
||||
|
||||
def remove_cached_responses(cache_id: str):
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
for filename in cache_files:
|
||||
os.remove(os.path.join('cache/', filename))
|
||||
|
||||
def load_cache_json(filepath: str) -> dict:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
responses = json.load(f)
|
||||
return responses
|
||||
|
||||
def run_over_responses(eval_func, responses: dict, scope: str) -> list:
|
||||
for prompt, resp_obj in responses.items():
|
||||
res = extract_responses(resp_obj, resp_obj['llm'])
|
||||
if scope == 'response':
|
||||
evals = [ # Run evaluator func over every individual response text
|
||||
eval_func(
|
||||
ResponseInfo(
|
||||
text=r,
|
||||
prompt=prompt,
|
||||
var=resp_obj['info'],
|
||||
llm=LLM_NAME_MAP_INVERSE[resp_obj['llm']])
|
||||
) for r in res
|
||||
]
|
||||
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
|
||||
'mean': mean(evals),
|
||||
'median': median(evals),
|
||||
'stdev': stdev(evals) if len(evals) > 1 else 0,
|
||||
'range': (min(evals), max(evals)),
|
||||
'items': evals,
|
||||
}
|
||||
else: # operate over the entire response batch
|
||||
ev = eval_func(res)
|
||||
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
|
||||
'mean': ev,
|
||||
'median': ev,
|
||||
'stdev': 0,
|
||||
'range': (ev, ev),
|
||||
'items': [ev],
|
||||
}
|
||||
return responses
|
||||
|
||||
def reduce_responses(responses: list, vars: list) -> list:
|
||||
if len(responses) == 0: return responses
|
||||
|
||||
# Figure out what vars we still care about (the ones we aren't reducing over):
|
||||
# NOTE: We are assuming all responses have the same 'vars' keys.
|
||||
all_vars = set(responses[0]['vars'])
|
||||
|
||||
if not all_vars.issuperset(set(vars)):
|
||||
# There's a var in vars which isn't part of the response.
|
||||
raise Exception(f"Some vars in {set(vars)} are not in the responses.")
|
||||
|
||||
# Get just the vars we want to keep around:
|
||||
include_vars = list(set(responses[0]['vars']) - set(vars))
|
||||
|
||||
# Bucket responses by the remaining var values, where tuples of vars are keys to a dict:
|
||||
# E.g. {(var1_val, var2_val): [responses] }
|
||||
bucketed_resp = {}
|
||||
for r in responses:
|
||||
print(r)
|
||||
tup_key = tuple([r['vars'][v] for v in include_vars])
|
||||
if tup_key in bucketed_resp:
|
||||
bucketed_resp[tup_key].append(r)
|
||||
else:
|
||||
bucketed_resp[tup_key] = [r]
|
||||
|
||||
# Perform reduce op across all bucketed responses, collecting them into a single 'meta'-response:
|
||||
ret = []
|
||||
for tup_key, resps in bucketed_resp.items():
|
||||
flat_eval_res = [item for r in resps for item in r['eval_res']['items']]
|
||||
ret.append({
|
||||
'vars': {v: r['vars'][v] for r in resps for v in include_vars},
|
||||
'llm': resps[0]['llm'],
|
||||
'prompt': [r['prompt'] for r in resps],
|
||||
'responses': [r['responses'] for r in resps],
|
||||
'tokens': resps[0]['tokens'],
|
||||
'eval_res': {
|
||||
'mean': mean(flat_eval_res),
|
||||
'median': median(flat_eval_res),
|
||||
'stdev': stdev(flat_eval_res) if len(flat_eval_res) > 1 else 0,
|
||||
'range': (min(flat_eval_res), max(flat_eval_res)),
|
||||
'items': flat_eval_res
|
||||
}
|
||||
})
|
||||
|
||||
return ret
|
||||
|
||||
@app.route('/app/test', methods=['GET'])
|
||||
def test():
|
||||
return "Hello, world!"
|
||||
|
||||
# @socketio.on('queryllm', namespace='/queryllm')
|
||||
# def handleQueryAsync(data):
|
||||
# print("reached handleQueryAsync")
|
||||
# socketio.start_background_task(queryLLM, emitLLMResponse)
|
||||
|
||||
# def emitLLMResponse(result):
|
||||
# socketio.emit('response', result)
|
||||
|
||||
"""
|
||||
Testing sockets. The following function can
|
||||
communicate to React via with the JS code:
|
||||
|
||||
const socket = io(BASE_URL + 'queryllm', {
|
||||
transports: ["websocket"],
|
||||
cors: {
|
||||
origin: "http://localhost:3000/",
|
||||
},
|
||||
});
|
||||
|
||||
socket.on("connect", (data) => {
|
||||
socket.emit("queryllm", "hello");
|
||||
});
|
||||
socket.on("disconnect", (data) => {
|
||||
console.log("disconnected");
|
||||
});
|
||||
socket.on("response", (data) => {
|
||||
console.log(data);
|
||||
});
|
||||
"""
|
||||
# def background_thread():
|
||||
# n = 10
|
||||
# while n > 0:
|
||||
# socketio.sleep(0.5)
|
||||
# socketio.emit('response', n, namespace='/queryllm')
|
||||
# n -= 1
|
||||
|
||||
# @socketio.on('queryllm', namespace='/queryllm')
|
||||
# def testSocket(data):
|
||||
# print(data)
|
||||
# global thread
|
||||
# with thread_lock:
|
||||
# if thread is None:
|
||||
# thread = socketio.start_background_task(target=background_thread)
|
||||
|
||||
# @socketio.on('queryllm', namespace='/queryllm')
|
||||
# def handleQuery(data):
|
||||
# print(data)
|
||||
|
||||
# def handleConnect():
|
||||
# print('here')
|
||||
# socketio.emit('response', 'goodbye', namespace='/')
|
||||
|
||||
@app.route('/app/countQueriesRequired', methods=['POST'])
|
||||
def countQueries():
|
||||
"""
|
||||
Returns how many queries we need to make, given the passed prompt and vars.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'prompt': str # the prompt template, with any {{}} vars
|
||||
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
|
||||
'llms': list # the list of LLMs you will query
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
if not set(data.keys()).issuperset({'prompt', 'vars', 'llms'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
|
||||
try:
|
||||
gen_prompts = PromptPermutationGenerator(PromptTemplate(data['prompt']))
|
||||
all_prompt_permutations = list(gen_prompts(data['vars']))
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)})
|
||||
|
||||
# TODO: Send more informative data back including how many queries per LLM based on cache'd data
|
||||
num_queries = {} # len(all_prompt_permutations) * len(data['llms'])
|
||||
for llm in data['llms']:
|
||||
num_queries[llm] = len(all_prompt_permutations)
|
||||
|
||||
ret = jsonify({'counts': num_queries})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@app.route('/app/createProgressFile', methods=['POST'])
|
||||
def createProgressFile():
|
||||
"""
|
||||
Creates a temp txt file for storing progress of async LLM queries.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'id': str # a unique ID that will be used when calling 'queryllm'
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
if 'id' not in data or not isinstance(data['id'], str) or len(data['id']) == 0:
|
||||
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
||||
|
||||
# Create a scratch file for keeping track of how many responses loaded
|
||||
try:
|
||||
with open(f"cache/_temp_{data['id']}.txt", 'w') as f:
|
||||
json.dump({}, f)
|
||||
ret = jsonify({'success': True})
|
||||
except Exception as e:
|
||||
ret = jsonify({'success': False, 'error': str(e)})
|
||||
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
# @socketio.on('connect', namespace='/queryllm')
|
||||
@app.route('/app/queryllm', methods=['POST'])
|
||||
async def queryLLM():
|
||||
"""
|
||||
Queries LLM(s) given a JSON spec.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'id': str # a unique ID to refer to this information. Used when cache'ing responses.
|
||||
'llm': str | list # a string or list of strings specifying the LLM(s) to query
|
||||
'params': dict # an optional dict of any other params to set when querying the LLMs, like 'temperature', 'n' (num of responses per prompt), etc.
|
||||
'prompt': str # the prompt template, with any {{}} vars
|
||||
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
|
||||
'no_cache': bool (optional) # delete any cache'd responses for 'id' (always call the LLM fresh)
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
# Check that all required info is here:
|
||||
if not set(data.keys()).issuperset({'llm', 'prompt', 'vars', 'id'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
elif not isinstance(data['id'], str) or len(data['id']) == 0:
|
||||
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
||||
|
||||
# Verify LLM name(s) (string or list) and convert to enum(s):
|
||||
if not (isinstance(data['llm'], list) or isinstance(data['llm'], str)) or (isinstance(data['llm'], list) and len(data['llm']) == 0):
|
||||
return jsonify({'error': 'POST data llm is improper format (not string or list, or of length 0).'})
|
||||
if isinstance(data['llm'], str):
|
||||
data['llm'] = [ data['llm'] ]
|
||||
|
||||
for llm_str in data['llm']:
|
||||
if llm_str not in LLM_NAME_MAP:
|
||||
return jsonify({'error': f"LLM named '{llm_str}' is not supported."})
|
||||
|
||||
if 'no_cache' in data and data['no_cache'] is True:
|
||||
remove_cached_responses(data['id'])
|
||||
|
||||
# Create a cache dir if it doesn't exist:
|
||||
create_dir_if_not_exists('cache')
|
||||
|
||||
# For each LLM, generate and cache responses:
|
||||
responses = {}
|
||||
llms = data['llm']
|
||||
params = data['params'] if 'params' in data else {}
|
||||
tempfilepath = f"cache/_temp_{data['id']}.txt"
|
||||
|
||||
async def query(llm_str: str) -> list:
|
||||
llm = LLM_NAME_MAP[llm_str]
|
||||
|
||||
# Check that storage path is valid:
|
||||
cache_filepath = os.path.join('cache', f"{data['id']}-{str(llm.name)}.json")
|
||||
if not is_valid_filepath(cache_filepath):
|
||||
return jsonify({'error': f'Invalid filepath: {cache_filepath}'})
|
||||
|
||||
# Create an object to query the LLM, passing a file for cache'ing responses
|
||||
prompter = PromptLLM(data['prompt'], storageFile=cache_filepath)
|
||||
|
||||
# Prompt the LLM with all permutations of the input prompt template:
|
||||
# NOTE: If the responses are already cache'd, this just loads them (no LLM is queried, saving $$$)
|
||||
resps = []
|
||||
try:
|
||||
print(f'Querying {llm}...')
|
||||
async for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
|
||||
resps.append(response)
|
||||
print(f"collected response from {llm.name}:", str(response))
|
||||
|
||||
# Save the number of responses collected to a temp file on disk
|
||||
with open(tempfilepath, 'r') as f:
|
||||
txt = f.read().strip()
|
||||
|
||||
cur_data = json.loads(txt) if len(txt) > 0 else {}
|
||||
cur_data[llm_str] = len(resps)
|
||||
|
||||
with open(tempfilepath, 'w') as f:
|
||||
json.dump(cur_data, f)
|
||||
except Exception as e:
|
||||
print('error generating responses:', e)
|
||||
raise e
|
||||
|
||||
return {'llm': llm, 'responses': resps}
|
||||
|
||||
try:
|
||||
# Request responses simultaneously across LLMs
|
||||
tasks = [query(llm) for llm in llms]
|
||||
|
||||
# Await the responses from all queried LLMs
|
||||
llm_results = await asyncio.gather(*tasks)
|
||||
for item in llm_results:
|
||||
responses[item['llm']] = item['responses']
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)})
|
||||
|
||||
# Convert the responses into a more standardized format with less information
|
||||
res = [
|
||||
to_standard_format(r)
|
||||
for rs in responses.values()
|
||||
for r in rs
|
||||
]
|
||||
|
||||
# Return all responses for all LLMs
|
||||
print('returning responses:', res)
|
||||
ret = jsonify({'responses': res})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@app.route('/app/execute', methods=['POST'])
|
||||
def execute():
|
||||
"""
|
||||
Executes a Python lambda function sent from JavaScript,
|
||||
over all cache'd responses with given id's.
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'id': # a unique ID to refer to this information. Used when cache'ing responses.
|
||||
'code': str, # the body of the lambda function to evaluate, in form: lambda responses: <body>
|
||||
'responses': str | List[str] # the responses to run on; a unique ID or list of unique IDs of cache'd data,
|
||||
'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch.
|
||||
# If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt.
|
||||
'reduce_vars': unspecified | List[str] # the 'vars' to average over (mean, median, stdev, range)
|
||||
'script_paths': unspecified | List[str] # the paths to scripts to be added to the path before the lambda function is evaluated
|
||||
}
|
||||
|
||||
NOTE: This should only be run on your server on code you trust.
|
||||
There is no sandboxing; no safety. We assume you are the creator of the code.
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
# Check that all required info is here:
|
||||
if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
if not isinstance(data['id'], str) or len(data['id']) == 0:
|
||||
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
||||
if data['scope'] not in ('response', 'batch'):
|
||||
return jsonify({'error': "POST data scope is unknown. Must be either 'response' or 'batch'."})
|
||||
|
||||
# Check that the filepath used to cache eval'd responses is valid:
|
||||
cache_filepath = os.path.join('cache', f"{data['id']}.json")
|
||||
if not is_valid_filepath(cache_filepath):
|
||||
return jsonify({'error': f'Invalid filepath: {cache_filepath}'})
|
||||
|
||||
# Check format of responses:
|
||||
if not (isinstance(data['responses'], str) or isinstance(data['responses'], list)):
|
||||
return jsonify({'error': 'POST data responses is improper format.'})
|
||||
elif isinstance(data['responses'], str):
|
||||
data['responses'] = [ data['responses'] ]
|
||||
|
||||
# add the path to any scripts to the path:
|
||||
try:
|
||||
if 'script_paths' in data:
|
||||
for script_path in data['script_paths']:
|
||||
# get the folder of the script_path:
|
||||
script_folder = os.path.dirname(script_path)
|
||||
# check that the script_folder is valid, and it contains __init__.py
|
||||
if not os.path.exists(script_folder):
|
||||
print(script_folder, 'is not a valid script path.')
|
||||
print(os.path.exists(script_folder))
|
||||
continue
|
||||
|
||||
# add it to the path:
|
||||
sys.path.append(script_folder)
|
||||
print(f'added {script_folder} to sys.path')
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Could not add script path to sys.path. Error message:\n{str(e)}'})
|
||||
|
||||
# Create the evaluator function
|
||||
# DANGER DANGER!
|
||||
try:
|
||||
exec(data['code'], globals())
|
||||
|
||||
# Double-check that there is an 'evaluate' method in our namespace.
|
||||
# This will throw a NameError if not:
|
||||
evaluate
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
|
||||
|
||||
# Load all responses with the given ID:
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
all_evald_responses = []
|
||||
for cache_id in data['responses']:
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
if len(cache_files) == 0:
|
||||
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
|
||||
|
||||
# To avoid loading all response files into memory at once, we'll run the evaluator on each file:
|
||||
for filename in cache_files:
|
||||
|
||||
# Load the raw responses from the cache
|
||||
responses = load_cache_json(os.path.join('cache', filename))
|
||||
if len(responses) == 0: continue
|
||||
|
||||
# Run the evaluator over them:
|
||||
# NOTE: 'evaluate' here was defined dynamically from 'exec' above.
|
||||
try:
|
||||
evald_responses = run_over_responses(evaluate, responses, scope=data['scope'])
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Error encountered while trying to run "evaluate" method:\n{str(e)}'})
|
||||
|
||||
# Convert to standard format:
|
||||
std_evald_responses = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in evald_responses.items()
|
||||
]
|
||||
|
||||
# Perform any reduction operations:
|
||||
if 'reduce_vars' in data and len(data['reduce_vars']) > 0:
|
||||
std_evald_responses = reduce_responses(
|
||||
std_evald_responses,
|
||||
vars=data['reduce_vars']
|
||||
)
|
||||
|
||||
all_evald_responses.extend(std_evald_responses)
|
||||
|
||||
# Store the evaluated responses in a new cache json:
|
||||
with open(cache_filepath, "w") as f:
|
||||
json.dump(all_evald_responses, f)
|
||||
|
||||
ret = jsonify({'responses': all_evald_responses})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@app.route('/app/checkEvalFunc', methods=['POST'])
|
||||
def checkEvalFunc():
|
||||
"""
|
||||
Tries to compile a Python lambda function sent from JavaScript.
|
||||
Returns a dict with 'result':true if it compiles without raising an exception;
|
||||
'result':false (and an 'error' property with a message) if not.
|
||||
|
||||
POST'd data should be in form:
|
||||
{
|
||||
'code': str, # the body of the lambda function to evaluate, in form: lambda responses: <body>
|
||||
}
|
||||
|
||||
NOTE: This should only be run on your server on code you trust.
|
||||
There is no sandboxing; no safety. We assume you are the creator of the code.
|
||||
"""
|
||||
data = request.get_json()
|
||||
if 'code' not in data:
|
||||
return jsonify({'result': False, 'error': f'Could not evaluate code. Error message:\n{str(e)}'})
|
||||
|
||||
# DANGER DANGER! Running exec on code passed through front-end. Make sure it's trusted!
|
||||
try:
|
||||
exec(data['code'], globals())
|
||||
|
||||
# Double-check that there is an 'evaluate' method in our namespace.
|
||||
# This will throw a NameError if not:
|
||||
evaluate
|
||||
return jsonify({'result': True})
|
||||
except Exception as e:
|
||||
return jsonify({'result': False, 'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
|
||||
|
||||
@app.route('/app/grabResponses', methods=['POST'])
|
||||
def grabResponses():
|
||||
"""
|
||||
Returns all responses with the specified id(s)
|
||||
|
||||
POST'd data should be in the form:
|
||||
{
|
||||
'responses': <the ids to grab>
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
# Check format of responses:
|
||||
if not (isinstance(data['responses'], str) or isinstance(data['responses'], list)):
|
||||
return jsonify({'error': 'POST data responses is improper format.'})
|
||||
elif isinstance(data['responses'], str):
|
||||
data['responses'] = [ data['responses'] ]
|
||||
|
||||
# Load all responses with the given ID:
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
print(all_cache_files)
|
||||
responses = []
|
||||
for cache_id in data['responses']:
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
if len(cache_files) == 0:
|
||||
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
|
||||
|
||||
for filename in cache_files:
|
||||
res = load_cache_json(os.path.join('cache', filename))
|
||||
if isinstance(res, dict):
|
||||
# Convert to standard response format
|
||||
res = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in res.items()
|
||||
]
|
||||
responses.extend(res)
|
||||
|
||||
print(responses)
|
||||
ret = jsonify({'responses': responses})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
def run_server(host="", port=8000, cmd_args=None):
|
||||
if cmd_args is not None and cmd_args.dummy_responses:
|
||||
global PromptLLM
|
||||
global extract_responses
|
||||
PromptLLM = PromptLLMDummy
|
||||
extract_responses = lambda r, llm: r['response']
|
||||
|
||||
app.run(host=host, port=port)
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("Run app.py instead.")
|
@ -4,4 +4,5 @@ flask_socketio
|
||||
openai
|
||||
python-socketio
|
||||
dalaipy==2.0.2
|
||||
gevent-websocket
|
||||
gevent-websocket
|
||||
werkzeug
|
3
python-backend/server.py
Normal file
3
python-backend/server.py
Normal file
@ -0,0 +1,3 @@
|
||||
import subprocess
|
||||
|
||||
subprocess.run("python socketio_app.py & python app.py & wait", shell=True)
|
@ -1,67 +0,0 @@
|
||||
import json, os, asyncio, sys, argparse
|
||||
from dataclasses import dataclass
|
||||
from statistics import mean, median, stdev
|
||||
from flask import Flask, request, jsonify
|
||||
from flask_cors import CORS
|
||||
from flask_socketio import SocketIO
|
||||
from promptengine.query import PromptLLM, PromptLLMDummy
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# Set up CORS for specific routes
|
||||
# cors = CORS(app, resources={r"/api/*": {"origins": "*"}})
|
||||
|
||||
# Initialize Socket.IO
|
||||
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="gevent")
|
||||
|
||||
# Wait a max of a full minute (60 seconds) for the response count to update, before exiting.
|
||||
MAX_WAIT_TIME = 60
|
||||
|
||||
# import threading
|
||||
# thread = None
|
||||
# thread_lock = threading.Lock()
|
||||
|
||||
def countdown():
|
||||
n = 10
|
||||
while n > 0:
|
||||
socketio.sleep(0.5)
|
||||
socketio.emit('response', n, namespace='/queryllm')
|
||||
n -= 1
|
||||
|
||||
def readCounts(id, max_count):
|
||||
i = 0
|
||||
n = 0
|
||||
last_n = 0
|
||||
while i < MAX_WAIT_TIME and n < max_count:
|
||||
with open(f'cache/_temp_{id}.txt', 'r') as f:
|
||||
queries = json.load(f)
|
||||
n = sum([int(n) for llm, n in queries.items()])
|
||||
print(n)
|
||||
socketio.emit('response', queries, namespace='/queryllm')
|
||||
socketio.sleep(0.1)
|
||||
if last_n != n:
|
||||
i = 0
|
||||
last_n = n
|
||||
else:
|
||||
i += 0.1
|
||||
|
||||
if i >= MAX_WAIT_TIME:
|
||||
print(f"Error: Waited maximum {MAX_WAIT_TIME} seconds for response count to update. Exited prematurely.")
|
||||
socketio.emit('finish', 'max_wait_reached', namespace='/queryllm')
|
||||
else:
|
||||
print("All responses loaded!")
|
||||
socketio.emit('finish', 'success', namespace='/queryllm')
|
||||
|
||||
@socketio.on('queryllm', namespace='/queryllm')
|
||||
def testSocket(data):
|
||||
readCounts(data['id'], data['max'])
|
||||
# countdown()
|
||||
# global thread
|
||||
# with thread_lock:
|
||||
# if thread is None:
|
||||
# thread = socketio.start_background_task(target=countdown)
|
||||
|
||||
if __name__ == "__main__":
|
||||
socketio.run(app, host="localhost", port=8001)
|
Loading…
x
Reference in New Issue
Block a user