ChainForge/chainforge/flask_app.py
ianarawjo beeffd0ebb
Better edges, ReactFlow v11, AlephAlpha integration, better .gitignores (#139)
* add aleph alpha to installg guide description

* add aleph alpha to settings to add key

* add settings for alephalpha

* add aleph alpha to models

* add aleph alpha api key to keymap

* no visible changes, removed console.log

* current working copy

* added settings for aleph alpha, added test, working api request

* removed console log

* added static build with Aleph Alpha integration

* added static build with Aleph Alpha integration

* Corrected from ALEPH_ALPHA_KEY to ALEPH_ALPHA_API_KEY

* add additional settings requested

* merge conflicts

* add build to gitignore, best practice, reduces conflicts

* remove empty lines

* Delete chainforge/react-server/build directory

build directory should not be in remote

* remove unnecessary changes

* Update backend.ts

* ...

* fixed import

* add aleph alpha to store, remove available llms from modelsettings, remove double palm key input from globalsettings

* Update to ReactFlow v11

* Add remove button to edge on hover

* Quality of life improvements (#133)

* Quality of life improvements for node and python

* Pinned minor version for mantine modules

* Updated package-lock.json

* Added .gitignore to react-server

* Updated .gitignore to python module

---------

Co-authored-by: ianarawjo <fatso784@gmail.com>

* update gitignore

* Minor fixes to AlephAlpha integration code

* Tested npm i and rebuilt react

* Force commit the react build

* Update package version

---------

Co-authored-by: Denise Wagenführ <denise.wagenfuehr@capgemini.com>
Co-authored-by: fguderia <falko.guderian@capgemini.com>
Co-authored-by: denise710 <53524926+denise710@users.noreply.github.com>
Co-authored-by: wday-cs <119377799+wday-cs@users.noreply.github.com>
2023-09-25 09:44:18 -04:00

703 lines
28 KiB
Python

import json, os, sys, asyncio, time
from dataclasses import dataclass
from enum import Enum
from typing import List
from statistics import mean, median, stdev
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
from chainforge.providers.dalai import call_dalai
from chainforge.providers import ProviderRegistry
import requests as py_requests
""" =================
SETUP AND GLOBALS
=================
"""
# Setup Flask app to serve static version of React front-end
HOSTNAME = "localhost"
PORT = 8000
# SESSION_TOKEN = secrets.token_hex(32)
BUILD_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'react-server', 'build')
STATIC_DIR = os.path.join(BUILD_DIR, 'static')
app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
# Set up CORS for specific routes
cors = CORS(app, resources={r"/*": {"origins": "*"}})
# The cache and examples files base directories
CACHE_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'cache')
EXAMPLES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'examples')
class MetricType(Enum):
KeyValue = 0
KeyValue_Numeric = 1
KeyValue_Categorical = 2
KeyValue_Mixed = 3
Numeric = 4
Categorical = 5
Mixed = 6
Unknown = 7
Empty = 8
""" ==============
UTIL FUNCTIONS
==============
"""
HIJACKED_PRINT_LOG_FILE = None
ORIGINAL_PRINT_METHOD = None
def HIJACK_PYTHON_PRINT() -> None:
# Hijacks Python's print function, so that we can log
# the outputs when the evaluator is run:
import builtins
import tempfile
global HIJACKED_PRINT_LOG_FILE, ORIGINAL_PRINT_METHOD
# Create a temporary file for logging and keep it open
HIJACKED_PRINT_LOG_FILE = tempfile.NamedTemporaryFile(mode='a+', delete=False)
# Create a wrapper over the original print method, and save the original print
ORIGINAL_PRINT_METHOD = print
def hijacked_print(*args, **kwargs):
if 'file' in kwargs:
# We don't want to override any library that's using print to a file.
ORIGINAL_PRINT_METHOD(*args, **kwargs)
else:
ORIGINAL_PRINT_METHOD(*args, **kwargs, file=HIJACKED_PRINT_LOG_FILE)
# Replace the original print function with the custom print function
builtins.print = hijacked_print
def REVERT_PYTHON_PRINT() -> List[str]:
# Reverts back to original Python print method
# NOTE: Call this after hijack, and make sure you've caught all exceptions!
import builtins
global ORIGINAL_PRINT_METHOD, HIJACKED_PRINT_LOG_FILE
logs = []
if HIJACKED_PRINT_LOG_FILE is not None:
# Read the log file
HIJACKED_PRINT_LOG_FILE.seek(0)
logs = HIJACKED_PRINT_LOG_FILE.read().split('\n')
if ORIGINAL_PRINT_METHOD is not None:
builtins.print = ORIGINAL_PRINT_METHOD
HIJACKED_PRINT_LOG_FILE.close()
HIJACKED_PRINT_LOG_FILE = None
if len(logs) == 1 and len(logs[0].strip()) == 0:
logs = []
return logs
@dataclass
class ResponseInfo:
"""Stores info about a single LLM response. Passed to evaluator functions."""
text: str # The text of the LLM response
prompt: str # The text of the prompt using to query the LLM
var: dict # A dictionary of arguments that filled in the prompt template used to generate the final prompt
meta: dict # A dictionary of metadata ('metavars') that is 'carried alongside' data used to generate the prompt
llm: str # The name of the LLM queried (the nickname in ChainForge)
def __str__(self):
return self.text
def asMarkdownAST(self):
import mistune
md_ast_parser = mistune.create_markdown(renderer='ast')
return md_ast_parser(self.text)
def check_typeof_vals(arr: list) -> MetricType:
if len(arr) == 0: return MetricType.Empty
def typeof_set(types: set) -> MetricType:
if len(types) == 0: return MetricType.Empty
if len(types) == 1 and next(iter(types)) == dict:
return MetricType.KeyValue
elif all((t in (int, float) for t in types)):
# Numeric metrics only
return MetricType.Numeric
elif all((t in (str, bool) for t in types)):
# Categorical metrics only ('bool' is True/False, counts as categorical)
return MetricType.Categorical
elif all((t in (int, float, bool, str) for t in types)):
# Mix of numeric and categorical types
return MetricType.Mixed
else:
# Mix of types beyond basic ones
return MetricType.Unknown
def typeof_dict_vals(d):
dict_val_type = typeof_set(set((type(v) for v in d.values())))
if dict_val_type == MetricType.Numeric:
return MetricType.KeyValue_Numeric
elif dict_val_type == MetricType.Categorical:
return MetricType.KeyValue_Categorical
else:
return MetricType.KeyValue_Mixed
# Checks type of all values in 'arr' and returns the type
val_type = typeof_set(set((type(v) for v in arr)))
if val_type == MetricType.KeyValue:
# This is a 'KeyValue' pair type. We need to find the more specific type of the values in the dict.
# First, we check that all dicts have the exact same keys
for i in range(len(arr)-1):
d, e = arr[i], arr[i+1]
if set(d.keys()) != set(e.keys()):
raise Exception('The keys and size of dicts for evaluation results must be consistent across evaluations.')
# Then, we check the consistency of the type of dict values:
first_dict_val_type = typeof_dict_vals(arr[0])
for d in arr[1:]:
if first_dict_val_type != typeof_dict_vals(d):
raise Exception('Types of values in dicts for evaluation results must be consistent across responses.')
# If we're here, all checks passed, and we return the more specific KeyValue type:
return first_dict_val_type
else:
return val_type
def run_over_responses(eval_func, responses: list, scope: str) -> list:
for resp_obj in responses:
res = resp_obj['responses']
if scope == 'response':
# Run evaluator func over every individual response text
evals = [eval_func(
ResponseInfo(
text=r,
prompt=resp_obj['prompt'],
var=resp_obj['vars'],
meta=resp_obj['metavars'] if 'metavars' in resp_obj else {},
llm=resp_obj['llm'])
) for r in res]
# Check the type of evaluation results
# NOTE: We assume this is consistent across all evaluations, but it may not be.
eval_res_type = check_typeof_vals(evals)
if eval_res_type == MetricType.Numeric:
# Store items with summary of mean, median, etc
resp_obj['eval_res'] = {
'mean': mean(evals),
'median': median(evals),
'stdev': stdev(evals) if len(evals) > 1 else 0,
'range': (min(evals), max(evals)),
'items': evals,
'dtype': eval_res_type.name,
}
elif eval_res_type in (MetricType.Unknown, MetricType.Empty):
raise Exception('Unsupported types found in evaluation results. Only supported types for metrics are: int, float, bool, str.')
else:
# Categorical, KeyValue, etc, we just store the items:
resp_obj['eval_res'] = {
'items': evals,
'dtype': eval_res_type.name,
}
else:
# Run evaluator func over the entire response batch
ev = eval_func([
ResponseInfo(text=r,
prompt=resp_obj['prompt'],
var=resp_obj['vars'],
llm=resp_obj['llm'])
for r in res])
ev_type = check_typeof_vals([ev])
if ev_type == MetricType.Numeric:
resp_obj['eval_res'] = {
'mean': ev,
'median': ev,
'stdev': 0,
'range': (ev, ev),
'items': [ev],
'type': ev_type.name,
}
else:
resp_obj['eval_res'] = {
'items': [ev],
'type': ev_type.name,
}
return responses
async def make_sync_call_async(sync_method, *args, **params):
"""
Makes a blocking synchronous call asynchronous, so that it can be awaited.
NOTE: This is necessary for LLM APIs that do not yet support async (e.g. Google PaLM).
"""
loop = asyncio.get_running_loop()
method = sync_method
if len(params) > 0:
def partial_sync_meth(*a):
return sync_method(*a, **params)
method = partial_sync_meth
return await loop.run_in_executor(None, method, *args)
def exclude_key(d, key_to_exclude):
return {k: v for k, v in d.items() if k != key_to_exclude}
""" ===================
FLASK SERVER ROUTES
===================
"""
# Serve React app (static; no hot reloading)
@app.route("/")
def index():
# Get the index.html HTML code
html_str = render_template("index.html")
# Inject global JS variables __CF_HOSTNAME and __CF_PORT at the top so that the application knows
# that it's running from a Flask server, and what the hostname and port of that server is:
html_str = html_str[:60] + f'<script>window.__CF_HOSTNAME="{HOSTNAME}"; window.__CF_PORT={PORT};</script>' + html_str[60:]
return html_str
@app.route('/app/executepy', methods=['POST'])
def executepy():
"""
Executes a Python function sent from JavaScript,
over all the `StandardizedLLMResponse` objects passed in from the front-end.
POST'd data should be in the form:
{
'id': # a unique ID to refer to this information. Used when cache'ing responses.
'code': str, # the body of the lambda function to evaluate, in form: lambda responses: <body>
'responses': List[StandardizedLLMResponse] # the responses to run on.
'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch.
# If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt.
'script_paths': unspecified | List[str] # the paths to scripts to be added to the path before the lambda function is evaluated
}
NOTE: This should only be run on your server on code you trust.
There is no sandboxing; no safety. We assume you are the creator of the code.
"""
data = request.get_json()
# Check that all required info is here:
if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope'}):
return jsonify({'error': 'POST data is improper format.'})
if not isinstance(data['id'], str) or len(data['id']) == 0:
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
if data['scope'] not in ('response', 'batch'):
return jsonify({'error': "POST data scope is unknown. Must be either 'response' or 'batch'."})
# Check format of responses:
responses = data['responses']
if (isinstance(responses, str) or not isinstance(responses, list)) or (len(responses) > 0 and any([not isinstance(r, dict) for r in responses])):
return jsonify({'error': 'POST data responses is improper format.'})
# Add the path to any scripts to the path:
try:
if 'script_paths' in data:
for script_path in data['script_paths']:
# get the folder of the script_path:
script_folder = os.path.dirname(script_path)
# check that the script_folder is valid, and it contains __init__.py
if not os.path.exists(script_folder):
print(script_folder, 'is not a valid script path.')
continue
# add it to the path:
sys.path.append(script_folder)
print(f'added {script_folder} to sys.path')
except Exception as e:
return jsonify({'error': f'Could not add script path to sys.path. Error message:\n{str(e)}'})
# Create the evaluator function
# DANGER DANGER!
try:
exec(data['code'], globals())
# Double-check that there is an 'evaluate' method in our namespace.
# This will throw a NameError if not:
evaluate # noqa
except Exception as e:
return jsonify({'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
evald_responses = []
logs = []
try:
HIJACK_PYTHON_PRINT()
evald_responses = run_over_responses(evaluate, responses, scope=data['scope']) # noqa
logs = REVERT_PYTHON_PRINT()
except Exception as e:
logs = REVERT_PYTHON_PRINT()
return jsonify({'error': f'Error encountered while trying to run "evaluate" method:\n{str(e)}', 'logs': logs})
ret = jsonify({'responses': evald_responses, 'logs': logs})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@app.route('/app/fetchExampleFlow', methods=['POST'])
def fetchExampleFlow():
"""
Fetches the example flow data, given its filename. The filename should be the
name of a file in the examples/ folder of the package.
Used for loading examples in the Example Flow modal.
POST'd data should be in form:
{
name: <str> # The filename (without .cforge extension)
}
"""
# Verify post'd data
data = request.get_json()
if 'name' not in data:
return jsonify({'error': 'Missing "name" parameter to fetchExampleFlow.'})
# Verify 'examples' directory exists:
if not os.path.isdir(EXAMPLES_DIR):
dirpath = os.path.dirname(os.path.realpath(__file__))
return jsonify({'error': f'Could not find an examples/ directory at path {dirpath}'})
# Check if the file is there:
filepath = os.path.join(EXAMPLES_DIR, data['name'] + '.cforge')
if not os.path.isfile(filepath):
return jsonify({'error': f"Could not find an example flow named {data['name']}"})
# Load the file and return its data:
try:
with open(filepath, 'r', encoding='utf-8') as f:
filedata = json.load(f)
except Exception as e:
return jsonify({'error': f"Error parsing example flow at {filepath}: {str(e)}"})
ret = jsonify({'data': filedata})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@app.route('/app/fetchOpenAIEval', methods=['POST'])
def fetchOpenAIEval():
"""
Fetches a preconverted OpenAI eval as a .cforge JSON file.
First detects if the eval is already in the cache. If the eval is already downloaded,
it will be stored in examples/ folder of the package under a new oaievals directory.
If it's not in the cache, it will download it from the ChainForge webserver.
POST'd data should be in form:
{
'name': <str> # The name of the eval to grab (without .cforge extension)
}
"""
# Verify post'd data
data = request.get_json()
if 'name' not in data:
return jsonify({'error': 'Missing "name" parameter to fetchOpenAIEval.'})
evalname = data['name']
# Verify 'examples' directory exists:
if not os.path.isdir(EXAMPLES_DIR):
dirpath = os.path.dirname(os.path.realpath(__file__))
return jsonify({'error': f'Could not find an examples/ directory at path {dirpath}'})
# Check if an oaievals subdirectory exists; if so, check for the file; if not create it:
oaievals_cache_dir = os.path.join(EXAMPLES_DIR, "oaievals")
if os.path.isdir(oaievals_cache_dir):
filepath = os.path.join(oaievals_cache_dir, evalname + '.cforge')
if os.path.isfile(filepath):
# File was already downloaded. Load it from cache:
try:
with open(filepath, 'r', encoding='utf-8') as f:
filedata = json.load(f)
except Exception as e:
return jsonify({'error': f"Error parsing OpenAI evals flow at {filepath}: {str(e)}"})
ret = jsonify({'data': filedata})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
# File was not downloaded
else:
# Directory does not exist yet; create it
try:
os.mkdir(oaievals_cache_dir)
except Exception as e:
return jsonify({'error': f"Error creating a new directory 'oaievals' at filepath {oaievals_cache_dir}: {str(e)}"})
# Download the preconverted OpenAI eval from the GitHub main branch for ChainForge
_url = f"https://raw.githubusercontent.com/ianarawjo/ChainForge/main/chainforge/oaievals/{evalname}.cforge"
response = py_requests.get(_url)
# Check if the request was successful (status code 200)
if response.status_code == 200:
# Parse the response as JSON
filedata = response.json()
# Store to the cache:
with open(os.path.join(oaievals_cache_dir, evalname + '.cforge'), 'w', encoding='utf8') as f:
json.dump(filedata, f)
else:
print("Error:", response.status_code)
return jsonify({'error': f"Error downloading OpenAI evals flow from {_url}: status code {response.status_code}"})
ret = jsonify({'data': filedata})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@app.route('/app/fetchEnvironAPIKeys', methods=['POST'])
def fetchEnvironAPIKeys():
keymap = {
'OPENAI_API_KEY': 'OpenAI',
'ANTHROPIC_API_KEY': 'Anthropic',
'PALM_API_KEY': 'Google',
'HUGGINGFACE_API_KEY': 'HuggingFace',
'AZURE_OPENAI_KEY': 'Azure_OpenAI',
'AZURE_OPENAI_ENDPOINT': 'Azure_OpenAI_Endpoint',
'ALEPH_ALPHA_API_KEY': 'AlephAlpha'
}
d = { alias: os.environ.get(key) for key, alias in keymap.items() }
ret = jsonify(d)
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@app.route('/app/makeFetchCall', methods=['POST'])
def makeFetchCall():
"""
Use in place of JavaScript's 'fetch' (with POST method), in cases where
cross-origin policy blocks responses from client-side fetches.
POST'd data should be in form:
{
'url': <str> # the url to fetch from
'headers': <dict> # a JSON object of the headers
'body': <dict> # the request payload, as JSON
}
"""
# Verify post'd data
data = request.get_json()
if not set(data.keys()).issuperset({'url', 'headers', 'body'}):
return jsonify({'error': 'POST data is improper format.'})
url = data['url']
headers = data['headers']
body = data['body']
response = py_requests.post(url, headers=headers, json=body)
if response.status_code == 200:
ret = jsonify({'response': response.json()})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
else:
return jsonify({'error': 'API request to Anthropic failed'})
@app.route('/app/callDalai', methods=['POST'])
async def callDalai():
"""
Fetch response from a Dalai-hosted model (Alpaca or Llama).
Requires Python backend since depends on custom library code to extract response.
POST'd data should be a dict of keyword arguments to provide the call_dalai method.
"""
# Verify post'd data
data = request.get_json()
if not set(data.keys()).issuperset({'prompt', 'model', 'server', 'n', 'temperature'}):
return jsonify({'error': 'POST data is improper format.'})
try:
query, response = await call_dalai(**data)
except Exception as e:
return jsonify({'error': str(e)})
ret = jsonify({'query': query, 'response': response})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@app.route('/app/initCustomProvider', methods=['POST'])
def initCustomProvider():
"""
Initalizes custom model provider(s) defined in a Python script,
and returns specs for the front-end UI provider dropdown and the providers' settings window.
POST'd data should be in form:
{
'code': <str> # the Python script to save + execute,
}
"""
# Verify post'd data
data = request.get_json()
if 'code' not in data:
return jsonify({'error': 'POST data is improper format.'})
# Sanity check that the code actually registers a provider
if '@provider' not in data['code']:
return jsonify({'error': """Did not detect a @provider decorator. Custom provider scripts should register at least one @provider.
Do `from chainforge.providers import provider` and decorate your provider completion function with @provider."""})
# Establish the custom provider script cache directory
provider_scripts_dir = os.path.join(CACHE_DIR, "provider_scripts")
if not os.path.isdir(provider_scripts_dir):
# Create the directory
try:
os.mkdir(provider_scripts_dir)
except Exception as e:
return jsonify({'error': f"Error creating a new directory 'provider_scripts' at filepath {provider_scripts_dir}: {str(e)}"})
# For keeping track of what script registered providers came from
script_id = str(round(time.time()*1000))
ProviderRegistry.set_curr_script_id(script_id)
ProviderRegistry.watch_next_registered()
# Attempt to run the Python script, in context
try:
exec(data['code'], globals(), None)
# This should have registered one or more new CustomModelProviders.
except Exception as e:
return jsonify({'error': f'Error while executing custom provider code:\n{str(e)}'})
# Check whether anything was updated, and what
new_registries = ProviderRegistry.last_registered()
if len(new_registries) == 0: # Determine whether there's at least one custom provider.
return jsonify({'error': 'Did not detect any custom providers added to the registry. Make sure you are registering your provider with @provider correctly.'})
# At least one provider was registered; detect if it had a past script id and remove those file(s) from the cache
if any((v is not None for v in new_registries.values())):
# For every registered provider that was overwritten, remove the cache'd script(s) associated with it:
past_script_ids = [v for v in new_registries.values() if v is not None]
for sid in past_script_ids:
past_script_path = os.path.join(provider_scripts_dir, f"{sid}.py")
try:
if os.path.isfile(past_script_path):
os.remove(past_script_path)
except Exception as e:
return jsonify({'error': f"Error removing cache'd custom provider script at filepath {past_script_path}: {str(e)}"})
# Get the names and specs of all currently registered CustomModelProviders,
# and pass that info to the front-end (excluding the func):
registered_providers = [exclude_key(d, 'func') for d in ProviderRegistry.get_all()]
# Copy the passed Python script to a local file in the package directory
try:
with open(os.path.join(provider_scripts_dir, f"{script_id}.py"), 'w') as f:
f.write(data['code'])
except Exception as e:
return jsonify({'error': f"Error saving script 'provider_scripts' at filepath {provider_scripts_dir}: {str(e)}"})
# Return all loaded providers
return jsonify({'providers': registered_providers})
@app.route('/app/loadCachedCustomProviders', methods=['POST'])
def loadCachedCustomProviders():
"""
Initalizes all custom model provider(s) in the local provider_scripts directory.
"""
provider_scripts_dir = os.path.join(CACHE_DIR, "provider_scripts")
if not os.path.isdir(provider_scripts_dir):
# No providers to load.
return jsonify({'providers': []})
try:
for file_name in os.listdir(provider_scripts_dir):
file_path = os.path.join(provider_scripts_dir, file_name)
if os.path.isfile(file_path) and os.path.splitext(file_path)[1] == '.py':
# For keeping track of what script registered providers came from
ProviderRegistry.set_curr_script_id(os.path.splitext(file_name)[0])
# Read the Python script
with open(file_path, 'r') as f:
code = f.read()
# Try to execute it in the global context
try:
exec(code, globals(), None)
except Exception as code_exc:
# Remove the script file associated w the failed execution
os.remove(file_path)
raise code_exc
except Exception as e:
return jsonify({'error': f'Error while loading custom providers from cache: \n{str(e)}'})
# Get the names and specs of all currently registered CustomModelProviders,
# and pass that info to the front-end (excluding the func):
registered_providers = [exclude_key(d, 'func') for d in ProviderRegistry.get_all()]
return jsonify({'providers': registered_providers})
@app.route('/app/removeCustomProvider', methods=['POST'])
def removeCustomProvider():
"""
Initalizes custom model provider(s) defined in a Python script,
and returns specs for the front-end UI provider dropdown and the providers' settings window.
POST'd data should be in form:
{
'name': <str> # a name that refers to the registered custom provider in the `ProviderRegistry`
}
"""
# Verify post'd data
data = request.get_json()
name = data.get('name')
if name is None:
return jsonify({'error': 'POST data is improper format.'})
if not ProviderRegistry.has(name):
return jsonify({'error': f'Could not find a custom provider named "{name}"'})
# Get the script id associated with the provider we're about to remove
script_id = ProviderRegistry.get(name).get('script_id')
# Remove the custom provider from the registry
ProviderRegistry.remove(name)
# Attempt to delete associated script from cache
if script_id:
script_path = os.path.join(CACHE_DIR, "provider_scripts", f"{script_id}.py")
if os.path.isfile(script_path):
os.remove(script_path)
return jsonify({'success': True})
@app.route('/app/callCustomProvider', methods=['POST'])
async def callCustomProvider():
"""
Calls a custom model provider and returns the response.
POST'd data should be in form:
{
'name': <str> # the name of the provider in the `ProviderRegistry`
'params': <dict> # the params (prompt, model, etc) to pass to the provider function.
}
"""
# Verify post'd data
data = request.get_json()
if not set(data.keys()).issuperset({'name', 'params'}):
return jsonify({'error': 'POST data is improper format.'})
# Load the name of the provider
name = data['name']
params = data['params']
# Double-check that the custom provider exists in the registry, and (if passed) a model with that name exists
provider_spec = ProviderRegistry.get(name)
if provider_spec is None:
return jsonify({'error': f'Could not find provider named {name}. Perhaps you need to import a custom provider script?'})
# Call + await the custom provider function, passing in the JSON payload as kwargs
try:
response = await make_sync_call_async(provider_spec.get('func'), **params)
except Exception as e:
return jsonify({'error': f'Error encountered while calling custom provider function: {str(e)}'})
# Return the response
return jsonify({'response': response})
def run_server(host="", port=8000, cmd_args=None):
global HOSTNAME, PORT
HOSTNAME = host
PORT = port
app.run(host=host, port=port)
if __name__ == '__main__':
print("Run app.py instead.")