mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-22 20:05:18 +00:00
190 lines
7.8 KiB
Python
190 lines
7.8 KiB
Python
from flask import Flask, request, jsonify
|
|
from flask_cors import CORS
|
|
from promptengine.query import PromptLLM
|
|
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
|
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir
|
|
import json, os
|
|
|
|
app = Flask(__name__)
|
|
CORS(app)
|
|
|
|
LLM_NAME_MAP = {
|
|
'gpt3.5': LLM.ChatGPT,
|
|
}
|
|
LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()}
|
|
|
|
def to_standard_format(r: dict) -> list:
|
|
llm = LLM_NAME_MAP_INVERSE[r['llm']]
|
|
resp_obj = {
|
|
'vars': r['info'],
|
|
'llm': llm,
|
|
'prompt': r['prompt'],
|
|
'responses': extract_responses(r, r['llm']),
|
|
'tokens': r['response']['usage'] if 'usage' in r['response'] else {},
|
|
}
|
|
if 'eval_res' in r:
|
|
resp_obj['eval_res'] = r['eval_res']
|
|
return resp_obj
|
|
|
|
def get_llm_of_response(response: dict) -> LLM:
|
|
return LLM_NAME_MAP[response['llm']]
|
|
|
|
def get_filenames_with_id(filenames: list, id: str) -> list:
|
|
return [
|
|
c for c in filenames
|
|
if ('-' in c and c.split('-')[0] == id) or ('-' not in c and c.split('.')[0] == id)
|
|
]
|
|
|
|
def load_cache_json(filepath: str) -> dict:
|
|
with open(filepath, encoding="utf-8") as f:
|
|
responses = json.load(f)
|
|
return responses
|
|
|
|
def run_over_responses(eval_func, responses: dict) -> list:
|
|
for prompt, resp_obj in responses.items():
|
|
res = extract_responses(resp_obj, resp_obj['llm'])
|
|
resp_obj['eval_res'] = [eval_func(r) for r in res] # run evaluator func over every individual response text
|
|
return responses
|
|
|
|
@app.route('/queryllm', methods=['POST'])
|
|
def queryLLM():
|
|
"""
|
|
Queries LLM(s) given a JSON spec.
|
|
|
|
POST'd data should be in the form:
|
|
{
|
|
'id': str # a unique ID to refer to this information. Used when cache'ing responses.
|
|
'llm': str | list # a string or list of strings specifying the LLM(s) to query
|
|
'params': dict # an optional dict of any other params to set when querying the LLMs, like 'temperature', 'n' (num of responses per prompt), etc.
|
|
'prompt': str # the prompt template, with any {{}} vars
|
|
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
|
|
}
|
|
"""
|
|
data = request.get_json()
|
|
|
|
# Check that all required info is here:
|
|
if not set(data.keys()).issuperset({'llm', 'prompt', 'vars', 'id'}):
|
|
return jsonify({'error': 'POST data is improper format.'})
|
|
elif not isinstance(data['id'], str) or len(data['id']) == 0:
|
|
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
|
|
|
# Verify LLM name(s) (string or list) and convert to enum(s):
|
|
if not (isinstance(data['llm'], list) or isinstance(data['llm'], str)) or (isinstance(data['llm'], list) and len(data['llm']) == 0):
|
|
return jsonify({'error': 'POST data llm is improper format (not string or list, or of length 0).'})
|
|
if isinstance(data['llm'], str):
|
|
data['llm'] = [ data['llm'] ]
|
|
llms = []
|
|
for llm_str in data['llm']:
|
|
if llm_str not in LLM_NAME_MAP:
|
|
return jsonify({'error': f"LLM named '{llm_str}' is not supported."})
|
|
llms.append(LLM_NAME_MAP[llm_str])
|
|
|
|
# For each LLM, generate and cache responses:
|
|
responses = {}
|
|
params = data['params'] if 'params' in data else {}
|
|
for llm in llms:
|
|
|
|
# Check that storage path is valid:
|
|
cache_filepath = os.path.join('cache', f"{data['id']}-{str(llm.name)}.json")
|
|
if not is_valid_filepath(cache_filepath):
|
|
return jsonify({'error': f'Invalid filepath: {cache_filepath}'})
|
|
|
|
# Create an object to query the LLM, passing a file for cache'ing responses
|
|
prompter = PromptLLM(data['prompt'], storageFile=cache_filepath)
|
|
|
|
# Prompt the LLM with all permutations of the input prompt template:
|
|
# NOTE: If the responses are already cache'd, this just loads them (no LLM is queried, saving $$$)
|
|
responses[llm] = []
|
|
try:
|
|
for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
|
|
responses[llm].append(response)
|
|
except Exception as e:
|
|
return jsonify({'error': str(e)})
|
|
|
|
# Convert the responses into a more standardized format with less information
|
|
res = [
|
|
to_standard_format(r)
|
|
for rs in responses.values()
|
|
for r in rs
|
|
]
|
|
|
|
# Return all responses for all LLMs
|
|
ret = jsonify({'responses': res})
|
|
ret.headers.add('Access-Control-Allow-Origin', '*')
|
|
return ret
|
|
|
|
@app.route('/execute', methods=['POST'])
|
|
def execute():
|
|
"""
|
|
Executes a Python lambda function sent from JavaScript,
|
|
over all cache'd responses with given id's.
|
|
|
|
POST'd data should be in the form:
|
|
{
|
|
'id': # a unique ID to refer to this information. Used when cache'ing responses.
|
|
'code': str, # the body of the lambda function to evaluate, in form: lambda responses: <body>
|
|
'responses': str | List[str] # the responses to run on; a unique ID or list of unique IDs of cache'd data
|
|
}
|
|
|
|
NOTE: This should only be run on your server on code you trust.
|
|
There is no sandboxing; no safety. We assume you are the creator of the code.
|
|
"""
|
|
data = request.get_json()
|
|
|
|
# Check that all required info is here:
|
|
if not set(data.keys()).issuperset({'id', 'code', 'responses'}):
|
|
return jsonify({'error': 'POST data is improper format.'})
|
|
if not isinstance(data['id'], str) or len(data['id']) == 0:
|
|
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
|
|
|
# Check that the filepath used to cache eval'd responses is valid:
|
|
cache_filepath = os.path.join('cache', f"{data['id']}.json")
|
|
if not is_valid_filepath(cache_filepath):
|
|
return jsonify({'error': f'Invalid filepath: {cache_filepath}'})
|
|
|
|
# Check format of responses:
|
|
if not (isinstance(data['responses'], str) or isinstance(data['responses'], list)):
|
|
return jsonify({'error': 'POST data responses is improper format.'})
|
|
elif isinstance(data['responses'], str):
|
|
data['responses'] = [ data['responses'] ]
|
|
|
|
# Create the evaluator function
|
|
# DANGER DANGER!
|
|
try:
|
|
exec('def evaluator(response):\n\t' + '\t\n'.join(data['code'].split('\n')), globals())
|
|
except Exception as e:
|
|
return jsonify({'error': f'Could not evaluate code. Error message:\n{str(e)}'})
|
|
|
|
# Load all responses with the given ID:
|
|
all_cache_files = get_files_at_dir('cache/')
|
|
all_evald_responses = []
|
|
for cache_id in data['responses']:
|
|
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
|
if len(cache_files) == 0:
|
|
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
|
|
|
|
# To avoid loading all response files into memory at once, we'll run the evaluator on each file:
|
|
for filename in cache_files:
|
|
|
|
# Load the responses from the cache
|
|
responses = load_cache_json(os.path.join('cache', filename))
|
|
if len(responses) == 0: continue
|
|
|
|
# Run the evaluator over them:
|
|
evald_responses = run_over_responses(evaluator, responses)
|
|
|
|
all_evald_responses.extend([
|
|
to_standard_format({'prompt': prompt, **res_obj})
|
|
for prompt, res_obj in evald_responses.items()
|
|
])
|
|
|
|
# Store the evaluated responses in a new cache json:
|
|
with open(cache_filepath, "w") as f:
|
|
json.dump(all_evald_responses, f)
|
|
|
|
ret = jsonify({'responses': all_evald_responses})
|
|
ret.headers.add('Access-Control-Allow-Origin', '*')
|
|
return ret
|
|
|
|
if __name__ == '__main__':
|
|
app.run(host="localhost", port=5000, debug=True) |