mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 08:16:37 +00:00
Custom providers (#122)
* Removed Python backend files that are no longer used (everything in `promptengine`) * Added `providers` subdomain, with `CustomProviderProtocol`, `provider` decorator, and global singleton `ProviderRegistry` * Added a tab for custom providers, and a dropzone, in ChainForge global settings UI * List custom providers in the Global Settings screen once added. * Added ability to remove custom providers by clicking X. * Make custom funcs sync but call them async'ly. * Add Cohere custom provider example in examples/ *Cache the custom provider scripts and load them upon page load * Rebuild react and update package version * Bug fix when custom provider is deleted and settings screen is opened on the deleted custom provider
This commit is contained in:
parent
f43861f075
commit
0134dbf59b
55
chainforge/examples/custom_provider_cohere.py
Normal file
55
chainforge/examples/custom_provider_cohere.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""
|
||||
A simple custom model provider to add to the ChainForge interface,
|
||||
to support Cohere AI text completions through their Python API.
|
||||
|
||||
NOTE: You must have the `cohere` package installed and an API key.
|
||||
"""
|
||||
from chainforge.providers import provider
|
||||
import cohere
|
||||
|
||||
# Init the Cohere client (replace with your Cohere API Key)
|
||||
co = cohere.Client('<YOUR_API_KEY>')
|
||||
|
||||
# JSON schemas to pass react-jsonschema-form, one for this endpoints' settings and one to describe the settings UI.
|
||||
COHERE_SETTINGS_SCHEMA = {
|
||||
"settings": {
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"title": "temperature",
|
||||
"description": "Controls the 'creativity' or randomness of the response.",
|
||||
"default": 0.75,
|
||||
"minimum": 0,
|
||||
"maximum": 5.0,
|
||||
"multipleOf": 0.01,
|
||||
},
|
||||
"max_tokens": {
|
||||
"type": "integer",
|
||||
"title": "max_tokens",
|
||||
"description": "Maximum number of tokens to generate in the response.",
|
||||
"default": 100,
|
||||
"minimum": 1,
|
||||
"maximum": 1024,
|
||||
},
|
||||
},
|
||||
"ui": {
|
||||
"temperature": {
|
||||
"ui:help": "Defaults to 1.0.",
|
||||
"ui:widget": "range"
|
||||
},
|
||||
"max_tokens": {
|
||||
"ui:help": "Defaults to 100.",
|
||||
"ui:widget": "range"
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# Our custom model provider for Cohere's text generation API.
|
||||
@provider(name="Cohere",
|
||||
emoji="🖇",
|
||||
models=['command', 'command-nightly', 'command-light', 'command-light-nightly'],
|
||||
rate_limit="sequential", # enter "sequential" for blocking; an integer N > 0 means N is the max mumber of requests per minute.
|
||||
settings_schema=COHERE_SETTINGS_SCHEMA)
|
||||
def CohereCompletion(prompt: str, model: str, temperature: float = 0.75, **kwargs) -> str:
|
||||
print(f"Calling Cohere model {model} with prompt '{prompt}'...")
|
||||
response = co.generate(model=model, prompt=prompt, temperature=temperature, **kwargs)
|
||||
return response.generations[0].text
|
@ -1,11 +1,12 @@
|
||||
import json, os, asyncio, sys, traceback
|
||||
import json, os, sys, asyncio, time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from statistics import mean, median, stdev
|
||||
from flask import Flask, request, jsonify, render_template
|
||||
from flask_cors import CORS
|
||||
from chainforge.promptengine.utils import LLM, call_dalai
|
||||
from chainforge.providers.dalai import call_dalai
|
||||
from chainforge.providers import ProviderRegistry
|
||||
import requests as py_requests
|
||||
|
||||
""" =================
|
||||
@ -16,6 +17,7 @@ import requests as py_requests
|
||||
# Setup Flask app to serve static version of React front-end
|
||||
HOSTNAME = "localhost"
|
||||
PORT = 8000
|
||||
# SESSION_TOKEN = secrets.token_hex(32)
|
||||
BUILD_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'react-server', 'build')
|
||||
STATIC_DIR = os.path.join(BUILD_DIR, 'static')
|
||||
app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
|
||||
@ -27,10 +29,6 @@ cors = CORS(app, resources={r"/*": {"origins": "*"}})
|
||||
CACHE_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'cache')
|
||||
EXAMPLES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'examples')
|
||||
|
||||
LLM_NAME_MAP = {}
|
||||
for model in LLM:
|
||||
LLM_NAME_MAP[model.value] = model
|
||||
|
||||
class MetricType(Enum):
|
||||
KeyValue = 0
|
||||
KeyValue_Numeric = 1
|
||||
@ -221,6 +219,22 @@ def run_over_responses(eval_func, responses: list, scope: str) -> list:
|
||||
}
|
||||
return responses
|
||||
|
||||
async def make_sync_call_async(sync_method, *args, **params):
|
||||
"""
|
||||
Makes a blocking synchronous call asynchronous, so that it can be awaited.
|
||||
NOTE: This is necessary for LLM APIs that do not yet support async (e.g. Google PaLM).
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
method = sync_method
|
||||
if len(params) > 0:
|
||||
def partial_sync_meth(*a):
|
||||
return sync_method(*a, **params)
|
||||
method = partial_sync_meth
|
||||
return await loop.run_in_executor(None, method, *args)
|
||||
|
||||
def exclude_key(d, key_to_exclude):
|
||||
return {k: v for k, v in d.items() if k != key_to_exclude}
|
||||
|
||||
|
||||
""" ===================
|
||||
FLASK SERVER ROUTES
|
||||
@ -261,7 +275,7 @@ def executepy():
|
||||
data = request.get_json()
|
||||
|
||||
# Check that all required info is here:
|
||||
if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope'}):
|
||||
if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope', 'token'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
if not isinstance(data['id'], str) or len(data['id']) == 0:
|
||||
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
|
||||
@ -273,7 +287,7 @@ def executepy():
|
||||
if (isinstance(responses, str) or not isinstance(responses, list)) or (len(responses) > 0 and any([not isinstance(r, dict) for r in responses])):
|
||||
return jsonify({'error': 'POST data responses is improper format.'})
|
||||
|
||||
# add the path to any scripts to the path:
|
||||
# Add the path to any scripts to the path:
|
||||
try:
|
||||
if 'script_paths' in data:
|
||||
for script_path in data['script_paths']:
|
||||
@ -367,7 +381,7 @@ def fetchOpenAIEval():
|
||||
|
||||
POST'd data should be in form:
|
||||
{
|
||||
name: <str> # The name of the eval to grab (without .cforge extension)
|
||||
'name': <str> # The name of the eval to grab (without .cforge extension)
|
||||
}
|
||||
"""
|
||||
# Verify post'd data
|
||||
@ -404,9 +418,8 @@ def fetchOpenAIEval():
|
||||
return jsonify({'error': f"Error creating a new directory 'oaievals' at filepath {oaievals_cache_dir}: {str(e)}"})
|
||||
|
||||
# Download the preconverted OpenAI eval from the GitHub main branch for ChainForge
|
||||
import requests
|
||||
_url = f"https://raw.githubusercontent.com/ianarawjo/ChainForge/main/chainforge/oaievals/{evalname}.cforge"
|
||||
response = requests.get(_url)
|
||||
response = py_requests.get(_url)
|
||||
|
||||
# Check if the request was successful (status code 200)
|
||||
if response.status_code == 200:
|
||||
@ -449,9 +462,9 @@ def makeFetchCall():
|
||||
|
||||
POST'd data should be in form:
|
||||
{
|
||||
url: <str> # the url to fetch from
|
||||
headers: <dict> # a JSON object of the headers
|
||||
body: <dict> # the request payload, as JSON
|
||||
'url': <str> # the url to fetch from
|
||||
'headers': <dict> # a JSON object of the headers
|
||||
'body': <dict> # the request payload, as JSON
|
||||
}
|
||||
"""
|
||||
# Verify post'd data
|
||||
@ -486,8 +499,6 @@ async def callDalai():
|
||||
if not set(data.keys()).issuperset({'prompt', 'model', 'server', 'n', 'temperature'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
|
||||
data['model'] = LLM_NAME_MAP[data['model']]
|
||||
|
||||
try:
|
||||
query, response = await call_dalai(**data)
|
||||
except Exception as e:
|
||||
@ -498,6 +509,189 @@ async def callDalai():
|
||||
return ret
|
||||
|
||||
|
||||
@app.route('/app/initCustomProvider', methods=['POST'])
|
||||
def initCustomProvider():
|
||||
"""
|
||||
Initalizes custom model provider(s) defined in a Python script,
|
||||
and returns specs for the front-end UI provider dropdown and the providers' settings window.
|
||||
|
||||
POST'd data should be in form:
|
||||
{
|
||||
'code': <str> # the Python script to save + execute,
|
||||
}
|
||||
"""
|
||||
# Verify post'd data
|
||||
data = request.get_json()
|
||||
if 'code' not in data:
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
|
||||
# Sanity check that the code actually registers a provider
|
||||
if '@provider' not in data['code']:
|
||||
return jsonify({'error': """Did not detect a @provider decorator. Custom provider scripts should register at least one @provider.
|
||||
Do `from chainforge.providers import provider` and decorate your provider completion function with @provider."""})
|
||||
|
||||
# Establish the custom provider script cache directory
|
||||
provider_scripts_dir = os.path.join(CACHE_DIR, "provider_scripts")
|
||||
if not os.path.isdir(provider_scripts_dir):
|
||||
# Create the directory
|
||||
try:
|
||||
os.mkdir(provider_scripts_dir)
|
||||
except Exception as e:
|
||||
return jsonify({'error': f"Error creating a new directory 'provider_scripts' at filepath {provider_scripts_dir}: {str(e)}"})
|
||||
|
||||
# For keeping track of what script registered providers came from
|
||||
script_id = str(round(time.time()*1000))
|
||||
ProviderRegistry.set_curr_script_id(script_id)
|
||||
ProviderRegistry.watch_next_registered()
|
||||
|
||||
# Attempt to run the Python script, in context
|
||||
try:
|
||||
exec(data['code'], globals(), None)
|
||||
|
||||
# This should have registered one or more new CustomModelProviders.
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Error while executing custom provider code:\n{str(e)}'})
|
||||
|
||||
# Check whether anything was updated, and what
|
||||
new_registries = ProviderRegistry.last_registered()
|
||||
if len(new_registries) == 0: # Determine whether there's at least one custom provider.
|
||||
return jsonify({'error': 'Did not detect any custom providers added to the registry. Make sure you are registering your provider with @provider correctly.'})
|
||||
|
||||
# At least one provider was registered; detect if it had a past script id and remove those file(s) from the cache
|
||||
if any((v is not None for v in new_registries.values())):
|
||||
# For every registered provider that was overwritten, remove the cache'd script(s) associated with it:
|
||||
past_script_ids = [v for v in new_registries.values() if v is not None]
|
||||
for sid in past_script_ids:
|
||||
past_script_path = os.path.join(provider_scripts_dir, f"{sid}.py")
|
||||
try:
|
||||
if os.path.isfile(past_script_path):
|
||||
os.remove(past_script_path)
|
||||
except Exception as e:
|
||||
return jsonify({'error': f"Error removing cache'd custom provider script at filepath {past_script_path}: {str(e)}"})
|
||||
|
||||
# Get the names and specs of all currently registered CustomModelProviders,
|
||||
# and pass that info to the front-end (excluding the func):
|
||||
registered_providers = [exclude_key(d, 'func') for d in ProviderRegistry.get_all()]
|
||||
|
||||
# Copy the passed Python script to a local file in the package directory
|
||||
try:
|
||||
with open(os.path.join(provider_scripts_dir, f"{script_id}.py"), 'w') as f:
|
||||
f.write(data['code'])
|
||||
except Exception as e:
|
||||
return jsonify({'error': f"Error saving script 'provider_scripts' at filepath {provider_scripts_dir}: {str(e)}"})
|
||||
|
||||
# Return all loaded providers
|
||||
return jsonify({'providers': registered_providers})
|
||||
|
||||
|
||||
@app.route('/app/loadCachedCustomProviders', methods=['POST'])
|
||||
def loadCachedCustomProviders():
|
||||
"""
|
||||
Initalizes all custom model provider(s) in the local provider_scripts directory.
|
||||
"""
|
||||
provider_scripts_dir = os.path.join(CACHE_DIR, "provider_scripts")
|
||||
if not os.path.isdir(provider_scripts_dir):
|
||||
# No providers to load.
|
||||
return jsonify({'providers': []})
|
||||
|
||||
try:
|
||||
for file_name in os.listdir(provider_scripts_dir):
|
||||
file_path = os.path.join(provider_scripts_dir, file_name)
|
||||
if os.path.isfile(file_path) and os.path.splitext(file_path)[1] == '.py':
|
||||
# For keeping track of what script registered providers came from
|
||||
ProviderRegistry.set_curr_script_id(os.path.splitext(file_name)[0])
|
||||
|
||||
# Read the Python script
|
||||
with open(file_path, 'r') as f:
|
||||
code = f.read()
|
||||
|
||||
# Try to execute it in the global context
|
||||
try:
|
||||
exec(code, globals(), None)
|
||||
except Exception as code_exc:
|
||||
# Remove the script file associated w the failed execution
|
||||
os.remove(file_path)
|
||||
raise code_exc
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Error while loading custom providers from cache: \n{str(e)}'})
|
||||
|
||||
# Get the names and specs of all currently registered CustomModelProviders,
|
||||
# and pass that info to the front-end (excluding the func):
|
||||
registered_providers = [exclude_key(d, 'func') for d in ProviderRegistry.get_all()]
|
||||
|
||||
return jsonify({'providers': registered_providers})
|
||||
|
||||
|
||||
@app.route('/app/removeCustomProvider', methods=['POST'])
|
||||
def removeCustomProvider():
|
||||
"""
|
||||
Initalizes custom model provider(s) defined in a Python script,
|
||||
and returns specs for the front-end UI provider dropdown and the providers' settings window.
|
||||
|
||||
POST'd data should be in form:
|
||||
{
|
||||
'name': <str> # a name that refers to the registered custom provider in the `ProviderRegistry`
|
||||
}
|
||||
"""
|
||||
# Verify post'd data
|
||||
data = request.get_json()
|
||||
name = data.get('name')
|
||||
if name is None:
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
|
||||
if not ProviderRegistry.has(name):
|
||||
return jsonify({'error': f'Could not find a custom provider named "{name}"'})
|
||||
|
||||
# Get the script id associated with the provider we're about to remove
|
||||
script_id = ProviderRegistry.get(name).get('script_id')
|
||||
|
||||
# Remove the custom provider from the registry
|
||||
ProviderRegistry.remove(name)
|
||||
|
||||
# Attempt to delete associated script from cache
|
||||
if script_id:
|
||||
script_path = os.path.join(CACHE_DIR, "provider_scripts", f"{script_id}.py")
|
||||
if os.path.isfile(script_path):
|
||||
os.remove(script_path)
|
||||
|
||||
return jsonify({'success': True})
|
||||
|
||||
|
||||
@app.route('/app/callCustomProvider', methods=['POST'])
|
||||
async def callCustomProvider():
|
||||
"""
|
||||
Calls a custom model provider and returns the response.
|
||||
|
||||
POST'd data should be in form:
|
||||
{
|
||||
'name': <str> # the name of the provider in the `ProviderRegistry`
|
||||
'params': <dict> # the params (prompt, model, etc) to pass to the provider function.
|
||||
}
|
||||
"""
|
||||
# Verify post'd data
|
||||
data = request.get_json()
|
||||
if not set(data.keys()).issuperset({'name', 'params'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
|
||||
# Load the name of the provider
|
||||
name = data['name']
|
||||
params = data['params']
|
||||
|
||||
# Double-check that the custom provider exists in the registry, and (if passed) a model with that name exists
|
||||
provider_spec = ProviderRegistry.get(name)
|
||||
if provider_spec is None:
|
||||
return jsonify({'error': f'Could not find provider named {name}. Perhaps you need to import a custom provider script?'})
|
||||
|
||||
# Call + await the custom provider function, passing in the JSON payload as kwargs
|
||||
try:
|
||||
response = await make_sync_call_async(provider_spec.get('func'), **params)
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Error encountered while calling custom provider function: {str(e)}'})
|
||||
|
||||
# Return the response
|
||||
return jsonify({'response': response})
|
||||
|
||||
|
||||
def run_server(host="", port=8000, cmd_args=None):
|
||||
global HOSTNAME, PORT
|
||||
HOSTNAME = host
|
||||
|
@ -1 +0,0 @@
|
||||
|
@ -1,82 +0,0 @@
|
||||
"""
|
||||
A list of all model APIs natively supported by ChainForge.
|
||||
"""
|
||||
from enum import Enum
|
||||
|
||||
class LLM(str, Enum):
|
||||
""" OpenAI Chat """
|
||||
OpenAI_ChatGPT = "gpt-3.5-turbo"
|
||||
OpenAI_ChatGPT_16k = "gpt-3.5-turbo-16k"
|
||||
OpenAI_ChatGPT_16k_0613 = "gpt-3.5-turbo-16k-0613"
|
||||
OpenAI_ChatGPT_0301 = "gpt-3.5-turbo-0301"
|
||||
OpenAI_ChatGPT_0613 = "gpt-3.5-turbo-0613"
|
||||
OpenAI_GPT4 = "gpt-4"
|
||||
OpenAI_GPT4_0314 = "gpt-4-0314"
|
||||
OpenAI_GPT4_0613 = "gpt-4-0613"
|
||||
OpenAI_GPT4_32k = "gpt-4-32k"
|
||||
OpenAI_GPT4_32k_0314 = "gpt-4-32k-0314"
|
||||
OpenAI_GPT4_32k_0613 = "gpt-4-32k-0613"
|
||||
|
||||
""" OpenAI Text Completions """
|
||||
OpenAI_Davinci003 = "text-davinci-003"
|
||||
OpenAI_Davinci002 = "text-davinci-002"
|
||||
|
||||
""" Azure OpenAI Endpoints """
|
||||
Azure_OpenAI = "azure-openai"
|
||||
|
||||
""" Dalai-served models (Alpaca and Llama) """
|
||||
Dalai_Alpaca_7B = "alpaca.7B"
|
||||
Dalai_Alpaca_13B = "alpaca.13B"
|
||||
Dalai_Llama_7B = "llama.7B"
|
||||
Dalai_Llama_13B = "llama.13B"
|
||||
Dalai_Llama_30B = "llama.30B"
|
||||
Dalai_Llama_65B = "llama.65B"
|
||||
|
||||
""" Anthropic """
|
||||
# Our largest model, ideal for a wide range of more complex tasks. Using this model name
|
||||
# will automatically switch you to newer versions of claude-v1 as they are released.
|
||||
Claude_v1 = "claude-v1"
|
||||
|
||||
# An earlier version of claude-v1
|
||||
Claude_v1_0 = "claude-v1.0"
|
||||
|
||||
# An improved version of claude-v1. It is slightly improved at general helpfulness,
|
||||
# instruction following, coding, and other tasks. It is also considerably better with
|
||||
# non-English languages. This model also has the ability to role play (in harmless ways)
|
||||
# more consistently, and it defaults to writing somewhat longer and more thorough responses.
|
||||
Claude_v1_2 = "claude-v1.2"
|
||||
|
||||
# A significantly improved version of claude-v1. Compared to claude-v1.2, it's more robust
|
||||
# against red-team inputs, better at precise instruction-following, better at code, and better
|
||||
# and non-English dialogue and writing.
|
||||
Claude_v1_3 = "claude-v1.3"
|
||||
|
||||
# A smaller model with far lower latency, sampling at roughly 40 words/sec! Its output quality
|
||||
# is somewhat lower than claude-v1 models, particularly for complex tasks. However, it is much
|
||||
# less expensive and blazing fast. We believe that this model provides more than adequate performance
|
||||
# on a range of tasks including text classification, summarization, and lightweight chat applications,
|
||||
# as well as search result summarization. Using this model name will automatically switch you to newer
|
||||
# versions of claude-instant-v1 as they are released.
|
||||
Claude_v1_instant = "claude-instant-v1"
|
||||
|
||||
""" Google models """
|
||||
PaLM2_Text_Bison = "text-bison-001" # it's really models/text-bison-001, but that's confusing
|
||||
PaLM2_Chat_Bison = "chat-bison-001"
|
||||
|
||||
|
||||
# LLM APIs often have rate limits, which control number of requests. E.g., OpenAI: https://platform.openai.com/account/rate-limits
|
||||
# For a basic organization in OpenAI, GPT3.5 is currently 3500 and GPT4 is 200 RPM (requests per minute).
|
||||
# For Anthropic evaluaton preview of Claude, can only send 1 request at a time (synchronously).
|
||||
# This 'cheap' version of controlling for rate limits is to wait a few seconds between batches of requests being sent off.
|
||||
# If a model is missing from below, it means we must send and receive only 1 request at a time (synchronous).
|
||||
# The following is only a guideline, and a bit on the conservative side.
|
||||
RATE_LIMITS = {
|
||||
LLM.OpenAI_ChatGPT: (30, 10), # max 30 requests a batch; wait 10 seconds between
|
||||
LLM.OpenAI_ChatGPT_0301: (30, 10),
|
||||
LLM.OpenAI_GPT4: (4, 15), # max 4 requests a batch; wait 15 seconds between
|
||||
LLM.OpenAI_GPT4_0314: (4, 15),
|
||||
LLM.OpenAI_GPT4_32k: (4, 15),
|
||||
LLM.OpenAI_GPT4_32k_0314: (4, 15),
|
||||
LLM.PaLM2_Text_Bison: (4, 10), # max 30 requests per minute; so do 4 per batch, 10 seconds between (conservative)
|
||||
LLM.PaLM2_Chat_Bison: (4, 10),
|
||||
}
|
@ -1,276 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Dict, Tuple, Iterator, Union, Optional
|
||||
import json, os, asyncio, random, string
|
||||
from chainforge.promptengine.utils import call_chatgpt, call_dalai, call_anthropic, call_google_palm, call_azure_openai, is_valid_filepath, is_valid_json, extract_responses, merge_response_objs
|
||||
from chainforge.promptengine.models import LLM, RATE_LIMITS
|
||||
from chainforge.promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
|
||||
class LLMResponseException(Exception):
|
||||
""" Raised when there is an error generating a single response from an LLM """
|
||||
pass
|
||||
|
||||
"""
|
||||
Abstract class that captures a generic querying interface to prompt LLMs
|
||||
"""
|
||||
class PromptPipeline:
|
||||
def __init__(self, storageFile: str):
|
||||
if not is_valid_filepath(storageFile):
|
||||
raise IOError(f"Filepath {storageFile} is invalid, or you do not have write access.")
|
||||
|
||||
self._filepath = storageFile
|
||||
|
||||
@abstractmethod
|
||||
def gen_prompts(self, properties) -> Iterator[PromptTemplate]:
|
||||
raise NotImplementedError("Please Implement the gen_prompts method")
|
||||
|
||||
async def gen_responses(self, properties, llm: LLM, n: int = 1, temperature: float = 1.0, **llm_params) -> Iterator[Union[Dict, LLMResponseException]]:
|
||||
"""
|
||||
Calls LLM 'llm' with all prompts, and yields responses as dicts in format {prompt, query, response, llm, info}.
|
||||
|
||||
Queries are sent off asynchronously (if possible).
|
||||
Yields responses as they come in. All LLM calls that yield errors (e.g., 'rate limit' error)
|
||||
will yield an individual LLMResponseException, so downstream tasks must check for this exception type.
|
||||
|
||||
By default, for each response successfully collected, this also saves reponses to disk as JSON at the filepath given during init.
|
||||
(Very useful for saving money in case something goes awry!)
|
||||
To clear the cached responses, call clear_cached_responses().
|
||||
|
||||
NOTE: The reason we collect, rather than raise, LLMResponseExceptions is because some API calls
|
||||
may still succeed, even if some fail. We don't want to stop listening to pending API calls,
|
||||
because we may lose money. Instead, we fail selectively.
|
||||
|
||||
Do not override this function.
|
||||
"""
|
||||
# Double-check that properties is the correct type (JSON dict):
|
||||
if not is_valid_json(properties):
|
||||
raise ValueError("Properties argument is not valid JSON.")
|
||||
|
||||
# Load any cache'd responses
|
||||
responses = self._load_cached_responses()
|
||||
|
||||
# Query LLM with each prompt, yield + cache the responses
|
||||
tasks = []
|
||||
max_req, wait_secs = RATE_LIMITS[llm] if llm in RATE_LIMITS else (1, 0)
|
||||
num_queries_sent = -1
|
||||
|
||||
for prompt in self.gen_prompts(properties):
|
||||
if isinstance(prompt, PromptTemplate) and not prompt.is_concrete():
|
||||
raise Exception(f"Cannot send a prompt '{prompt}' to LLM: Prompt is a template.")
|
||||
|
||||
prompt_str = str(prompt)
|
||||
info = prompt.fill_history
|
||||
metavars = prompt.metavars
|
||||
|
||||
cached_resp = responses[prompt_str] if prompt_str in responses else None
|
||||
extracted_resps = cached_resp["responses"] if cached_resp is not None else []
|
||||
|
||||
# First check if there is already a response for this item under these settings. If so, we can save an LLM call:
|
||||
if cached_resp and len(extracted_resps) >= n:
|
||||
print(f" - Found cache'd response for prompt {prompt_str}. Using...")
|
||||
yield {
|
||||
"prompt": prompt_str,
|
||||
"query": cached_resp["query"],
|
||||
"responses": extracted_resps[:n],
|
||||
"raw_response": cached_resp["raw_response"],
|
||||
"llm": cached_resp["llm"] if "llm" in cached_resp else LLM.OpenAI_ChatGPT.value,
|
||||
# We want to use the new info, since 'vars' could have changed even though
|
||||
# the prompt text is the same (e.g., "this is a tool -> this is a {x} where x='tool'")
|
||||
"info": info,
|
||||
"metavars": metavars
|
||||
}
|
||||
continue
|
||||
|
||||
num_queries_sent += 1
|
||||
|
||||
if max_req > 1:
|
||||
# Call the LLM asynchronously to generate a response, sending off
|
||||
# requests in batches of size 'max_req' separated by seconds 'wait_secs' to avoid hitting rate limit
|
||||
tasks.append(self._prompt_llm(llm=llm,
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
temperature=temperature,
|
||||
past_resp_obj=cached_resp,
|
||||
query_number=num_queries_sent,
|
||||
rate_limit_batch_size=max_req,
|
||||
rate_limit_wait_secs=wait_secs,
|
||||
**llm_params))
|
||||
else:
|
||||
# Block. Await + yield a single LLM call.
|
||||
_, query, response, past_resp_obj = await self._prompt_llm(llm=llm,
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
temperature=temperature,
|
||||
past_resp_obj=cached_resp,
|
||||
**llm_params)
|
||||
|
||||
# Check for selective failure
|
||||
if query is None and isinstance(response, LLMResponseException):
|
||||
yield response # yield the LLMResponseException
|
||||
continue
|
||||
|
||||
# Create a response obj to represent the response
|
||||
resp_obj = {
|
||||
"prompt": str(prompt),
|
||||
"query": query,
|
||||
"responses": extract_responses(response, llm),
|
||||
"raw_response": response,
|
||||
"llm": llm.value,
|
||||
"info": info,
|
||||
"metavars": metavars
|
||||
}
|
||||
|
||||
# Merge the response obj with the past one, if necessary
|
||||
if past_resp_obj is not None:
|
||||
resp_obj = merge_response_objs(resp_obj, past_resp_obj)
|
||||
|
||||
# Save the current state of cache'd responses to a JSON file
|
||||
responses[resp_obj["prompt"]] = resp_obj
|
||||
self._cache_responses(responses)
|
||||
|
||||
print(f" - collected response from {llm.value} for prompt:", resp_obj['prompt'])
|
||||
|
||||
# Yield the response
|
||||
yield resp_obj
|
||||
|
||||
# Yield responses as they come in
|
||||
for task in asyncio.as_completed(tasks):
|
||||
# Collect the response from the earliest completed task
|
||||
prompt, query, response, past_resp_obj = await task
|
||||
|
||||
# Check for selective failure
|
||||
if query is None and isinstance(response, LLMResponseException):
|
||||
yield response # yield the LLMResponseException
|
||||
continue
|
||||
|
||||
# Each prompt has a history of what was filled in from its base template.
|
||||
# This data --like, "class", "language", "library" etc --can be useful when parsing responses.
|
||||
info = prompt.fill_history
|
||||
metavars = prompt.metavars
|
||||
|
||||
# Create a response obj to represent the response
|
||||
resp_obj = {
|
||||
"prompt": str(prompt),
|
||||
"query": query,
|
||||
"responses": extract_responses(response, llm),
|
||||
"raw_response": response,
|
||||
"llm": llm.value,
|
||||
"info": info,
|
||||
"metavars": metavars,
|
||||
}
|
||||
|
||||
# Merge the response obj with the past one, if necessary
|
||||
if past_resp_obj is not None:
|
||||
resp_obj = merge_response_objs(resp_obj, past_resp_obj)
|
||||
|
||||
# Save the current state of cache'd responses to a JSON file
|
||||
# NOTE: We do this to save money --in case something breaks between calls, can ensure we got the data!
|
||||
responses[resp_obj["prompt"]] = resp_obj
|
||||
self._cache_responses(responses)
|
||||
|
||||
print(f" - collected response from {llm.value} for prompt:", resp_obj['prompt'])
|
||||
|
||||
# Yield the response
|
||||
yield resp_obj
|
||||
|
||||
def _load_cached_responses(self) -> Dict:
|
||||
"""
|
||||
Loads saved responses of JSON at self._filepath.
|
||||
Useful for continuing if computation was interrupted halfway through.
|
||||
"""
|
||||
if os.path.isfile(self._filepath):
|
||||
with open(self._filepath, encoding="utf-8") as f:
|
||||
responses = json.load(f)
|
||||
return responses
|
||||
else:
|
||||
return {}
|
||||
|
||||
def _cache_responses(self, responses) -> None:
|
||||
with open(self._filepath, "w", encoding='utf-8') as f:
|
||||
json.dump(responses, f)
|
||||
|
||||
def clear_cached_responses(self) -> None:
|
||||
self._cache_responses({})
|
||||
|
||||
async def _prompt_llm(self,
|
||||
llm: LLM,
|
||||
prompt: PromptTemplate,
|
||||
n: int = 1,
|
||||
temperature: float = 1.0,
|
||||
past_resp_obj: Optional[Dict] = None,
|
||||
query_number: Optional[int] = None,
|
||||
rate_limit_batch_size: Optional[int] = None,
|
||||
rate_limit_wait_secs: Optional[float] = None,
|
||||
**llm_params) -> Tuple[str, Dict, Union[List, Dict], Union[Dict, None]]:
|
||||
# Detect how many responses we have already (from cache obj past_resp_obj)
|
||||
if past_resp_obj is not None:
|
||||
# How many *new* queries we need to send:
|
||||
# NOTE: The check n > len(past_resp_obj["responses"]) should occur prior to calling this function.
|
||||
n = n - len(past_resp_obj["responses"])
|
||||
|
||||
# Block asynchronously when we exceed rate limits
|
||||
if query_number is not None and rate_limit_batch_size is not None and rate_limit_wait_secs is not None and rate_limit_batch_size >= 1 and rate_limit_wait_secs > 0:
|
||||
batch_num = int(query_number / rate_limit_batch_size)
|
||||
if batch_num > 0:
|
||||
# We've exceeded the estimated batch rate limit and need to wait the appropriate seconds before sending off new API calls:
|
||||
wait_secs = rate_limit_wait_secs * batch_num
|
||||
if query_number % rate_limit_batch_size == 0: # Print when we start blocking, for each batch
|
||||
print(f"Batch rate limit of {rate_limit_batch_size} reached for LLM {llm}. Waiting {wait_secs} seconds until sending request batch #{batch_num}...")
|
||||
await asyncio.sleep(wait_secs)
|
||||
|
||||
# Get the correct API call for the given LLM:
|
||||
call_llm = None
|
||||
if llm.name[:6] == 'OpenAI':
|
||||
call_llm = call_chatgpt
|
||||
elif llm.name[:5] == 'Azure':
|
||||
call_llm = call_azure_openai
|
||||
elif llm.name[:5] == 'PaLM2':
|
||||
call_llm = call_google_palm
|
||||
elif llm.name[:5] == 'Dalai':
|
||||
call_llm = call_dalai
|
||||
elif llm.value[:6] == 'claude':
|
||||
call_llm = call_anthropic
|
||||
else:
|
||||
raise Exception(f"Language model {llm} is not supported.")
|
||||
|
||||
# Now try to call the API. If it fails for whatever reason, 'soft fail' by returning
|
||||
# an LLMResponseException object as the 'response'.
|
||||
try:
|
||||
query, response = await call_llm(prompt=str(prompt), model=llm, n=n, temperature=temperature, **llm_params)
|
||||
except Exception as e:
|
||||
return prompt, None, LLMResponseException(str(e)), None
|
||||
|
||||
return prompt, query, response, past_resp_obj
|
||||
|
||||
"""
|
||||
Most basic prompt pipeline: given a prompt (and any variables, if it's a template),
|
||||
query the LLM with all prompt permutations, and cache responses.
|
||||
"""
|
||||
class PromptLLM(PromptPipeline):
|
||||
def __init__(self, template: str, storageFile: str):
|
||||
self._template = PromptTemplate(template)
|
||||
super().__init__(storageFile)
|
||||
def gen_prompts(self, properties: dict) -> Iterator[PromptTemplate]:
|
||||
gen_prompts = PromptPermutationGenerator(self._template)
|
||||
return gen_prompts(properties)
|
||||
|
||||
|
||||
"""
|
||||
A dummy class that spoofs LLM responses. Used for testing.
|
||||
"""
|
||||
class PromptLLMDummy(PromptLLM):
|
||||
def __init__(self, template: str, storageFile: str):
|
||||
# Hijack the 'extract_responses' method so that for whichever 'llm' parameter,
|
||||
# it will just return the response verbatim (since dummy responses will always be strings)
|
||||
global extract_responses
|
||||
extract_responses = lambda response, llm: response
|
||||
super().__init__(template, storageFile)
|
||||
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0, past_resp_obj: Optional[Dict] = None, **params) -> Tuple[Dict, Dict]:
|
||||
# Wait a random amount of time, to simulate wait times from real queries
|
||||
await asyncio.sleep(random.uniform(0.1, 3))
|
||||
|
||||
if random.random() > 0.2:
|
||||
# Return a random string of characters of random length (within a predefined range)
|
||||
return prompt, {'prompt': str(prompt)}, [''.join(random.choice(string.ascii_letters) for i in range(random.randint(25, 80))) for _ in range(n)], past_resp_obj
|
||||
else:
|
||||
# Return a mock 'error' making the API request
|
||||
return prompt, None, LLMResponseException('Dummy error'), past_resp_obj
|
@ -1,257 +0,0 @@
|
||||
import re
|
||||
from string import Template
|
||||
from typing import Dict, List, Union
|
||||
|
||||
def escape_dollar_signs(s: str) -> str:
|
||||
pattern = r'\$(?![{])'
|
||||
replaced_string = re.sub(pattern, '$$', s)
|
||||
return replaced_string
|
||||
|
||||
class PromptTemplate:
|
||||
"""
|
||||
Wrapper around string.Template. Use to generate prompts fast.
|
||||
|
||||
Example usage:
|
||||
prompt_temp = PromptTemplate('Can you list all the cities in the country ${country} by the cheapest ${domain} prices?')
|
||||
concrete_prompt = prompt_temp.fill({
|
||||
"country": "France",
|
||||
"domain": "rent"
|
||||
});
|
||||
print(concrete_prompt)
|
||||
|
||||
# Fill can also fill the prompt only partially, which gives us a new prompt template:
|
||||
partial_prompt = prompt_temp.fill({
|
||||
"domain": "rent"
|
||||
});
|
||||
print(partial_prompt)
|
||||
"""
|
||||
def __init__(self, templateStr):
|
||||
"""
|
||||
Initialize a PromptTemplate with a string in string.Template format.
|
||||
(See https://docs.python.org/3/library/string.html#template-strings for more details.)
|
||||
"""
|
||||
# NOTE: ChainForge only supports placeholders with braces {}
|
||||
# We detect any $ without { to the right of them, and insert a '$' before it to escape the $.
|
||||
templateStr = escape_dollar_signs(templateStr)
|
||||
try:
|
||||
Template(templateStr)
|
||||
except Exception:
|
||||
raise Exception("Invalid template formatting for string:", templateStr)
|
||||
self.template = templateStr
|
||||
self.fill_history = {}
|
||||
self.metavars = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.template
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
def has_var(self, varname) -> bool:
|
||||
""" Returns True if the template has a variable with the given name.
|
||||
"""
|
||||
subbed_str = Template(self.template).safe_substitute({varname: '_'})
|
||||
return subbed_str != self.template # if the strings differ, a replacement occurred
|
||||
|
||||
def is_concrete(self) -> bool:
|
||||
""" Returns True if no template variables are left in template string.
|
||||
"""
|
||||
try:
|
||||
Template(self.template).substitute({})
|
||||
return True # no exception raised means there was nothing to substitute...
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def fill(self, paramDict: Dict[str, Union[str, Dict[str, str]]]) -> 'PromptTemplate':
|
||||
"""
|
||||
Formats the template string with the given parameters, returning a new PromptTemplate.
|
||||
Can return a partial completion.
|
||||
|
||||
NOTE: paramDict values can be in a special form: {text: <str>, fill_history: {varname: <str>}}
|
||||
in order to bundle in any past fill history that is lost in the current text.
|
||||
|
||||
Example usage:
|
||||
prompt = prompt_template.fill({
|
||||
"className": className,
|
||||
"library": "Kivy",
|
||||
"PL": "Python"
|
||||
});
|
||||
"""
|
||||
# Check for special 'past fill history' format:
|
||||
past_fill_history = {}
|
||||
past_metavars = {}
|
||||
if len(paramDict) > 0 and isinstance(next(iter(paramDict.values())), dict):
|
||||
for obj in paramDict.values():
|
||||
if "fill_history" in obj:
|
||||
past_fill_history = {**obj['fill_history'], **past_fill_history}
|
||||
if "metavars" in obj:
|
||||
past_metavars = {**obj['metavars'], **past_metavars}
|
||||
paramDict = {param: obj['text'] for param, obj in paramDict.items()}
|
||||
|
||||
filled_pt = PromptTemplate(
|
||||
Template(self.template).safe_substitute(paramDict)
|
||||
)
|
||||
|
||||
# Deep copy prior fill history of this PromptTemplate from this version over to new one
|
||||
filled_pt.fill_history = { key: val for (key, val) in self.fill_history.items() }
|
||||
filled_pt.metavars = { key: val for (key, val) in self.metavars.items() }
|
||||
|
||||
# Append any past history passed as vars:
|
||||
for key, val in past_fill_history.items():
|
||||
if key in filled_pt.fill_history:
|
||||
print(f"Warning: PromptTemplate already has fill history for key {key}.")
|
||||
filled_pt.fill_history[key] = val
|
||||
|
||||
# Append any metavars, overwriting existing ones with the same key
|
||||
for key, val in past_metavars.items():
|
||||
filled_pt.metavars[key] = val
|
||||
|
||||
# Add the new fill history using the passed parameters that we just filled in
|
||||
for key, val in paramDict.items():
|
||||
if key in filled_pt.fill_history:
|
||||
print(f"Warning: PromptTemplate already has fill history for key {key}.")
|
||||
filled_pt.fill_history[key] = val
|
||||
|
||||
return filled_pt
|
||||
|
||||
|
||||
class PromptPermutationGenerator:
|
||||
"""
|
||||
Given a PromptTemplate and a parameter dict that includes arrays of items,
|
||||
generate all the permutations of the prompt for all permutations of the items.
|
||||
|
||||
NOTE: Items can be in a special form: {text: <str>, fill_history: {varname: <str>}}
|
||||
in order to bundle in any past fill history that is lost in the current text.
|
||||
|
||||
Example usage:
|
||||
prompt_gen = PromptPermutationGenerator('Can you list all the cities in the country ${country} by the cheapest ${domain} prices?')
|
||||
for prompt in prompt_gen({"country":["Canada", "South Africa", "China"],
|
||||
"domain": ["rent", "food", "energy"]}):
|
||||
print(prompt)
|
||||
"""
|
||||
def __init__(self, template: Union[PromptTemplate, str]):
|
||||
if isinstance(template, str):
|
||||
template = PromptTemplate(template)
|
||||
self.template = template
|
||||
|
||||
def _gen_perm(self, template, params_to_fill, paramDict):
|
||||
if len(params_to_fill) == 0: return []
|
||||
|
||||
# Extract the first param that occurs in the current template
|
||||
param = None
|
||||
params_left = params_to_fill
|
||||
for p in params_to_fill:
|
||||
if template.has_var(p):
|
||||
param = p
|
||||
params_left = [_p for _p in params_to_fill if _p != p]
|
||||
break
|
||||
|
||||
if param is None:
|
||||
return [template]
|
||||
|
||||
# Generate new prompts by filling in its value(s) into the PromptTemplate
|
||||
val = paramDict[param]
|
||||
if isinstance(val, list):
|
||||
new_prompt_temps = []
|
||||
for v in val:
|
||||
param_fill_dict = {param: v}
|
||||
|
||||
# If this var has an "associate_id", then it wants to "carry with"
|
||||
# values of other prompt parameters with the same id.
|
||||
# We have to find any parameters with values of the same id,
|
||||
# and fill them in alongside the initial parameter v:
|
||||
if isinstance(v, dict) and "associate_id" in v:
|
||||
v_associate_id = v["associate_id"]
|
||||
for other_param in params_left[:]:
|
||||
if template.has_var(other_param) and isinstance(paramDict[other_param], list):
|
||||
other_vals = paramDict[other_param]
|
||||
for ov in other_vals:
|
||||
if isinstance(ov, dict) and ov.get("associate_id") == v_associate_id:
|
||||
# This is a match. We should add the val to our param_fill_dict:
|
||||
param_fill_dict[other_param] = ov
|
||||
break
|
||||
|
||||
# Fill the template with the param values and append it to the list
|
||||
new_prompt_temps.append(template.fill(param_fill_dict))
|
||||
elif isinstance(val, str):
|
||||
new_prompt_temps = [template.fill({param: val})]
|
||||
else:
|
||||
raise ValueError("Value of prompt template parameter is not a list or a string, but of type " + str(type(val)))
|
||||
|
||||
# Recurse
|
||||
if len(params_left) == 0:
|
||||
return new_prompt_temps
|
||||
else:
|
||||
res = []
|
||||
for p in new_prompt_temps:
|
||||
res.extend(self._gen_perm(p, params_left, paramDict))
|
||||
return res
|
||||
|
||||
def __call__(self, paramDict: Dict[str, Union[str, List[str], Dict[str, str]]]):
|
||||
if len(paramDict) == 0:
|
||||
yield self.template
|
||||
return
|
||||
|
||||
for p in self._gen_perm(self.template, list(paramDict.keys()), paramDict):
|
||||
yield p
|
||||
|
||||
# Test cases
|
||||
if __name__ == '__main__':
|
||||
# Dollar sign escape works
|
||||
tests = ["What is $2 + $2?", "If I have $4 and I want ${dollars} then how many do I have?", "$4 is equal to ${dollars}?", "${what} is the $400?"]
|
||||
escaped_tests = [escape_dollar_signs(t) for t in tests]
|
||||
print(escaped_tests)
|
||||
assert escaped_tests[0] == "What is $$2 + $$2?"
|
||||
assert escaped_tests[1] == "If I have $$4 and I want ${dollars} then how many do I have?"
|
||||
assert escaped_tests[2] == "$$4 is equal to ${dollars}?"
|
||||
assert escaped_tests[3] == "${what} is the $$400?"
|
||||
|
||||
# Single template
|
||||
gen = PromptPermutationGenerator('What is the ${timeframe} when ${person} was born?')
|
||||
res = [r for r in gen({'timeframe': ['year', 'decade', 'century'], 'person': ['Howard Hughes', 'Toni Morrison', 'Otis Redding']})]
|
||||
for r in res:
|
||||
print(r)
|
||||
assert len(res) == 9
|
||||
|
||||
# Nested templates
|
||||
gen = PromptPermutationGenerator('${prefix}... ${suffix}')
|
||||
res = [r for r in gen({
|
||||
'prefix': ['Who invented ${tool}?', 'When was ${tool} invented?', 'What can you do with ${tool}?'],
|
||||
'suffix': ['Phrase your answer in the form of a ${response_type}', 'Respond with a ${response_type}'],
|
||||
'tool': ['the flashlight', 'CRISPR', 'rubber'],
|
||||
'response_type': ['question', 'poem', 'nightmare']
|
||||
})]
|
||||
for r in res:
|
||||
print(r)
|
||||
assert len(res) == (3*3)*(2*3)
|
||||
|
||||
# 'Carry together' vars with 'metavar' data attached
|
||||
# NOTE: This feature may be used when passing rows of a table, so that vars that have associated values,
|
||||
# like 'inventor' with 'tool', 'carry together' when being filled into the prompt template.
|
||||
# In addition, 'metavars' may be attached which are, commonly, the values of other columns for that row, but
|
||||
# columns which weren't used to fill in the prompt template explcitly.
|
||||
gen = PromptPermutationGenerator('What ${timeframe} did ${inventor} invent the ${tool}?')
|
||||
res = [r for r in gen({
|
||||
'inventor': [
|
||||
{'text': "Thomas Edison", "fill_history": {}, "associate_id": "A", "metavars": { "year": 1879 }},
|
||||
{'text': "Alexander Fleming", "fill_history": {}, "associate_id": "B", "metavars": { "year": 1928 }},
|
||||
{'text': "William Shockley", "fill_history": {}, "associate_id": "C", "metavars": { "year": 1947 }},
|
||||
],
|
||||
'tool': [
|
||||
{'text': "lightbulb", "fill_history": {}, "associate_id": "A"},
|
||||
{'text': "penicillin", "fill_history": {}, "associate_id": "B"},
|
||||
{'text': "transistor", "fill_history": {}, "associate_id": "C"},
|
||||
],
|
||||
'timeframe': [ "year", "decade", "century" ]
|
||||
})]
|
||||
for r in res:
|
||||
r_str = str(r)
|
||||
print(r_str, r.metavars)
|
||||
assert "year" in r.metavars
|
||||
if "Edison" in r_str:
|
||||
assert "lightbulb" in r_str
|
||||
elif "Fleming" in r_str:
|
||||
assert "penicillin" in r_str
|
||||
elif "Shockley" in r_str:
|
||||
assert "transistor" in r_str
|
||||
assert len(res) == 3*3
|
@ -1,505 +0,0 @@
|
||||
from typing import Dict, Tuple, List, Union, Optional
|
||||
import json, os, time, asyncio
|
||||
from string import Template
|
||||
|
||||
from chainforge.promptengine.models import LLM
|
||||
|
||||
DALAI_MODEL = None
|
||||
DALAI_RESPONSE = None
|
||||
|
||||
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY")
|
||||
GOOGLE_PALM_API_KEY = os.environ.get("PALM_API_KEY")
|
||||
AZURE_OPENAI_KEY = os.environ.get("AZURE_OPENAI_KEY")
|
||||
AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
|
||||
def set_api_keys(api_keys):
|
||||
"""
|
||||
Sets the local API keys for the revelant LLM API(s).
|
||||
Currently only supports 'OpenAI', 'Anthropic'.
|
||||
"""
|
||||
global ANTHROPIC_API_KEY, GOOGLE_PALM_API_KEY, AZURE_OPENAI_KEY, AZURE_OPENAI_ENDPOINT
|
||||
def key_is_present(name):
|
||||
return name in api_keys and len(api_keys[name].strip()) > 0
|
||||
if key_is_present('OpenAI'):
|
||||
import openai
|
||||
openai.api_key = api_keys['OpenAI']
|
||||
if key_is_present('Anthropic'):
|
||||
ANTHROPIC_API_KEY = api_keys['Anthropic']
|
||||
if key_is_present('Google'):
|
||||
GOOGLE_PALM_API_KEY = api_keys['Google']
|
||||
if key_is_present('Azure_OpenAI'):
|
||||
AZURE_OPENAI_KEY = api_keys['Azure_OpenAI']
|
||||
if key_is_present('Azure_OpenAI_Endpoint'):
|
||||
AZURE_OPENAI_ENDPOINT = api_keys['Azure_OpenAI_Endpoint']
|
||||
# Soft fail for non-present keys
|
||||
|
||||
async def make_sync_call_async(sync_method, *args, **params):
|
||||
"""
|
||||
Makes a blocking synchronous call asynchronous, so that it can be awaited.
|
||||
NOTE: This is necessary for LLM APIs that do not yet support async (e.g. Google PaLM).
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
method = sync_method
|
||||
if len(params) > 0:
|
||||
def partial_sync_meth(*a):
|
||||
return sync_method(*a, **params)
|
||||
method = partial_sync_meth
|
||||
return await loop.run_in_executor(None, method, *args)
|
||||
|
||||
async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float= 1.0,
|
||||
system_msg: Optional[str]=None,
|
||||
**params) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls GPT3.5 via OpenAI's API.
|
||||
Returns raw query and response JSON dicts.
|
||||
|
||||
NOTE: It is recommended to set an environment variable OPENAI_API_KEY with your OpenAI API key
|
||||
"""
|
||||
import openai
|
||||
if not openai.api_key:
|
||||
openai.api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
model = model.value
|
||||
if 'stop' in params and (not isinstance(params['stop'], list) or len(params['stop']) == 0):
|
||||
del params['stop']
|
||||
if 'functions' in params and (not isinstance(params['functions'], list) or len(params['functions']) == 0):
|
||||
del params['functions']
|
||||
if 'function_call' in params and (not isinstance(params['function_call'], str) or len(params['function_call'].strip()) == 0):
|
||||
del params['function_call']
|
||||
|
||||
print(f"Querying OpenAI model '{model}' with prompt '{prompt}'...")
|
||||
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
|
||||
|
||||
query = {
|
||||
"model": model,
|
||||
"n": n,
|
||||
"temperature": temperature,
|
||||
**params, # 'the rest' of the settings, passed from a front-end app
|
||||
}
|
||||
|
||||
if 'davinci' in model: # text completions model
|
||||
openai_call = openai.Completion.acreate
|
||||
query['prompt'] = prompt
|
||||
else: # chat model
|
||||
openai_call = openai.ChatCompletion.acreate
|
||||
query['messages'] = [
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
try:
|
||||
response = await openai_call(**query)
|
||||
except Exception as e:
|
||||
if (isinstance(e, openai.error.AuthenticationError)):
|
||||
raise Exception("Could not authenticate to OpenAI. Double-check that your API key is set in Settings or in your local Python environment.")
|
||||
raise e
|
||||
|
||||
return query, response
|
||||
|
||||
async def call_azure_openai(prompt: str, model: LLM, n: int = 1, temperature: float= 1.0,
|
||||
deployment_name: str = 'gpt-35-turbo',
|
||||
model_type: str = "chat-completion",
|
||||
api_version: str = "2023-05-15",
|
||||
system_msg: Optional[str]=None,
|
||||
**params) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls an OpenAI chat model GPT3.5 or GPT4 via Microsoft Azure services.
|
||||
Returns raw query and response JSON dicts.
|
||||
|
||||
NOTE: It is recommended to set an environment variables AZURE_OPENAI_KEY and AZURE_OPENAI_ENDPOINT
|
||||
"""
|
||||
global AZURE_OPENAI_KEY, AZURE_OPENAI_ENDPOINT
|
||||
if AZURE_OPENAI_KEY is None:
|
||||
raise Exception("Could not find an Azure OpenAPI Key to use. Double-check that your key is set in Settings or in your local Python environment.")
|
||||
if AZURE_OPENAI_ENDPOINT is None:
|
||||
raise Exception("Could not find an Azure OpenAI Endpoint to use. Double-check that your endpoint is set in Settings or in your local Python environment.")
|
||||
|
||||
import openai
|
||||
openai.api_type = "azure"
|
||||
openai.api_version = api_version
|
||||
openai.api_key = AZURE_OPENAI_KEY
|
||||
openai.api_base = AZURE_OPENAI_ENDPOINT
|
||||
|
||||
if 'stop' in params and not isinstance(params['stop'], list) or len(params['stop']) == 0:
|
||||
del params['stop']
|
||||
if 'functions' in params and not isinstance(params['functions'], list) or len(params['functions']) == 0:
|
||||
del params['functions']
|
||||
if 'function_call' in params and not isinstance(params['function_call'], str) or len(params['function_call'].strip()) == 0:
|
||||
del params['function_call']
|
||||
|
||||
print(f"Querying Azure OpenAI deployed model '{deployment_name}' at endpoint '{AZURE_OPENAI_ENDPOINT}' with prompt '{prompt}'...")
|
||||
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
|
||||
|
||||
query = {
|
||||
"engine": deployment_name, # this differs from a basic OpenAI call
|
||||
"n": n,
|
||||
"temperature": temperature,
|
||||
**params, # 'the rest' of the settings, passed from a front-end app
|
||||
}
|
||||
|
||||
if model_type == 'text-completion':
|
||||
openai_call = openai.Completion.acreate
|
||||
query['prompt'] = prompt
|
||||
else:
|
||||
openai_call = openai.ChatCompletion.acreate
|
||||
query['messages'] = [
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
try:
|
||||
response = await openai_call(**query)
|
||||
except Exception as e:
|
||||
if (isinstance(e, openai.error.AuthenticationError)):
|
||||
raise Exception("Could not authenticate to OpenAI. Double-check that your API key is set in Settings or in your local Python environment.")
|
||||
raise e
|
||||
|
||||
return query, response
|
||||
|
||||
async def call_anthropic(prompt: str, model: LLM, n: int = 1, temperature: float= 1.0,
|
||||
max_tokens_to_sample=1024,
|
||||
async_mode=False,
|
||||
custom_prompt_wrapper: Optional[str]=None,
|
||||
stop_sequences: Optional[List[str]]=["\n\nHuman:"],
|
||||
**params) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls Anthropic API with the given model, passing in params.
|
||||
Returns raw query and response JSON dicts.
|
||||
|
||||
Unique parameters:
|
||||
- custom_prompt_wrapper: Anthropic models expect prompts in form "\n\nHuman: ${prompt}\n\nAssistant". If you wish to
|
||||
explore custom prompt wrappers that deviate, write a python Template that maps from 'prompt' to custom wrapper.
|
||||
If set to None, defaults to Anthropic's suggested prompt wrapper.
|
||||
- max_tokens_to_sample: A maximum number of tokens to generate before stopping.
|
||||
- stop_sequences: A list of strings upon which to stop generating. Defaults to ["\n\nHuman:"], the cue for the next turn in the dialog agent.
|
||||
- async_mode: Evaluation access to Claude limits calls to 1 at a time, meaning we can't take advantage of async.
|
||||
If you want to send all 'n' requests at once, you can set async_mode to True.
|
||||
|
||||
NOTE: It is recommended to set an environment variable ANTHROPIC_API_KEY with your Anthropic API key
|
||||
"""
|
||||
if ANTHROPIC_API_KEY is None:
|
||||
raise Exception("Could not find an API key for Anthropic models. Double-check that your API key is set in Settings or in your local Python environment.")
|
||||
|
||||
import anthropic
|
||||
client = anthropic.Client(ANTHROPIC_API_KEY)
|
||||
|
||||
# Wrap the prompt in the provided template, or use the default Anthropic one
|
||||
if custom_prompt_wrapper is None or '${prompt}' not in custom_prompt_wrapper:
|
||||
custom_prompt_wrapper = anthropic.HUMAN_PROMPT + " ${prompt}" + anthropic.AI_PROMPT
|
||||
prompt_wrapper_template = Template(custom_prompt_wrapper)
|
||||
wrapped_prompt = prompt_wrapper_template.substitute(prompt=prompt)
|
||||
|
||||
# Format query
|
||||
query = {
|
||||
'model': model.value,
|
||||
'prompt': wrapped_prompt,
|
||||
'max_tokens_to_sample': max_tokens_to_sample,
|
||||
'stop_sequences': stop_sequences,
|
||||
'temperature': temperature,
|
||||
**params
|
||||
}
|
||||
|
||||
print(f"Calling Anthropic model '{model.value}' with prompt '{prompt}' (n={n}). Please be patient...")
|
||||
|
||||
# Request responses using the passed async_mode
|
||||
responses = []
|
||||
if async_mode:
|
||||
# Gather n responses by firing off all API requests at once
|
||||
tasks = [client.acompletion(**query) for _ in range(n)]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
else:
|
||||
# Repeat call n times, waiting for each response to come in:
|
||||
while len(responses) < n:
|
||||
resp = await client.acompletion(**query)
|
||||
responses.append(resp)
|
||||
print(f'{model.value} response {len(responses)} of {n}:\n', resp)
|
||||
|
||||
return query, responses
|
||||
|
||||
async def call_google_palm(prompt: str, model: LLM, n: int = 1, temperature: float= 0.7,
|
||||
max_output_tokens=800,
|
||||
async_mode=False,
|
||||
**params) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls a Google PaLM model.
|
||||
Returns raw query and response JSON dicts.
|
||||
"""
|
||||
if GOOGLE_PALM_API_KEY is None:
|
||||
raise Exception("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local Python environment.")
|
||||
|
||||
import google.generativeai as palm
|
||||
palm.configure(api_key=GOOGLE_PALM_API_KEY)
|
||||
|
||||
is_chat_model = 'chat' in model.value
|
||||
|
||||
query = {
|
||||
'model': f"models/{model.value}",
|
||||
'prompt': prompt,
|
||||
'candidate_count': n,
|
||||
'temperature': temperature,
|
||||
'max_output_tokens': max_output_tokens,
|
||||
**params,
|
||||
}
|
||||
|
||||
# Remove erroneous parameters for text and chat models
|
||||
if 'top_k' in query and query['top_k'] <= 0:
|
||||
del query['top_k']
|
||||
if 'top_p' in query and query['top_p'] <= 0:
|
||||
del query['top_p']
|
||||
if is_chat_model and 'max_output_tokens' in query:
|
||||
del query['max_output_tokens']
|
||||
if is_chat_model and 'stop_sequences' in query:
|
||||
del query['stop_sequences']
|
||||
|
||||
# Get the correct model's completions call
|
||||
palm_call = palm.chat if is_chat_model else palm.generate_text
|
||||
|
||||
# Google PaLM's python API does not currently support async calls.
|
||||
# To make one, we need to wrap it in an asynchronous executor:
|
||||
completion = await make_sync_call_async(palm_call, **query)
|
||||
completion_dict = completion.to_dict()
|
||||
|
||||
# Google PaLM, unlike other chat models, will output empty
|
||||
# responses for any response it deems unsafe (blocks). Although the text completions
|
||||
# API has a (relatively undocumented) 'safety_settings' parameter,
|
||||
# the current chat completions API provides users no control over the blocking.
|
||||
# We need to detect this and fill the response with the safety reasoning:
|
||||
if len(completion.filters) > 0:
|
||||
# Request was blocked. Output why in the response text,
|
||||
# repairing the candidate dict to mock up 'n' responses
|
||||
block_error_msg = f'[[BLOCKED_REQUEST]] Request was blocked because it triggered safety filters: {str(completion.filters)}'
|
||||
completion_dict['candidates'] = [{'author': 1, 'content':block_error_msg}] * n
|
||||
|
||||
# Weirdly, google ignores candidate_count if temperature is 0.
|
||||
# We have to check for this and manually append the n-1 responses:
|
||||
if n > 1 and temperature == 0 and len(completion_dict['candidates']) == 1:
|
||||
copied_candidates = [completion_dict['candidates'][0]] * n
|
||||
completion_dict['candidates'] = copied_candidates
|
||||
|
||||
return query, completion_dict
|
||||
|
||||
async def call_dalai(prompt: str, model: LLM, server: str="http://localhost:4000", n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls a Dalai server running LLMs Alpaca, Llama, etc locally.
|
||||
Returns the raw query and response JSON dicts.
|
||||
|
||||
Parameters:
|
||||
- model: The LLM model, whose value is the name known byt Dalai; e.g. 'alpaca.7b'
|
||||
- port: The port of the local server where Dalai is running. By default 4000.
|
||||
- prompt: The prompt to pass to the LLM.
|
||||
- n: How many times to query. If n > 1, this will continue to query the LLM 'n' times and collect all responses.
|
||||
- temperature: The temperature to query at
|
||||
- params: Any other Dalai-specific params to pass. For more info, see below or https://cocktailpeanut.github.io/dalai/#/?id=syntax-1
|
||||
|
||||
TODO: Currently, this uses a modified dalaipy library for simplicity; however, in the future we might remove this dependency.
|
||||
"""
|
||||
# Import and load upon first run
|
||||
global DALAI_MODEL, DALAI_RESPONSE
|
||||
if not server or len(server.strip()) == 0: # In case user passed a blank server name, revert to default on port 4000
|
||||
server = "http://localhost:4000"
|
||||
if DALAI_MODEL is None:
|
||||
from chainforge.promptengine.dalaipy import Dalai
|
||||
DALAI_MODEL = Dalai(server)
|
||||
elif DALAI_MODEL.server != server: # if the port has changed, we need to create a new model
|
||||
DALAI_MODEL = Dalai(server)
|
||||
|
||||
# Make sure server is connected
|
||||
DALAI_MODEL.connect()
|
||||
|
||||
# Create settings dict to pass to Dalai as args
|
||||
def_params = {'n_predict':128, 'repeat_last_n':64, 'repeat_penalty':1.3, 'seed':-1, 'threads':4, 'top_k':40, 'top_p':0.9}
|
||||
for key in params:
|
||||
if key in def_params:
|
||||
def_params[key] = params[key]
|
||||
else:
|
||||
print(f"Attempted to pass unsupported param '{key}' to Dalai. Ignoring.")
|
||||
|
||||
# Create full query to Dalai
|
||||
query = {
|
||||
'prompt': prompt,
|
||||
'model': model.value,
|
||||
'id': str(round(time.time()*1000)),
|
||||
'temp': temperature,
|
||||
**def_params
|
||||
}
|
||||
|
||||
# Create spot to put response and a callback that sets it
|
||||
DALAI_RESPONSE = None
|
||||
def on_finish(r):
|
||||
global DALAI_RESPONSE
|
||||
DALAI_RESPONSE = r
|
||||
|
||||
print(f"Calling Dalai model '{query['model']}' with prompt '{query['prompt']}' (n={n}). Please be patient...")
|
||||
|
||||
# Repeat call n times
|
||||
responses = []
|
||||
while len(responses) < n:
|
||||
|
||||
# Call the Dalai model
|
||||
req = DALAI_MODEL.generate_request(**query)
|
||||
sent_req_success = DALAI_MODEL.generate(req, on_finish=on_finish)
|
||||
|
||||
if not sent_req_success:
|
||||
print("Something went wrong pinging the Dalai server. Returning None.")
|
||||
return None, None
|
||||
|
||||
# Blocking --wait for request to complete:
|
||||
while DALAI_RESPONSE is None:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
response = DALAI_RESPONSE['response']
|
||||
if response[-5:] == '<end>': # strip ending <end> tag, if present
|
||||
response = response[:-5]
|
||||
if response.index('\r\n') > -1: # strip off the prompt, which is included in the result up to \r\n:
|
||||
response = response[(response.index('\r\n')+2):]
|
||||
DALAI_RESPONSE = None
|
||||
|
||||
responses.append(response)
|
||||
print(f'Response {len(responses)} of {n}:\n', response)
|
||||
|
||||
# Disconnect from the server
|
||||
DALAI_MODEL.disconnect()
|
||||
|
||||
return query, responses
|
||||
|
||||
def _extract_openai_chat_choice_content(choice: dict) -> str:
|
||||
"""
|
||||
Extracts the relevant portion of a OpenAI chat response.
|
||||
|
||||
Note that chat choice objects can now include 'function_call' and a blank 'content' response.
|
||||
This method detects a 'function_call's presence, prepends [[FUNCTION]] and converts the function call into Python format.
|
||||
"""
|
||||
if choice['finish_reason'] == 'function_call' or choice["message"]["content"] is None or \
|
||||
('function_call' in choice['message'] and len(choice['message']['function_call']) > 0):
|
||||
func = choice['message']['function_call']
|
||||
return '[[FUNCTION]] ' + func['name'] + str(func['arguments'])
|
||||
else:
|
||||
return choice["message"]["content"]
|
||||
|
||||
def _extract_chatgpt_responses(response: dict) -> List[str]:
|
||||
"""
|
||||
Extracts the text part of a response JSON from ChatGPT. If there is more
|
||||
than 1 response (e.g., asking the LLM to generate multiple responses),
|
||||
this produces a list of all returned responses.
|
||||
"""
|
||||
choices = response["choices"]
|
||||
return [
|
||||
_extract_openai_chat_choice_content(c)
|
||||
for c in choices
|
||||
]
|
||||
|
||||
def _extract_openai_completion_responses(response: dict) -> List[str]:
|
||||
"""
|
||||
Extracts the text part of a response JSON from OpenAI completions models like Davinci. If there are more
|
||||
than 1 response (e.g., asking the LLM to generate multiple responses),
|
||||
this produces a list of all returned responses.
|
||||
"""
|
||||
choices = response["choices"]
|
||||
return [
|
||||
c["text"].strip()
|
||||
for c in choices
|
||||
]
|
||||
|
||||
def _extract_openai_responses(response: dict) -> List[str]:
|
||||
"""
|
||||
Deduces the format of an OpenAI model response (completion or chat)
|
||||
and extracts the response text using the appropriate method.
|
||||
"""
|
||||
if len(response["choices"]) == 0: return []
|
||||
first_choice = response["choices"][0]
|
||||
if "message" in first_choice:
|
||||
return _extract_chatgpt_responses(response)
|
||||
else:
|
||||
return _extract_openai_completion_responses(response)
|
||||
|
||||
def _extract_palm_responses(completion) -> List[str]:
|
||||
"""
|
||||
Extracts the text part of a 'Completion' object from Google PaLM2 `generate_text` or `chat`
|
||||
|
||||
NOTE: The candidate object for `generate_text` has a key 'output' which contains the response,
|
||||
while the `chat` API uses a key 'content'. This checks for either.
|
||||
"""
|
||||
return [
|
||||
c['output'] if 'output' in c else c['content']
|
||||
for c in completion['candidates']
|
||||
]
|
||||
|
||||
def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List[str]:
|
||||
"""
|
||||
Given a LLM and a response object from its API, extract the
|
||||
text response(s) part of the response object.
|
||||
"""
|
||||
llm_str = llm.name if isinstance(llm, LLM) else llm
|
||||
if llm_str[:6] == 'OpenAI':
|
||||
if 'davinci' in llm_str.lower():
|
||||
return _extract_openai_completion_responses(response)
|
||||
else:
|
||||
return _extract_chatgpt_responses(response)
|
||||
elif llm_str[:5] == 'Azure':
|
||||
return _extract_openai_responses(response)
|
||||
elif llm_str[:5] == 'PaLM2':
|
||||
return _extract_palm_responses(response)
|
||||
elif llm_str[:5] == 'Dalai':
|
||||
return response
|
||||
elif llm_str[:6] == 'Claude':
|
||||
return [r["completion"] for r in response]
|
||||
else:
|
||||
raise ValueError(f"LLM {llm_str} is unsupported.")
|
||||
|
||||
def merge_response_objs(resp_obj_A: Union[dict, None], resp_obj_B: Union[dict, None]) -> dict:
|
||||
if resp_obj_B is None:
|
||||
return resp_obj_A
|
||||
elif resp_obj_A is None:
|
||||
return resp_obj_B
|
||||
raw_resp_A = resp_obj_A["raw_response"]
|
||||
raw_resp_B = resp_obj_B["raw_response"]
|
||||
if not isinstance(raw_resp_A, list):
|
||||
raw_resp_A = [ raw_resp_A ]
|
||||
if not isinstance(raw_resp_B, list):
|
||||
raw_resp_B = [ raw_resp_B ]
|
||||
C = {
|
||||
"responses": resp_obj_A["responses"] + resp_obj_B["responses"],
|
||||
"raw_response": raw_resp_A + raw_resp_B,
|
||||
}
|
||||
return {
|
||||
**C,
|
||||
"prompt": resp_obj_B['prompt'],
|
||||
"query": resp_obj_B['query'],
|
||||
"llm": resp_obj_B['llm'],
|
||||
"info": resp_obj_B['info'],
|
||||
"metavars": resp_obj_B['metavars'],
|
||||
}
|
||||
|
||||
def create_dir_if_not_exists(path: str) -> None:
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
def is_valid_filepath(filepath: str) -> bool:
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8'):
|
||||
pass
|
||||
except IOError:
|
||||
try:
|
||||
# Create the file if it doesn't exist, and write an empty json string to it
|
||||
with open(filepath, 'w+', encoding='utf-8') as f:
|
||||
f.write("{}")
|
||||
pass
|
||||
except IOError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_valid_json(json_dict: dict) -> bool:
|
||||
if isinstance(json_dict, dict):
|
||||
try:
|
||||
json.dumps(json_dict)
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def get_files_at_dir(path: str) -> list:
|
||||
f = []
|
||||
for (dirpath, dirnames, filenames) in os.walk(path):
|
||||
f = filenames
|
||||
break
|
||||
return f
|
2
chainforge/providers/__init__.py
Normal file
2
chainforge/providers/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
__all__ = ['CustomProviderProtocol', 'provider', 'ProviderRegistry']
|
||||
from .protocol import CustomProviderProtocol, provider, ProviderRegistry
|
89
chainforge/providers/dalai.py
Normal file
89
chainforge/providers/dalai.py
Normal file
@ -0,0 +1,89 @@
|
||||
from typing import Tuple, Dict
|
||||
import asyncio, time
|
||||
|
||||
DALAI_MODEL = None
|
||||
DALAI_RESPONSE = None
|
||||
|
||||
async def call_dalai(prompt: str, model: str, server: str="http://localhost:4000", n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls a Dalai server running LLMs Alpaca, Llama, etc locally.
|
||||
Returns the raw query and response JSON dicts.
|
||||
|
||||
Parameters:
|
||||
- model: The LLM model, whose value is the name known byt Dalai; e.g. 'alpaca.7b'
|
||||
- port: The port of the local server where Dalai is running. By default 4000.
|
||||
- prompt: The prompt to pass to the LLM.
|
||||
- n: How many times to query. If n > 1, this will continue to query the LLM 'n' times and collect all responses.
|
||||
- temperature: The temperature to query at
|
||||
- params: Any other Dalai-specific params to pass. For more info, see below or https://cocktailpeanut.github.io/dalai/#/?id=syntax-1
|
||||
|
||||
TODO: Currently, this uses a modified dalaipy library for simplicity; however, in the future we might remove this dependency.
|
||||
"""
|
||||
# Import and load upon first run
|
||||
global DALAI_MODEL, DALAI_RESPONSE
|
||||
if not server or len(server.strip()) == 0: # In case user passed a blank server name, revert to default on port 4000
|
||||
server = "http://localhost:4000"
|
||||
if DALAI_MODEL is None:
|
||||
from chainforge.providers.dalaipy import Dalai
|
||||
DALAI_MODEL = Dalai(server)
|
||||
elif DALAI_MODEL.server != server: # if the port has changed, we need to create a new model
|
||||
DALAI_MODEL = Dalai(server)
|
||||
|
||||
# Make sure server is connected
|
||||
DALAI_MODEL.connect()
|
||||
|
||||
# Create settings dict to pass to Dalai as args
|
||||
def_params = {'n_predict':128, 'repeat_last_n':64, 'repeat_penalty':1.3, 'seed':-1, 'threads':4, 'top_k':40, 'top_p':0.9}
|
||||
for key in params:
|
||||
if key in def_params:
|
||||
def_params[key] = params[key]
|
||||
else:
|
||||
print(f"Attempted to pass unsupported param '{key}' to Dalai. Ignoring.")
|
||||
|
||||
# Create full query to Dalai
|
||||
query = {
|
||||
'prompt': prompt,
|
||||
'model': model,
|
||||
'id': str(round(time.time()*1000)),
|
||||
'temp': temperature,
|
||||
**def_params
|
||||
}
|
||||
|
||||
# Create spot to put response and a callback that sets it
|
||||
DALAI_RESPONSE = None
|
||||
def on_finish(r):
|
||||
global DALAI_RESPONSE
|
||||
DALAI_RESPONSE = r
|
||||
|
||||
print(f"Calling Dalai model '{query['model']}' with prompt '{query['prompt']}' (n={n}). Please be patient...")
|
||||
|
||||
# Repeat call n times
|
||||
responses = []
|
||||
while len(responses) < n:
|
||||
|
||||
# Call the Dalai model
|
||||
req = DALAI_MODEL.generate_request(**query)
|
||||
sent_req_success = DALAI_MODEL.generate(req, on_finish=on_finish)
|
||||
|
||||
if not sent_req_success:
|
||||
print("Something went wrong pinging the Dalai server. Returning None.")
|
||||
return None, None
|
||||
|
||||
# Blocking --wait for request to complete:
|
||||
while DALAI_RESPONSE is None:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
response = DALAI_RESPONSE['response']
|
||||
if response[-5:] == '<end>': # strip ending <end> tag, if present
|
||||
response = response[:-5]
|
||||
if response.index('\r\n') > -1: # strip off the prompt, which is included in the result up to \r\n:
|
||||
response = response[(response.index('\r\n')+2):]
|
||||
DALAI_RESPONSE = None
|
||||
|
||||
responses.append(response)
|
||||
print(f'Response {len(responses)} of {n}:\n', response)
|
||||
|
||||
# Disconnect from the server
|
||||
DALAI_MODEL.disconnect()
|
||||
|
||||
return query, responses
|
127
chainforge/providers/protocol.py
Normal file
127
chainforge/providers/protocol.py
Normal file
@ -0,0 +1,127 @@
|
||||
from typing import Protocol, Optional, TypedDict, Dict, List, Literal, Union, Any
|
||||
import time
|
||||
|
||||
"""
|
||||
OpenAI chat message format typing
|
||||
"""
|
||||
class ChatMessage(TypedDict):
|
||||
""" A single message, in OpenAI chat message format. """
|
||||
role: str
|
||||
content: str
|
||||
name: Optional[str]
|
||||
function_call: Optional[Dict]
|
||||
|
||||
ChatHistory = List[ChatMessage]
|
||||
|
||||
|
||||
class CustomProviderProtocol(Protocol):
|
||||
"""
|
||||
A Callable protocol to implement for custom model provider completions.
|
||||
See `__call__` for more details.
|
||||
"""
|
||||
def __call__(self,
|
||||
prompt: str,
|
||||
model: Optional[str],
|
||||
chat_history: Optional[ChatHistory],
|
||||
**kwargs: Any) -> str:
|
||||
"""
|
||||
Define a call to your custom provider.
|
||||
|
||||
Parameters:
|
||||
- `prompt`: Text to prompt the model. (If it's a chat model, this is the new message to send.)
|
||||
- `model`: The name of the particular model to use, from the CF settings window. Useful when you have multiple models for a single provider. Optional.
|
||||
- `chat_history`: Providers may be passed a past chat context as a list of chat messages in OpenAI format (see `chainforge.providers.ChatHistory`).
|
||||
Chat history does not include the new message to send off (which is passed instead as the `prompt` parameter).
|
||||
- `kwargs`: Any other parameters to pass the provider API call, like temperature. Parameter names are the keynames in your provider's settings_schema.
|
||||
Only relevant if you are defining a custom settings_schema JSON to edit provider/model settings in ChainForge.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
"""
|
||||
A registry for custom providers
|
||||
"""
|
||||
class _ProviderRegistry:
|
||||
def __init__(self):
|
||||
self._registry = {}
|
||||
self._curr_script_id = '0'
|
||||
self._last_updated = {}
|
||||
|
||||
def set_curr_script_id(self, id: str):
|
||||
self._curr_script_id = id
|
||||
|
||||
def register(self, cls: CustomProviderProtocol, name: str, **kwargs):
|
||||
if name is None or isinstance(name, str) is False or len(name) == 0:
|
||||
raise Exception("Cannot register custom model provider: No name given. Name must be a string and unique.")
|
||||
self._last_updated[name] = self._registry[name]["script_id"] if name in self._registry else None
|
||||
self._registry[name] = { "name": name, "func": cls, "script_id": self._curr_script_id, **kwargs }
|
||||
|
||||
def get(self, name):
|
||||
return self._registry.get(name)
|
||||
|
||||
def get_all(self):
|
||||
return list(self._registry.values())
|
||||
|
||||
def has(self, name):
|
||||
return name in self._registry
|
||||
|
||||
def remove(self, name):
|
||||
if self.has(name):
|
||||
del self._registry[name]
|
||||
|
||||
def watch_next_registered(self):
|
||||
self._last_updated = {}
|
||||
|
||||
def last_registered(self):
|
||||
return {k: v for k, v in self._last_updated.items()}
|
||||
|
||||
|
||||
# Global instance of the registry.
|
||||
ProviderRegistry = _ProviderRegistry()
|
||||
|
||||
def provider(name: str = 'Custom Provider',
|
||||
emoji: str = '✨',
|
||||
models: Optional[List[str]] = None,
|
||||
rate_limit: Union[int, Literal["sequential"]] = "sequential",
|
||||
settings_schema: Optional[Dict] = None):
|
||||
"""
|
||||
A decorator for registering custom LLM provider methods or classes (Callables)
|
||||
that conform to `CustomProviderProtocol`.
|
||||
|
||||
Parameters:
|
||||
- `name`: The name of your custom provider. Required. (Must be unique; cannot be blank.)
|
||||
- `emoji`: The emoji to use as the default icon for your provider in the CF interface. Required.
|
||||
- `models`: A list of models that your provider supports, that you want to be able to choose between in Settings window.
|
||||
If you're just calling a single model, you can omit this.
|
||||
- `rate_limit`: If an integer, the maximum number of simulatenous requests to send per minute.
|
||||
To force requests to be sequential (wait until each request returns before sending another), enter "sequential". Default is sequential.
|
||||
- `settings_schema`: a JSON Schema specifying the name of your provider in the ChainForge UI, the available settings, and the UI for those settings.
|
||||
The settings and UI specs are in react-jsonschema-form format: https://rjsf-team.github.io/react-jsonschema-form/.
|
||||
|
||||
Specifically, your `settings_schema` dict should have keys:
|
||||
|
||||
```
|
||||
{
|
||||
"settings": <JSON dict of the schema properties for your settings form, in react-jsonschema-form format (https://rjsf-team.github.io/react-jsonschema-form/docs/)>,
|
||||
"ui": <JSON dict of the UI Schema for your settings form, in react-jsonschema-form (see UISchema example here: https://rjsf-team.github.io/react-jsonschema-form/)
|
||||
}
|
||||
```
|
||||
|
||||
You may look to adapt an existing schema from `ModelSettingsSchemas.js` in `chainforge/react-server/src/`,
|
||||
BUT with the following things to keep in mind:
|
||||
- the value of "settings" should just be the value of "properties" in the full schema
|
||||
- don't include the 'shortname' property; this will be added by default and set to the value of `name`
|
||||
- don't include the 'model' property; this will be populated by the list you passed to `models` (if any)
|
||||
- the keynames of all properties of the schema should be valid as variable names for Python keyword args; i.e., no spaces
|
||||
|
||||
Finally, if you want temperature to appear in the ChainForge UI, you must name your
|
||||
settings schema property `temperature`, and give it `minimum` and `maximum` values.
|
||||
|
||||
NOTE: Only `textarea`, `range`, and enum, and text input UI widgets are properly supported from `react-jsonschema-form`;
|
||||
you can try other widget types, but the CSS may not display property.
|
||||
|
||||
"""
|
||||
def dec(cls: CustomProviderProtocol):
|
||||
ProviderRegistry.register(cls, name=name, emoji=emoji, models=models, rate_limit=rate_limit, settings_schema=settings_schema)
|
||||
return cls
|
||||
return dec
|
@ -1,4 +1,4 @@
|
||||
# Getting Started with Create React App
|
||||
# ChainForge React Server
|
||||
|
||||
This project was bootstrapped with [Create React App](https://github.com/facebook/create-react-app).
|
||||
|
||||
|
@ -1,15 +1,15 @@
|
||||
{
|
||||
"files": {
|
||||
"main.css": "/static/css/main.37690c8a.css",
|
||||
"main.js": "/static/js/main.c2ab48ed.js",
|
||||
"main.js": "/static/js/main.1642966f.js",
|
||||
"static/js/787.4c72bb55.chunk.js": "/static/js/787.4c72bb55.chunk.js",
|
||||
"index.html": "/index.html",
|
||||
"main.37690c8a.css.map": "/static/css/main.37690c8a.css.map",
|
||||
"main.c2ab48ed.js.map": "/static/js/main.c2ab48ed.js.map",
|
||||
"main.1642966f.js.map": "/static/js/main.1642966f.js.map",
|
||||
"787.4c72bb55.chunk.js.map": "/static/js/787.4c72bb55.chunk.js.map"
|
||||
},
|
||||
"entrypoints": [
|
||||
"static/css/main.37690c8a.css",
|
||||
"static/js/main.c2ab48ed.js"
|
||||
"static/js/main.1642966f.js"
|
||||
]
|
||||
}
|
@ -1 +1 @@
|
||||
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.c2ab48ed.js"></script><link href="/static/css/main.37690c8a.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
|
||||
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.1642966f.js"></script><link href="/static/css/main.37690c8a.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
59
chainforge/react-server/package-lock.json
generated
59
chainforge/react-server/package-lock.json
generated
@ -17,6 +17,7 @@
|
||||
"@google-ai/generativelanguage": "^0.2.0",
|
||||
"@mantine/core": "^6.0.9",
|
||||
"@mantine/dates": "^6.0.13",
|
||||
"@mantine/dropzone": "^6.0.19",
|
||||
"@mantine/form": "^6.0.11",
|
||||
"@mantine/prism": "^6.0.15",
|
||||
"@reactflow/background": "^11.2.0",
|
||||
@ -4013,6 +4014,29 @@
|
||||
"react": ">=16.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@mantine/dropzone": {
|
||||
"version": "6.0.19",
|
||||
"resolved": "https://registry.npmjs.org/@mantine/dropzone/-/dropzone-6.0.19.tgz",
|
||||
"integrity": "sha512-riftzhsXSe84oUHIzMsANtJSIINT08N8ycUuGFxGStf+ytVUZn7TommITnifEEokY6iu4yxjv27FK8maaB1BHA==",
|
||||
"dependencies": {
|
||||
"@mantine/utils": "6.0.19",
|
||||
"react-dropzone": "14.2.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@mantine/core": "6.0.19",
|
||||
"@mantine/hooks": "6.0.19",
|
||||
"react": ">=16.8.0",
|
||||
"react-dom": ">=16.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@mantine/dropzone/node_modules/@mantine/utils": {
|
||||
"version": "6.0.19",
|
||||
"resolved": "https://registry.npmjs.org/@mantine/utils/-/utils-6.0.19.tgz",
|
||||
"integrity": "sha512-duvtnaW1gDR2gnvUqnWhl6DMW7sN0HEWqS8Z/BbwaMi75U+Xp17Q72R9JtiIrxQbzsq+KvH9L9B/pxMVwbLirg==",
|
||||
"peerDependencies": {
|
||||
"react": ">=16.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@mantine/form": {
|
||||
"version": "6.0.11",
|
||||
"resolved": "https://registry.npmjs.org/@mantine/form/-/form-6.0.11.tgz",
|
||||
@ -7080,6 +7104,14 @@
|
||||
"resolved": "https://registry.npmjs.org/atob-lite/-/atob-lite-2.0.0.tgz",
|
||||
"integrity": "sha512-LEeSAWeh2Gfa2FtlQE1shxQ8zi5F9GHarrGKz08TMdODD5T4eH6BMsvtnhbWZ+XQn+Gb6om/917ucvRu7l7ukw=="
|
||||
},
|
||||
"node_modules/attr-accept": {
|
||||
"version": "2.2.2",
|
||||
"resolved": "https://registry.npmjs.org/attr-accept/-/attr-accept-2.2.2.tgz",
|
||||
"integrity": "sha512-7prDjvt9HmqiZ0cl5CRjtS84sEyhsHP2coDkaZKRKVfCDo9s7iw7ChVmar78Gu9pC4SoR/28wFu/G5JJhTnqEg==",
|
||||
"engines": {
|
||||
"node": ">=4"
|
||||
}
|
||||
},
|
||||
"node_modules/autoprefixer": {
|
||||
"version": "10.4.14",
|
||||
"resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.14.tgz",
|
||||
@ -11515,6 +11547,17 @@
|
||||
"webpack": "^4.0.0 || ^5.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/file-selector": {
|
||||
"version": "0.6.0",
|
||||
"resolved": "https://registry.npmjs.org/file-selector/-/file-selector-0.6.0.tgz",
|
||||
"integrity": "sha512-QlZ5yJC0VxHxQQsQhXvBaC7VRJ2uaxTf+Tfpu4Z/OcVQJVpZO+DGU0rkoVW5ce2SccxugvpBJoMvUs59iILYdw==",
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
}
|
||||
},
|
||||
"node_modules/filelist": {
|
||||
"version": "1.0.4",
|
||||
"resolved": "https://registry.npmjs.org/filelist/-/filelist-1.0.4.tgz",
|
||||
@ -20265,6 +20308,22 @@
|
||||
"react": "^18.2.0"
|
||||
}
|
||||
},
|
||||
"node_modules/react-dropzone": {
|
||||
"version": "14.2.3",
|
||||
"resolved": "https://registry.npmjs.org/react-dropzone/-/react-dropzone-14.2.3.tgz",
|
||||
"integrity": "sha512-O3om8I+PkFKbxCukfIR3QAGftYXDZfOE2N1mr/7qebQJHs7U+/RSL/9xomJNpRg9kM5h9soQSdf0Gc7OHF5Fug==",
|
||||
"dependencies": {
|
||||
"attr-accept": "^2.2.2",
|
||||
"file-selector": "^0.6.0",
|
||||
"prop-types": "^15.8.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 10.13"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">= 16.8 || 18.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/react-edit-text": {
|
||||
"version": "5.1.0",
|
||||
"resolved": "https://registry.npmjs.org/react-edit-text/-/react-edit-text-5.1.0.tgz",
|
||||
|
@ -12,6 +12,7 @@
|
||||
"@google-ai/generativelanguage": "^0.2.0",
|
||||
"@mantine/core": "^6.0.9",
|
||||
"@mantine/dates": "^6.0.13",
|
||||
"@mantine/dropzone": "^6.0.19",
|
||||
"@mantine/form": "^6.0.11",
|
||||
"@mantine/prism": "^6.0.15",
|
||||
"@reactflow/background": "^11.2.0",
|
||||
|
7
chainforge/react-server/src/App.js
vendored
7
chainforge/react-server/src/App.js
vendored
@ -23,7 +23,7 @@ import GlobalSettingsModal from './GlobalSettingsModal';
|
||||
import ExampleFlowsModal from './ExampleFlowsModal';
|
||||
import AreYouSureModal from './AreYouSureModal';
|
||||
import LLMEvaluatorNode from './LLMEvalNode';
|
||||
import { getDefaultModelFormData, getDefaultModelSettings } from './ModelSettingSchemas';
|
||||
import { getDefaultModelFormData, getDefaultModelSettings, setCustomProviders } from './ModelSettingSchemas';
|
||||
import { v4 as uuid } from 'uuid';
|
||||
import LZString from 'lz-string';
|
||||
import { EXAMPLEFLOW_1 } from './example_flows';
|
||||
@ -122,7 +122,8 @@ const snapGrid = [16, 16];
|
||||
const App = () => {
|
||||
|
||||
// Get nodes, edges, etc. state from the Zustand store:
|
||||
const { nodes, edges, onNodesChange, onEdgesChange, onConnect, addNode, setNodes, setEdges, resetLLMColors } = useStore(selector, shallow);
|
||||
const { nodes, edges, onNodesChange, onEdgesChange,
|
||||
onConnect, addNode, setNodes, setEdges, resetLLMColors } = useStore(selector, shallow);
|
||||
|
||||
// For saving / loading
|
||||
const [rfInstance, setRfInstance] = useState(null);
|
||||
@ -672,7 +673,7 @@ const App = () => {
|
||||
}
|
||||
else return (
|
||||
<div>
|
||||
<GlobalSettingsModal ref={settingsModal} />
|
||||
<GlobalSettingsModal ref={settingsModal} alertModal={alertModal} />
|
||||
<AlertModal ref={alertModal} />
|
||||
<LoadingOverlay visible={isLoading} overlayBlur={1} />
|
||||
<ExampleFlowsModal ref={examplesModal} onSelect={onSelectExampleFlow} />
|
||||
|
292
chainforge/react-server/src/GlobalSettingsModal.js
vendored
292
chainforge/react-server/src/GlobalSettingsModal.js
vendored
@ -1,14 +1,160 @@
|
||||
import React, { useState, forwardRef, useImperativeHandle } from 'react';
|
||||
import { TextInput, Button, Group, Box, Modal, Divider, Text } from '@mantine/core';
|
||||
import React, { useState, forwardRef, useImperativeHandle, useCallback, useEffect } from 'react';
|
||||
import { TextInput, Button, Group, Box, Modal, Divider, Text, Tabs, useMantineTheme, rem, Flex, Center, Badge, Card } from '@mantine/core';
|
||||
import { useDisclosure } from '@mantine/hooks';
|
||||
import { useForm } from '@mantine/form';
|
||||
import useStore from './store';
|
||||
import { IconUpload, IconBrandPython, IconX } from '@tabler/icons-react';
|
||||
import { Dropzone, DropzoneProps } from '@mantine/dropzone';
|
||||
import useStore, { initLLMProviders } from './store';
|
||||
import { APP_IS_RUNNING_LOCALLY } from './backend/utils';
|
||||
import fetch_from_backend from './fetch_from_backend';
|
||||
import { setCustomProviders } from './ModelSettingSchemas';
|
||||
|
||||
const _LINK_STYLE = {color: '#1E90FF', textDecoration: 'none'};
|
||||
|
||||
// To only let us call the backend to load custom providers once upon initalization
|
||||
let LOADED_CUSTOM_PROVIDERS = false;
|
||||
|
||||
// Read a file as text and pass the text to a cb (callback) function
|
||||
const read_file = (file, cb) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = function(event) {
|
||||
const fileContent = event.target.result;
|
||||
cb(fileContent);
|
||||
};
|
||||
reader.onerror = function(event) {
|
||||
console.error("Error reading file:", event);
|
||||
};
|
||||
reader.readAsText(file);
|
||||
};
|
||||
|
||||
/** A Dropzone to load a Python `.py` script that registers a `CustomModelProvider` in the Flask backend.
|
||||
* If successful, the list of custom model providers in the ChainForge UI dropdown is updated.
|
||||
* */
|
||||
const CustomProviderScriptDropzone = ({onError, onSetProviders}) => {
|
||||
const theme = useMantineTheme();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
return (<Dropzone
|
||||
loading={isLoading}
|
||||
onDrop={(files) => {
|
||||
if (files.length === 1) {
|
||||
setIsLoading(true);
|
||||
read_file(files[0], (content) => {
|
||||
// Read the file into text and then send it to backend
|
||||
fetch_from_backend('initCustomProvider', {
|
||||
code: content
|
||||
}).then((response) => {
|
||||
setIsLoading(false);
|
||||
|
||||
if (response.error || !response.providers) {
|
||||
onError(response.error);
|
||||
return;
|
||||
}
|
||||
// Successfully loaded custom providers in backend,
|
||||
// now load them into the ChainForge UI:
|
||||
console.log(response.providers);
|
||||
setCustomProviders(response.providers);
|
||||
onSetProviders(response.providers);
|
||||
}).catch((err) => {
|
||||
setIsLoading(false);
|
||||
onError(err.message);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
console.error('Too many files dropped. Only drop one file at a time.')
|
||||
}
|
||||
}}
|
||||
onReject={(files) => console.log('rejected files', files)}
|
||||
maxSize={3 * 1024 ** 2}
|
||||
accept={{'text/x-python-script': []}}
|
||||
>
|
||||
<Flex pos="center" spacing="md" style={{ minHeight: rem(80), pointerEvents: 'none' }}>
|
||||
<Center>
|
||||
<Dropzone.Accept>
|
||||
<IconUpload
|
||||
size="4.2rem"
|
||||
stroke={1.5}
|
||||
color={theme.colors[theme.primaryColor][theme.colorScheme === 'dark' ? 4 : 6]}
|
||||
/>
|
||||
</Dropzone.Accept>
|
||||
<Dropzone.Reject>
|
||||
<IconX size="4.2rem"
|
||||
stroke={1.5}
|
||||
color={theme.colors.red[theme.colorScheme === 'dark' ? 4 : 6]} />
|
||||
</Dropzone.Reject>
|
||||
<Dropzone.Idle>
|
||||
<IconBrandPython size="4.2rem" stroke={1.5} />
|
||||
</Dropzone.Idle>
|
||||
|
||||
<Box ml='md'>
|
||||
<Text size="md" lh={1.2} inline>
|
||||
Drag a Python script for your custom model provider here
|
||||
</Text>
|
||||
<Text size="sm" color="dimmed" inline mt={7}>
|
||||
Each script should contain one or more registered @provider callables
|
||||
</Text>
|
||||
</Box>
|
||||
</Center>
|
||||
</Flex>
|
||||
|
||||
</Dropzone>);
|
||||
};
|
||||
|
||||
const GlobalSettingsModal = forwardRef((props, ref) => {
|
||||
const [opened, { open, close }] = useDisclosure(false);
|
||||
const setAPIKeys = useStore((state) => state.setAPIKeys);
|
||||
const AvailableLLMs = useStore((state) => state.AvailableLLMs);
|
||||
const setAvailableLLMs = useStore((state) => state.setAvailableLLMs);
|
||||
const nodes = useStore((state) => state.nodes);
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
const alertModal = props?.alertModal;
|
||||
|
||||
const handleError = useCallback((msg) => {
|
||||
if (alertModal && alertModal.current)
|
||||
alertModal.current.trigger(msg);
|
||||
}, [alertModal]);
|
||||
|
||||
const [customProviders, setLocalCustomProviders] = useState([]);
|
||||
const refreshLLMProviderLists = useCallback(() => {
|
||||
// We unfortunately have to force all prompt/chat nodes to refresh their LLM lists, bc
|
||||
// apparently the update to the AvailableLLMs list is not immediately propagated to them.
|
||||
const prompt_nodes = nodes.filter(n => n.type === 'prompt' || n.type === 'chat');
|
||||
prompt_nodes.forEach(n => setDataPropsForNode(n.id, { refreshLLMList: true }));
|
||||
}, [nodes, setDataPropsForNode]);
|
||||
|
||||
const removeCustomProvider = useCallback((name) => {
|
||||
fetch_from_backend('removeCustomProvider', {
|
||||
name: name,
|
||||
}).then((response) => {
|
||||
if (response.error || !response.success) {
|
||||
handleError(response.error);
|
||||
return;
|
||||
}
|
||||
// Successfully deleted the custom provider from backend;
|
||||
// now updated the front-end UI to reflect this:
|
||||
setAvailableLLMs(AvailableLLMs.filter((p) => p.name !== name));
|
||||
setLocalCustomProviders(customProviders.filter((p) => p.name !== name));
|
||||
refreshLLMProviderLists();
|
||||
}).catch((err) => handleError(err.message));
|
||||
}, [customProviders, handleError, AvailableLLMs, refreshLLMProviderLists]);
|
||||
|
||||
// On init
|
||||
useEffect(() => {
|
||||
if (APP_IS_RUNNING_LOCALLY() && !LOADED_CUSTOM_PROVIDERS) {
|
||||
LOADED_CUSTOM_PROVIDERS = true;
|
||||
// Is running locally; try to load any custom providers.
|
||||
// Soft fails if it encounters error:
|
||||
fetch_from_backend('loadCachedCustomProviders', {}, console.error).then((json) => {
|
||||
if (json?.error || json?.providers === undefined) {
|
||||
console.error(json?.error || "Could not load custom provider scripts: Error contacting backend.");
|
||||
return;
|
||||
}
|
||||
// Success; pass custom providers list to store:
|
||||
setCustomProviders(json.providers);
|
||||
setLocalCustomProviders(json.providers);
|
||||
});
|
||||
}
|
||||
}, []);
|
||||
|
||||
const form = useForm({
|
||||
initialValues: {
|
||||
@ -25,6 +171,7 @@ const GlobalSettingsModal = forwardRef((props, ref) => {
|
||||
},
|
||||
});
|
||||
|
||||
// When the API settings form is submitted
|
||||
const onSubmit = (values) => {
|
||||
setAPIKeys(values);
|
||||
close();
|
||||
@ -39,63 +186,102 @@ const GlobalSettingsModal = forwardRef((props, ref) => {
|
||||
}));
|
||||
|
||||
return (
|
||||
<Modal opened={opened} onClose={close} title="ChainForge Settings" closeOnClickOutside={false} style={{position: 'relative', 'left': '-100px'}}>
|
||||
<Modal keepMounted opened={opened} onClose={close} title="ChainForge Settings" closeOnClickOutside={false} style={{position: 'relative', 'left': '-5%'}}>
|
||||
<Box maw={380} mx="auto">
|
||||
<Text mb="md" fz="xs" lh={1.15} color='dimmed'>
|
||||
Note: <b>We do not store your API keys</b> —not in a cookie, localStorage, or server.
|
||||
Because of this, <b>you must set your API keys every time you load ChainForge.</b> If you prefer not to worry about it,
|
||||
we recommend <a href="https://github.com/ianarawjo/ChainForge" target="_blank" style={_LINK_STYLE}>installing ChainForge locally</a> and
|
||||
<a href="https://github.com/ianarawjo/ChainForge/blob/main/INSTALL_GUIDE.md#2-set-api-keys-openai-anthropic-google-palm" target="_blank" style={_LINK_STYLE}> setting your API keys as environment variables.</a>
|
||||
</Text>
|
||||
<form onSubmit={form.onSubmit(onSubmit)}>
|
||||
<TextInput
|
||||
label="OpenAI API Key"
|
||||
placeholder="Paste your OpenAI API key here"
|
||||
{...form.getInputProps('OpenAI')}
|
||||
/>
|
||||
<Tabs defaultValue="api-keys">
|
||||
|
||||
<br />
|
||||
<TextInput
|
||||
label="HuggingFace API Key"
|
||||
placeholder="Paste your HuggingFace API key here"
|
||||
{...form.getInputProps('HuggingFace')}
|
||||
/>
|
||||
<Tabs.List>
|
||||
<Tabs.Tab value="api-keys" >API Keys</Tabs.Tab>
|
||||
<Tabs.Tab value="custom-providers" >Custom Model Providers</Tabs.Tab>
|
||||
</Tabs.List>
|
||||
|
||||
<br />
|
||||
<TextInput
|
||||
label="Anthropic API Key"
|
||||
placeholder="Paste your Anthropic API key here"
|
||||
{...form.getInputProps('Anthropic')}
|
||||
/>
|
||||
|
||||
<br />
|
||||
<TextInput
|
||||
label="Google PaLM API Key"
|
||||
placeholder="Paste your Google PaLM API key here"
|
||||
{...form.getInputProps('Google')}
|
||||
/>
|
||||
<br />
|
||||
<Tabs.Panel value="api-keys" pt="xs">
|
||||
<Text mb="md" fz="xs" lh={1.15} color='dimmed'>
|
||||
Note: <b>We do not store your API keys</b> —not in a cookie, localStorage, or server.
|
||||
Because of this, <b>you must set your API keys every time you load ChainForge.</b> If you prefer not to worry about it,
|
||||
we recommend <a href="https://github.com/ianarawjo/ChainForge" target="_blank" style={_LINK_STYLE}>installing ChainForge locally</a> and
|
||||
<a href="https://github.com/ianarawjo/ChainForge/blob/main/INSTALL_GUIDE.md#2-set-api-keys-openai-anthropic-google-palm" target="_blank" style={_LINK_STYLE}> setting your API keys as environment variables.</a>
|
||||
</Text>
|
||||
<form onSubmit={form.onSubmit(onSubmit)}>
|
||||
<TextInput
|
||||
label="OpenAI API Key"
|
||||
placeholder="Paste your OpenAI API key here"
|
||||
{...form.getInputProps('OpenAI')}
|
||||
/>
|
||||
|
||||
<Divider my="xs" label="Microsoft Azure" labelPosition="center" />
|
||||
<TextInput
|
||||
label="Azure OpenAI Key"
|
||||
description={<span>For more details on Azure OpenAI, see <a href="https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal" target="_blank" style={{color: '#1E90FF', textDecoration:'none'}}>Microsoft Learn.</a> Note that you will have to set the Deployment Name in the Settings of any Azure OpenAI model you add to a Prompt Node.</span>}
|
||||
|
||||
placeholder="Paste your Azure OpenAI Key here"
|
||||
{...form.getInputProps('Azure_OpenAI')}
|
||||
style={{marginBottom: '8pt'}}
|
||||
/>
|
||||
<br />
|
||||
<TextInput
|
||||
label="HuggingFace API Key"
|
||||
placeholder="Paste your HuggingFace API key here"
|
||||
{...form.getInputProps('HuggingFace')}
|
||||
/>
|
||||
|
||||
<TextInput
|
||||
label="Azure OpenAI Endpoint"
|
||||
placeholder="Paste your Azure OpenAI Endpoint here"
|
||||
{...form.getInputProps('Azure_OpenAI_Endpoint')}
|
||||
/>
|
||||
<br />
|
||||
<TextInput
|
||||
label="Anthropic API Key"
|
||||
placeholder="Paste your Anthropic API key here"
|
||||
{...form.getInputProps('Anthropic')}
|
||||
/>
|
||||
|
||||
<Group position="right" mt="md">
|
||||
<Button type="submit">Submit</Button>
|
||||
</Group>
|
||||
</form>
|
||||
<br />
|
||||
<TextInput
|
||||
label="Google PaLM API Key"
|
||||
placeholder="Paste your Google PaLM API key here"
|
||||
{...form.getInputProps('Google')}
|
||||
/>
|
||||
<br />
|
||||
|
||||
<Divider my="xs" label="Microsoft Azure" labelPosition="center" />
|
||||
<TextInput
|
||||
label="Azure OpenAI Key"
|
||||
description={<span>For more details on Azure OpenAI, see <a href="https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal" target="_blank" style={{color: '#1E90FF', textDecoration:'none'}}>Microsoft Learn.</a> Note that you will have to set the Deployment Name in the Settings of any Azure OpenAI model you add to a Prompt Node.</span>}
|
||||
|
||||
placeholder="Paste your Azure OpenAI Key here"
|
||||
{...form.getInputProps('Azure_OpenAI')}
|
||||
style={{marginBottom: '8pt'}}
|
||||
/>
|
||||
|
||||
<TextInput
|
||||
label="Azure OpenAI Endpoint"
|
||||
placeholder="Paste your Azure OpenAI Endpoint here"
|
||||
{...form.getInputProps('Azure_OpenAI_Endpoint')}
|
||||
/>
|
||||
|
||||
<Group position="right" mt="md">
|
||||
<Button type="submit">Submit</Button>
|
||||
</Group>
|
||||
</form>
|
||||
</Tabs.Panel>
|
||||
|
||||
|
||||
{APP_IS_RUNNING_LOCALLY() ?
|
||||
<Tabs.Panel value="custom-providers" pt="md">
|
||||
<Text mb="md" fz="sm" lh={1.3}>
|
||||
You can add model providers to ChainForge by writing custom completion functions as Python scripts. (You can even make your own settings screen!)
|
||||
To learn more, <a href="https://chainforge.ai/docs/custom_providers/" target="_blank" style={_LINK_STYLE}>see the documentation.</a>
|
||||
</Text>
|
||||
{ customProviders.map(p => (
|
||||
<Card key={p.name} shadow='sm' radius='sm' pt='0px' pb='4px' mb='md' withBorder>
|
||||
<Group position="apart">
|
||||
<Group position="left" mt="md" mb="xs">
|
||||
<Text w='10px'>{p.emoji}</Text>
|
||||
<Text weight={500}>{p.name}</Text>
|
||||
{ p.settings_schema ?
|
||||
<Badge color="blue" variant="light">has settings</Badge>
|
||||
: <></> }
|
||||
</Group>
|
||||
<Button onClick={() => removeCustomProvider(p.name)} color='red' p='0px' mt='4px' variant='subtle'><IconX /></Button>
|
||||
</Group>
|
||||
</Card>
|
||||
)) }
|
||||
<CustomProviderScriptDropzone onError={handleError} onSetProviders={(ps) => {
|
||||
refreshLLMProviderLists();
|
||||
setLocalCustomProviders(ps);
|
||||
}} />
|
||||
</Tabs.Panel>
|
||||
: <></>}
|
||||
</Tabs>
|
||||
</Box>
|
||||
</Modal>
|
||||
);
|
||||
|
9
chainforge/react-server/src/LLMEvalNode.js
vendored
9
chainforge/react-server/src/LLMEvalNode.js
vendored
@ -1,23 +1,24 @@
|
||||
import React, { useState, useCallback, useRef, useEffect } from 'react';
|
||||
import { Handle } from 'react-flow-renderer';
|
||||
import { Button, Alert, Progress, Textarea } from '@mantine/core';
|
||||
import { Alert, Progress, Textarea } from '@mantine/core';
|
||||
import { IconAlertTriangle, IconRobot, IconSearch } from "@tabler/icons-react";
|
||||
import { v4 as uuid } from 'uuid';
|
||||
import useStore from './store';
|
||||
import NodeLabel from './NodeLabelComponent';
|
||||
import fetch_from_backend from './fetch_from_backend';
|
||||
import { AvailableLLMs, getDefaultModelSettings } from './ModelSettingSchemas';
|
||||
import { getDefaultModelSettings } from './ModelSettingSchemas';
|
||||
import { LLMListContainer } from './LLMListComponent';
|
||||
import LLMResponseInspectorModal from './LLMResponseInspectorModal';
|
||||
import InspectFooter from './InspectFooter';
|
||||
import { initLLMProviders } from './store';
|
||||
|
||||
// The default prompt shown in gray highlights to give people a good example of an evaluation prompt.
|
||||
const PLACEHOLDER_PROMPT = "Respond with 'true' if the text below has a positive sentiment, and 'false' if not. Do not reply with anything else.";
|
||||
|
||||
// The default LLM annotator is GPT-4 at temperature 0.
|
||||
const DEFAULT_LLM_ITEM = (() => {
|
||||
let item = [AvailableLLMs.find(i => i.base_model === 'gpt-4')]
|
||||
.map((i) => ({key: uuid(), settings: getDefaultModelSettings(i.base_model), ...i}))[0];
|
||||
let item = [initLLMProviders.find(i => i.base_model === 'gpt-4')]
|
||||
.map((i) => ({key: uuid(), settings: getDefaultModelSettings(i.base_model), ...i}))[0];
|
||||
item.settings.temperature = 0.0;
|
||||
return item;
|
||||
})();
|
||||
|
30
chainforge/react-server/src/LLMListComponent.js
vendored
30
chainforge/react-server/src/LLMListComponent.js
vendored
@ -1,15 +1,16 @@
|
||||
import { useState, useEffect, useCallback, useRef, forwardRef, useImperativeHandle } from "react";
|
||||
import { useState, useEffect, useCallback, useRef, forwardRef, useImperativeHandle, useReducer } from "react";
|
||||
import { DragDropContext, Draggable } from "react-beautiful-dnd";
|
||||
import { Menu } from "@mantine/core";
|
||||
import { v4 as uuid } from 'uuid';
|
||||
import LLMListItem, { LLMListItemClone } from "./LLMListItem";
|
||||
import { StrictModeDroppable } from './StrictModeDroppable';
|
||||
import ModelSettingsModal from "./ModelSettingsModal";
|
||||
import { getDefaultModelSettings, AvailableLLMs } from './ModelSettingSchemas';
|
||||
import { getDefaultModelSettings } from './ModelSettingSchemas';
|
||||
import useStore, { initLLMProviders } from "./store";
|
||||
|
||||
// The LLM(s) to include by default on a PromptNode whenever one is created.
|
||||
// Defaults to ChatGPT (GPT3.5) when running locally, and HF-hosted falcon-7b for online version since it's free.
|
||||
const DEFAULT_INIT_LLMS = [AvailableLLMs[0]];
|
||||
const DEFAULT_INIT_LLMS = [initLLMProviders[0]];
|
||||
|
||||
// Helper funcs
|
||||
// Ensure that a name is 'unique'; if not, return an amended version with a count tacked on (e.g. "GPT-4 (2)")
|
||||
@ -68,8 +69,13 @@ export function LLMList({llms, onItemsChange}) {
|
||||
updated_item.formData = {...formData};
|
||||
updated_item.settings = {...settingsData};
|
||||
|
||||
if ('model' in formData) // Update the name of the specific model to call
|
||||
updated_item.model = formData['model'];
|
||||
if ('model' in formData) { // Update the name of the specific model to call
|
||||
if (item.base_model.startsWith('__custom'))
|
||||
// Custom models must always have their base name, to avoid name collisions
|
||||
updated_item.model = item.base_model + '/' + formData['model'];
|
||||
else
|
||||
updated_item.model = formData['model'];
|
||||
}
|
||||
if ('shortname' in formData) {
|
||||
// Change the name, amending any name that isn't unique to ensure it is unique:
|
||||
const unique_name = ensureUniqueName(formData['shortname'], prev_names);
|
||||
@ -158,6 +164,17 @@ export function LLMList({llms, onItemsChange}) {
|
||||
|
||||
export const LLMListContainer = forwardRef(({description, modelSelectButtonText, initLLMItems, onSelectModel, selectModelAction, onItemsChange}, ref) => {
|
||||
|
||||
// All available LLM providers, for the dropdown list
|
||||
const AvailableLLMs = useStore((state) => state.AvailableLLMs);
|
||||
|
||||
// For some reason, when the AvailableLLMs list is updated in the store/, it is not
|
||||
// immediately updated here. I've tried all kinds of things, but cannot seem to fix this problem.
|
||||
// We must force a re-render of the component:
|
||||
const [ignored, forceUpdate] = useReducer(x => x + 1, 0);
|
||||
const refreshLLMProviderList = () => {
|
||||
forceUpdate();
|
||||
};
|
||||
|
||||
// Selecting LLM models to prompt
|
||||
const [llmItems, setLLMItems] = useState(initLLMItems || DEFAULT_INIT_LLMS.map((i) => ({key: uuid(), settings: getDefaultModelSettings(i.base_model), ...i})));
|
||||
const [llmItemsCurrState, setLLMItemsCurrState] = useState([]);
|
||||
@ -228,7 +245,7 @@ export const LLMListContainer = forwardRef(({description, modelSelectButtonText,
|
||||
|
||||
setLLMItems(new_items);
|
||||
if (onSelectModel) onSelectModel(item, new_items);
|
||||
}, [llmItemsCurrState, onSelectModel, selectModelAction]);
|
||||
}, [llmItemsCurrState, onSelectModel, selectModelAction, AvailableLLMs]);
|
||||
|
||||
const onLLMListItemsChange = useCallback((new_items) => {
|
||||
setLLMItemsCurrState(new_items);
|
||||
@ -242,6 +259,7 @@ export const LLMListContainer = forwardRef(({description, modelSelectButtonText,
|
||||
updateProgress,
|
||||
ensureLLMItemsErrorProgress,
|
||||
getLLMListItemForKey,
|
||||
refreshLLMProviderList,
|
||||
}));
|
||||
|
||||
return (<div className="llm-list-container nowheel">
|
||||
|
174
chainforge/react-server/src/ModelSettingSchemas.js
vendored
174
chainforge/react-server/src/ModelSettingSchemas.js
vendored
@ -10,21 +10,18 @@
|
||||
* Descriptions of OpenAI model parameters copied from OpenAI's official chat completions documentation: https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
*/
|
||||
|
||||
import { APP_IS_RUNNING_LOCALLY } from "./backend/utils";
|
||||
import { RATE_LIMITS } from "./backend/models";
|
||||
import { filterDict } from './backend/utils';
|
||||
import useStore from "./store";
|
||||
|
||||
// Available LLMs in ChainForge, in the format expected by LLMListItems.
|
||||
export let AvailableLLMs = [
|
||||
{ name: "GPT3.5", emoji: "🤖", model: "gpt-3.5-turbo", base_model: "gpt-3.5-turbo", temp: 1.0 }, // The base_model designates what settings form will be used, and must be unique.
|
||||
{ name: "GPT4", emoji: "🥵", model: "gpt-4", base_model: "gpt-4", temp: 1.0 },
|
||||
{ name: "Claude", emoji: "📚", model: "claude-2", base_model: "claude-v1", temp: 0.5 },
|
||||
{ name: "PaLM2", emoji: "🦬", model: "chat-bison-001", base_model: "palm2-bison", temp: 0.7 },
|
||||
{ name: "Azure OpenAI", emoji: "🔷", model: "azure-openai", base_model: "azure-openai", temp: 1.0 },
|
||||
{ name: "HuggingFace", emoji: "🤗", model: "tiiuae/falcon-7b-instruct", base_model: "hf", temp: 1.0 },
|
||||
];
|
||||
if (APP_IS_RUNNING_LOCALLY()) {
|
||||
AvailableLLMs.push({ name: "Dalai (Alpaca.7B)", emoji: "🦙", model: "alpaca.7B", base_model: "dalai", temp: 0.5 });
|
||||
}
|
||||
const UI_SUBMIT_BUTTON_SPEC = {
|
||||
props: {
|
||||
disabled: false,
|
||||
className: 'mantine-UnstyledButton-root mantine-Button-root',
|
||||
},
|
||||
norender: false,
|
||||
submitText: 'Submit',
|
||||
};
|
||||
|
||||
const ChatGPTSettings = {
|
||||
fullName: "GPT-3.5+ (OpenAI)",
|
||||
@ -122,14 +119,7 @@ const ChatGPTSettings = {
|
||||
},
|
||||
|
||||
uiSchema: {
|
||||
'ui:submitButtonOptions': {
|
||||
props: {
|
||||
disabled: false,
|
||||
className: 'mantine-UnstyledButton-root mantine-Button-root',
|
||||
},
|
||||
norender: false,
|
||||
submitText: 'Submit',
|
||||
},
|
||||
'ui:submitButtonOptions': UI_SUBMIT_BUTTON_SPEC,
|
||||
"shortname": {
|
||||
"ui:autofocus": true
|
||||
},
|
||||
@ -295,14 +285,7 @@ const ClaudeSettings = {
|
||||
},
|
||||
|
||||
uiSchema: {
|
||||
'ui:submitButtonOptions': {
|
||||
props: {
|
||||
disabled: false,
|
||||
className: 'mantine-UnstyledButton-root mantine-Button-root',
|
||||
},
|
||||
norender: false,
|
||||
submitText: 'Submit',
|
||||
},
|
||||
'ui:submitButtonOptions': UI_SUBMIT_BUTTON_SPEC,
|
||||
"shortname": {
|
||||
"ui:autofocus": true
|
||||
},
|
||||
@ -403,14 +386,7 @@ const PaLM2Settings = {
|
||||
},
|
||||
|
||||
uiSchema: {
|
||||
'ui:submitButtonOptions': {
|
||||
props: {
|
||||
disabled: false,
|
||||
className: 'mantine-UnstyledButton-root mantine-Button-root',
|
||||
},
|
||||
norender: false,
|
||||
submitText: 'Submit',
|
||||
},
|
||||
'ui:submitButtonOptions': UI_SUBMIT_BUTTON_SPEC,
|
||||
"shortname": {
|
||||
"ui:autofocus": true
|
||||
},
|
||||
@ -536,14 +512,7 @@ const DalaiModelSettings = {
|
||||
},
|
||||
|
||||
uiSchema: {
|
||||
'ui:submitButtonOptions': {
|
||||
props: {
|
||||
disabled: false,
|
||||
className: 'mantine-UnstyledButton-root mantine-Button-root',
|
||||
},
|
||||
norender: false,
|
||||
submitText: 'Submit',
|
||||
},
|
||||
'ui:submitButtonOptions': UI_SUBMIT_BUTTON_SPEC,
|
||||
"shortname": {
|
||||
"ui:autofocus": true
|
||||
},
|
||||
@ -730,19 +699,12 @@ const HuggingFaceTextInferenceSettings = {
|
||||
},
|
||||
|
||||
uiSchema: {
|
||||
'ui:submitButtonOptions': {
|
||||
props: {
|
||||
disabled: false,
|
||||
className: 'mantine-UnstyledButton-root mantine-Button-root',
|
||||
},
|
||||
norender: false,
|
||||
submitText: 'Submit',
|
||||
},
|
||||
'ui:submitButtonOptions': UI_SUBMIT_BUTTON_SPEC,
|
||||
"shortname": {
|
||||
"ui:autofocus": true
|
||||
},
|
||||
"model": {
|
||||
"ui:help": "Defaults to Falcon.7B."
|
||||
"ui:help": "Defaults to Falcon.7B."
|
||||
},
|
||||
"temperature": {
|
||||
"ui:help": "Defaults to 1.0.",
|
||||
@ -781,7 +743,7 @@ const HuggingFaceTextInferenceSettings = {
|
||||
};
|
||||
|
||||
// A lookup table indexed by base_model.
|
||||
export const ModelSettings = {
|
||||
export let ModelSettings = {
|
||||
'gpt-3.5-turbo': ChatGPTSettings,
|
||||
'gpt-4': GPT4Settings,
|
||||
'claude-v1': ClaudeSettings,
|
||||
@ -791,6 +753,104 @@ export const ModelSettings = {
|
||||
'hf': HuggingFaceTextInferenceSettings,
|
||||
};
|
||||
|
||||
/**
|
||||
* Add new model provider to the AvailableLLMs list. Also adds the respective ModelSettings schema and rate limit.
|
||||
* @param {*} name The name of the provider, to use in the dropdown menu and default name. Must be unique.
|
||||
* @param {*} emoji The emoji to use for the provider. Optional.
|
||||
* @param {*} models A list of models the user can select from this provider. Optional.
|
||||
* @param {*} rate_limit
|
||||
* @param {*} settings_schema
|
||||
*/
|
||||
export const setCustomProvider = (name, emoji, models, rate_limit, settings_schema, llmProviderList) => {
|
||||
if (typeof emoji === 'string' && (emoji.length === 0 || emoji.length > 2))
|
||||
throw new Error(`Emoji for a custom provider must have a character.`)
|
||||
|
||||
let new_provider = { name };
|
||||
new_provider.emoji = emoji || '✨';
|
||||
|
||||
// Each LLM *model* must have a unique name. To avoid name collisions, for custom providers,
|
||||
// the full LLM model name is a path, __custom/<provider_name>/<submodel name>
|
||||
// If there's no submodel, it's just __custom/<provider_name>.
|
||||
const base_model = `__custom/${name}/`;
|
||||
new_provider.base_model = base_model;
|
||||
new_provider.model = base_model + ((Array.isArray(models) && models.length > 0) ? `${models[0]}` : '');
|
||||
|
||||
// Build the settings form schema for this new custom provider
|
||||
let compiled_schema = {
|
||||
fullName: `${name} (custom provider)`,
|
||||
schema: {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"shortname",
|
||||
],
|
||||
"properties": {
|
||||
"shortname": {
|
||||
"type": "string",
|
||||
"title": "Nickname",
|
||||
"description": "Unique identifier to appear in ChainForge. Keep it short.",
|
||||
"default": name,
|
||||
}
|
||||
}
|
||||
},
|
||||
uiSchema: {
|
||||
'ui:submitButtonOptions': UI_SUBMIT_BUTTON_SPEC,
|
||||
"shortname": {
|
||||
"ui:autofocus": true
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
// Add a models selector if there's multiple models
|
||||
if (Array.isArray(models) && models.length > 0) {
|
||||
compiled_schema.schema["properties"]["model"] = {
|
||||
"type": "string",
|
||||
"title": "Model",
|
||||
"description": `Select a ${name} model to query.`,
|
||||
"enum": models,
|
||||
"default": models[0],
|
||||
};
|
||||
compiled_schema.uiSchema["model"] = {
|
||||
"ui:help": `Defaults to ${models[0]}`
|
||||
};
|
||||
}
|
||||
|
||||
// Add the rest of the settings window if there's one
|
||||
if (settings_schema) {
|
||||
compiled_schema.schema["properties"] = {...compiled_schema.schema["properties"], ...settings_schema.settings};
|
||||
compiled_schema.uiSchema = {...compiled_schema.uiSchema, ...settings_schema.ui};
|
||||
}
|
||||
|
||||
// Check for a default temperature
|
||||
const default_temp = compiled_schema?.schema?.properties?.temperature?.default;
|
||||
if (default_temp !== undefined)
|
||||
new_provider.temp = default_temp;
|
||||
|
||||
// Add the built provider and its settings to the global lookups:
|
||||
let AvailableLLMs = useStore.getState().AvailableLLMs;
|
||||
const prev_provider_idx = AvailableLLMs.findIndex((d) => d.name === name);
|
||||
if (prev_provider_idx > -1)
|
||||
AvailableLLMs[prev_provider_idx] = new_provider;
|
||||
else
|
||||
AvailableLLMs.push(new_provider);
|
||||
ModelSettings[base_model] = compiled_schema;
|
||||
|
||||
// Add rate limit info, if specified
|
||||
if (rate_limit !== undefined && typeof rate_limit === 'number' && rate_limit > 0) {
|
||||
if (rate_limit >= 60)
|
||||
RATE_LIMITS[base_model] = [ Math.trunc(rate_limit/60), 1 ]; // for instance, 300 rpm means 5 every second
|
||||
else
|
||||
RATE_LIMITS[base_model] = [ 1, Math.trunc(60/rate_limit) ]; // for instance, 10 rpm means 1 every 6 seconds
|
||||
}
|
||||
|
||||
// Commit changes to LLM list
|
||||
useStore.getState().setAvailableLLMs(AvailableLLMs);
|
||||
};
|
||||
|
||||
export const setCustomProviders = (providers) => {
|
||||
for (const p of providers)
|
||||
setCustomProvider(p.name, p.emoji, p.models, p.rate_limit, p.settings_schema);
|
||||
};
|
||||
|
||||
export const getTemperatureSpecForModel = (modelName) => {
|
||||
if (modelName in ModelSettings) {
|
||||
const temperature_property = ModelSettings[modelName].schema?.properties?.temperature;
|
||||
@ -806,7 +866,7 @@ export const postProcessFormData = (settingsSpec, formData) => {
|
||||
const skip_keys = {'model': true, 'shortname': true};
|
||||
|
||||
let new_data = {};
|
||||
let postprocessors = settingsSpec.postprocessors ? settingsSpec.postprocessors : {};
|
||||
let postprocessors = settingsSpec?.postprocessors ? settingsSpec.postprocessors : {};
|
||||
|
||||
Object.keys(formData).forEach(key => {
|
||||
if (key in skip_keys) return;
|
||||
@ -815,7 +875,7 @@ export const postProcessFormData = (settingsSpec, formData) => {
|
||||
else
|
||||
new_data[key] = formData[key];
|
||||
});
|
||||
|
||||
|
||||
return new_data;
|
||||
};
|
||||
|
||||
|
@ -16,7 +16,6 @@ const ModelSettingsModal = forwardRef((props, ref) => {
|
||||
|
||||
const [formData, setFormData] = useState(undefined);
|
||||
const onSettingsSubmit = props.onSettingsSubmit;
|
||||
const selectedModelKey = props.model ? props.model.key : null;
|
||||
|
||||
const [schema, setSchema] = useState({'type': 'object', 'description': 'No model info object was passed to settings modal.'});
|
||||
const [uiSchema, setUISchema] = useState({});
|
||||
@ -30,7 +29,7 @@ const ModelSettingsModal = forwardRef((props, ref) => {
|
||||
if (props.model && props.model.base_model) {
|
||||
setModelEmoji(props.model.emoji);
|
||||
if (!(props.model.base_model in ModelSettings)) {
|
||||
setSchema({'type': 'object', 'description': `Did not find settings schema for base model ${props.model.base_model}.`});
|
||||
setSchema({'type': 'object', 'description': `Did not find settings schema for base model ${props.model.base_model}. Maybe you are missing importing a custom provider script?`});
|
||||
setUISchema({});
|
||||
setModelName(props.model.base_model);
|
||||
return;
|
||||
|
5
chainforge/react-server/src/PromptNode.js
vendored
5
chainforge/react-server/src/PromptNode.js
vendored
@ -199,9 +199,12 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
|
||||
// On upstream changes
|
||||
useEffect(() => {
|
||||
if (data.refresh && data.refresh === true) {
|
||||
if (data.refresh === true) {
|
||||
setDataPropsForNode(id, { refresh: false });
|
||||
setStatus('warning');
|
||||
} else if (data.refreshLLMList === true) {
|
||||
llmListContainer?.current?.refreshLLMProviderList();
|
||||
setDataPropsForNode(id, { refreshLLMList: false });
|
||||
}
|
||||
}, [data]);
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
/*
|
||||
* @jest-environment node
|
||||
* @jest-environment jsdom
|
||||
*/
|
||||
import { LLM } from '../models';
|
||||
import { NativeLLM } from '../models';
|
||||
import { expect, test } from '@jest/globals';
|
||||
import { queryLLM, executejs, countQueries, ResponseInfo } from '../backend';
|
||||
import { StandardizedLLMResponse, Dict } from '../typing';
|
||||
@ -32,14 +32,14 @@ test('count queries required', async () => {
|
||||
};
|
||||
|
||||
// Try a number of different inputs
|
||||
await test_count_queries([LLM.OpenAI_ChatGPT, LLM.Claude_v1], 3);
|
||||
await test_count_queries([NativeLLM.OpenAI_ChatGPT, NativeLLM.Claude_v1], 3);
|
||||
await test_count_queries([{ name: "Claude", key: 'claude-test', emoji: "📚", model: "claude-v1", base_model: "claude-v1", temp: 0.5 }], 5);
|
||||
});
|
||||
|
||||
test('call three LLMs with a single prompt', async () => {
|
||||
// Setup params to call
|
||||
const prompt = 'What is one major difference between French and English languages? Be brief.'
|
||||
const llms = [LLM.OpenAI_ChatGPT, LLM.Claude_v1, LLM.PaLM2_Chat_Bison];
|
||||
const llms = [NativeLLM.OpenAI_ChatGPT, NativeLLM.Claude_v1, NativeLLM.PaLM2_Chat_Bison];
|
||||
const n = 1;
|
||||
const progress_listener = (progress: {[key: symbol]: any}) => {
|
||||
console.log(JSON.stringify(progress));
|
||||
|
@ -2,7 +2,7 @@
|
||||
* @jest-environment node
|
||||
*/
|
||||
import { PromptPipeline } from '../query';
|
||||
import { LLM } from '../models';
|
||||
import { LLM, NativeLLM } from '../models';
|
||||
import { expect, test } from '@jest/globals';
|
||||
import { LLMResponseError, LLMResponseObject } from '../typing';
|
||||
|
||||
@ -65,13 +65,13 @@ async function prompt_model(model: LLM): Promise<void> {
|
||||
|
||||
test('basic prompt pipeline with chatgpt', async () => {
|
||||
// Setup a simple pipeline with a prompt template, 1 variable and 3 input values
|
||||
await prompt_model(LLM.OpenAI_ChatGPT);
|
||||
await prompt_model(NativeLLM.OpenAI_ChatGPT);
|
||||
}, 20000);
|
||||
|
||||
test('basic prompt pipeline with anthropic', async () => {
|
||||
await prompt_model(LLM.Claude_v1);
|
||||
await prompt_model(NativeLLM.Claude_v1);
|
||||
}, 40000);
|
||||
|
||||
test('basic prompt pipeline with google palm2', async () => {
|
||||
await prompt_model(LLM.PaLM2_Chat_Bison);
|
||||
await prompt_model(NativeLLM.PaLM2_Chat_Bison);
|
||||
}, 40000);
|
@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @jest-environment node
|
||||
* @jest-environment jsdom
|
||||
*/
|
||||
import { call_anthropic, call_chatgpt, call_google_palm, extract_responses, merge_response_objs } from '../utils';
|
||||
import { LLM } from '../models';
|
||||
import { LLM, NativeLLM } from '../models';
|
||||
import { expect, test } from '@jest/globals';
|
||||
import { LLMResponseObject } from '../typing';
|
||||
|
||||
@ -13,7 +13,7 @@ test('merge response objects', () => {
|
||||
raw_response: ['x', 'y', 'z'],
|
||||
prompt: 'this is a test',
|
||||
query: {},
|
||||
llm: LLM.OpenAI_ChatGPT,
|
||||
llm: NativeLLM.OpenAI_ChatGPT,
|
||||
info: { var1: 'value1', var2: 'value2' },
|
||||
metavars: { meta1: 'meta1' },
|
||||
};
|
||||
@ -22,7 +22,7 @@ test('merge response objects', () => {
|
||||
raw_response: {B: 'B'},
|
||||
prompt: 'this is a test 2',
|
||||
query: {},
|
||||
llm: LLM.OpenAI_ChatGPT,
|
||||
llm: NativeLLM.OpenAI_ChatGPT,
|
||||
info: { varB1: 'valueB1', varB2: 'valueB2' },
|
||||
metavars: { metaB1: 'metaB1' },
|
||||
};
|
||||
@ -46,62 +46,62 @@ test('merge response objects', () => {
|
||||
|
||||
test('openai chat completions', async () => {
|
||||
// Call ChatGPT with a basic question, and n=2
|
||||
const [query, response] = await call_chatgpt("Who invented modern playing cards? Keep your answer brief.", LLM.OpenAI_ChatGPT, 2, 1.0);
|
||||
const [query, response] = await call_chatgpt("Who invented modern playing cards? Keep your answer brief.", NativeLLM.OpenAI_ChatGPT, 2, 1.0);
|
||||
console.log(response.choices[0].message);
|
||||
expect(response.choices).toHaveLength(2);
|
||||
expect(query).toHaveProperty('temperature');
|
||||
|
||||
// Extract responses, check their type
|
||||
const resps = extract_responses(response, LLM.OpenAI_ChatGPT);
|
||||
const resps = extract_responses(response, NativeLLM.OpenAI_ChatGPT);
|
||||
expect(resps).toHaveLength(2);
|
||||
expect(typeof resps[0]).toBe('string');
|
||||
}, 20000);
|
||||
|
||||
test('openai text completions', async () => {
|
||||
// Call OpenAI template with a basic question, and n=2
|
||||
const [query, response] = await call_chatgpt("Who invented modern playing cards? The answer is ", LLM.OpenAI_Davinci003, 2, 1.0);
|
||||
const [query, response] = await call_chatgpt("Who invented modern playing cards? The answer is ", NativeLLM.OpenAI_Davinci003, 2, 1.0);
|
||||
console.log(response.choices[0].text);
|
||||
expect(response.choices).toHaveLength(2);
|
||||
expect(query).toHaveProperty('n');
|
||||
|
||||
// Extract responses, check their type
|
||||
const resps = extract_responses(response, LLM.OpenAI_Davinci003);
|
||||
const resps = extract_responses(response, NativeLLM.OpenAI_Davinci003);
|
||||
expect(resps).toHaveLength(2);
|
||||
expect(typeof resps[0]).toBe('string');
|
||||
}, 20000);
|
||||
|
||||
test('anthropic models', async () => {
|
||||
// Call Anthropic's Claude with a basic question
|
||||
const [query, response] = await call_anthropic("Who invented modern playing cards?", LLM.Claude_v1, 1, 1.0);
|
||||
const [query, response] = await call_anthropic("Who invented modern playing cards?", NativeLLM.Claude_v1, 1, 1.0);
|
||||
console.log(response);
|
||||
expect(response).toHaveLength(1);
|
||||
expect(query).toHaveProperty('max_tokens_to_sample');
|
||||
|
||||
// Extract responses, check their type
|
||||
const resps = extract_responses(response, LLM.Claude_v1);
|
||||
const resps = extract_responses(response, NativeLLM.Claude_v1);
|
||||
expect(resps).toHaveLength(1);
|
||||
expect(typeof resps[0]).toBe('string');
|
||||
}, 20000);
|
||||
|
||||
test('google palm2 models', async () => {
|
||||
// Call Google's PaLM Chat API with a basic question
|
||||
let [query, response] = await call_google_palm("Who invented modern playing cards?", LLM.PaLM2_Chat_Bison, 3, 0.7);
|
||||
let [query, response] = await call_google_palm("Who invented modern playing cards?", NativeLLM.PaLM2_Chat_Bison, 3, 0.7);
|
||||
expect(response.candidates).toHaveLength(3);
|
||||
expect(query).toHaveProperty('candidateCount');
|
||||
|
||||
// Extract responses, check their type
|
||||
let resps = extract_responses(response, LLM.PaLM2_Chat_Bison);
|
||||
let resps = extract_responses(response, NativeLLM.PaLM2_Chat_Bison);
|
||||
expect(resps).toHaveLength(3);
|
||||
expect(typeof resps[0]).toBe('string');
|
||||
console.log(JSON.stringify(resps));
|
||||
|
||||
// Call Google's PaLM Text Completions API with a basic question
|
||||
[query, response] = await call_google_palm("Who invented modern playing cards? The answer ", LLM.PaLM2_Text_Bison, 3, 0.7);
|
||||
[query, response] = await call_google_palm("Who invented modern playing cards? The answer ", NativeLLM.PaLM2_Text_Bison, 3, 0.7);
|
||||
expect(response.candidates).toHaveLength(3);
|
||||
expect(query).toHaveProperty('maxOutputTokens');
|
||||
|
||||
// Extract responses, check their type
|
||||
resps = extract_responses(response, LLM.PaLM2_Chat_Bison);
|
||||
resps = extract_responses(response, NativeLLM.PaLM2_Chat_Bison);
|
||||
expect(resps).toHaveLength(3);
|
||||
expect(typeof resps[0]).toBe('string');
|
||||
console.log(JSON.stringify(resps));
|
||||
|
@ -1,7 +1,7 @@
|
||||
import markdownIt from "markdown-it";
|
||||
|
||||
import { Dict, StringDict, LLMResponseError, LLMResponseObject, StandardizedLLMResponse, ChatHistory, ChatHistoryInfo, isEqualChatHistory } from "./typing";
|
||||
import { LLM, getEnumName } from "./models";
|
||||
import { Dict, StringDict, LLMResponseError, LLMResponseObject, StandardizedLLMResponse, ChatHistoryInfo, isEqualChatHistory } from "./typing";
|
||||
import { LLM, NativeLLM, getEnumName } from "./models";
|
||||
import { APP_IS_RUNNING_LOCALLY, set_api_keys, FLASK_BASE_URL, call_flask_backend } from "./utils";
|
||||
import StorageCache from "./cache";
|
||||
import { PromptPipeline } from "./query";
|
||||
@ -12,11 +12,6 @@ import { PromptPermutationGenerator, PromptTemplate } from "./template";
|
||||
// =================
|
||||
// """
|
||||
|
||||
let LLM_NAME_MAP = {};
|
||||
Object.entries(LLM).forEach(([key, value]) => {
|
||||
LLM_NAME_MAP[value] = key;
|
||||
});
|
||||
|
||||
enum MetricType {
|
||||
KeyValue = 0,
|
||||
KeyValue_Numeric = 1,
|
||||
@ -553,13 +548,7 @@ export async function queryLLM(id: string,
|
||||
// Ensure llm param is an array
|
||||
if (typeof llm === 'string')
|
||||
llm = [ llm ];
|
||||
llm = llm as (Array<string> | Array<Dict>);
|
||||
|
||||
for (let i = 0; i < llm.length; i++) {
|
||||
const llm_spec = llm[i];
|
||||
if (!(extract_llm_name(llm_spec) in LLM_NAME_MAP))
|
||||
return {'error': `LLM named '${llm_spec}' is not supported.`};
|
||||
}
|
||||
llm = llm as (Array<string> | Array<Dict>);
|
||||
|
||||
await setAPIKeys(api_keys);
|
||||
|
||||
@ -1120,3 +1109,57 @@ export async function fetchOpenAIEval(evalname: string): Promise<Dict> {
|
||||
.then(response => response.json())
|
||||
.then(res => ({data: res}));
|
||||
}
|
||||
|
||||
/**
|
||||
* Passes a Python script to load a custom model provider to the Flask backend.
|
||||
|
||||
* @param code The Python script to pass, as a string.
|
||||
* @returns a Promise with the JSON of the response. Will include 'error' key if error'd; if success,
|
||||
* a 'providers' key with a list of all loaded custom provider callbacks, as dicts.
|
||||
*/
|
||||
export async function initCustomProvider(code: string): Promise<Dict> {
|
||||
// Attempt to fetch the example flow from the local filesystem
|
||||
// by querying the Flask server:
|
||||
return fetch(`${FLASK_BASE_URL}app/initCustomProvider`, {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({ code })
|
||||
}).then(function(res) {
|
||||
return res.json();
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Asks Python script to remove a custom provider with name 'name'.
|
||||
|
||||
* @param name The name of the provider to remove. The name must match the name in the `ProviderRegistry`.
|
||||
* @returns a Promise with the JSON of the response. Will include 'error' key if error'd; if success,
|
||||
* a 'success' key with a true value.
|
||||
*/
|
||||
export async function removeCustomProvider(name: string): Promise<Dict> {
|
||||
// Attempt to fetch the example flow from the local filesystem
|
||||
// by querying the Flask server:
|
||||
return fetch(`${FLASK_BASE_URL}app/removeCustomProvider`, {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({ name })
|
||||
}).then(function(res) {
|
||||
return res.json();
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Asks Python backend to load custom provider scripts that are cache'd in the user's local dir.
|
||||
*
|
||||
* @returns a Promise with the JSON of the response. Will include 'error' key if error'd; if success,
|
||||
* a 'providers' key with all loaded custom providers in an array. If there were none, returns empty array.
|
||||
*/
|
||||
export async function loadCachedCustomProviders(): Promise<Dict> {
|
||||
return fetch(`${FLASK_BASE_URL}app/loadCachedCustomProviders`, {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: "{}"
|
||||
}).then(function(res) {
|
||||
return res.json();
|
||||
});
|
||||
}
|
@ -1,7 +1,8 @@
|
||||
/**
|
||||
* A list of all model APIs natively supported by ChainForge.
|
||||
*/
|
||||
export enum LLM {
|
||||
export type LLM = string | NativeLLM;
|
||||
export enum NativeLLM {
|
||||
// OpenAI Chat
|
||||
OpenAI_ChatGPT = "gpt-3.5-turbo",
|
||||
OpenAI_ChatGPT_16k = "gpt-3.5-turbo-16k",
|
||||
@ -70,6 +71,7 @@ export enum LLMProvider {
|
||||
Anthropic = "anthropic",
|
||||
Google = "google",
|
||||
HuggingFace = "hf",
|
||||
Custom = "__custom",
|
||||
}
|
||||
|
||||
/**
|
||||
@ -78,7 +80,7 @@ export enum LLMProvider {
|
||||
* @returns an `LLMProvider` describing what provider hosts the model
|
||||
*/
|
||||
export function getProvider(llm: LLM): LLMProvider | undefined {
|
||||
const llm_name = getEnumName(LLM, llm.toString());
|
||||
const llm_name = getEnumName(NativeLLM, llm.toString());
|
||||
if (llm_name?.startsWith('OpenAI'))
|
||||
return LLMProvider.OpenAI;
|
||||
else if (llm_name?.startsWith('Azure'))
|
||||
@ -91,6 +93,8 @@ export function getProvider(llm: LLM): LLMProvider | undefined {
|
||||
return LLMProvider.HuggingFace;
|
||||
else if (llm.toString().startsWith('claude'))
|
||||
return LLMProvider.Anthropic;
|
||||
else if (llm.toString().startsWith('__custom/'))
|
||||
return LLMProvider.Custom;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
@ -101,21 +105,21 @@ export function getProvider(llm: LLM): LLMProvider | undefined {
|
||||
# This 'cheap' version of controlling for rate limits is to wait a few seconds between batches of requests being sent off.
|
||||
# If a model is missing from below, it means we must send and receive only 1 request at a time (synchronous).
|
||||
# The following is only a guideline, and a bit on the conservative side. */
|
||||
export const RATE_LIMITS: { [key in LLM]?: [number, number] } = {
|
||||
[LLM.OpenAI_ChatGPT]: [30, 10], // max 30 requests a batch; wait 10 seconds between
|
||||
[LLM.OpenAI_ChatGPT_0301]: [30, 10],
|
||||
[LLM.OpenAI_ChatGPT_0613]: [30, 10],
|
||||
[LLM.OpenAI_ChatGPT_16k]: [30, 10],
|
||||
[LLM.OpenAI_ChatGPT_16k_0613]: [30, 10],
|
||||
[LLM.OpenAI_GPT4]: [4, 15], // max 4 requests a batch; wait 15 seconds between
|
||||
[LLM.OpenAI_GPT4_0314]: [4, 15],
|
||||
[LLM.OpenAI_GPT4_0613]: [4, 15],
|
||||
[LLM.OpenAI_GPT4_32k]: [4, 15],
|
||||
[LLM.OpenAI_GPT4_32k_0314]: [4, 15],
|
||||
[LLM.OpenAI_GPT4_32k_0613]: [4, 15],
|
||||
[LLM.Azure_OpenAI]: [30, 10],
|
||||
[LLM.PaLM2_Text_Bison]: [4, 10], // max 30 requests per minute; so do 4 per batch, 10 seconds between (conservative)
|
||||
[LLM.PaLM2_Chat_Bison]: [4, 10],
|
||||
export let RATE_LIMITS: { [key in LLM]?: [number, number] } = {
|
||||
[NativeLLM.OpenAI_ChatGPT]: [30, 10], // max 30 requests a batch; wait 10 seconds between
|
||||
[NativeLLM.OpenAI_ChatGPT_0301]: [30, 10],
|
||||
[NativeLLM.OpenAI_ChatGPT_0613]: [30, 10],
|
||||
[NativeLLM.OpenAI_ChatGPT_16k]: [30, 10],
|
||||
[NativeLLM.OpenAI_ChatGPT_16k_0613]: [30, 10],
|
||||
[NativeLLM.OpenAI_GPT4]: [4, 15], // max 4 requests a batch; wait 15 seconds between
|
||||
[NativeLLM.OpenAI_GPT4_0314]: [4, 15],
|
||||
[NativeLLM.OpenAI_GPT4_0613]: [4, 15],
|
||||
[NativeLLM.OpenAI_GPT4_32k]: [4, 15],
|
||||
[NativeLLM.OpenAI_GPT4_32k_0314]: [4, 15],
|
||||
[NativeLLM.OpenAI_GPT4_32k_0613]: [4, 15],
|
||||
[NativeLLM.Azure_OpenAI]: [30, 10],
|
||||
[NativeLLM.PaLM2_Text_Bison]: [4, 10], // max 30 requests per minute; so do 4 per batch, 10 seconds between (conservative)
|
||||
[NativeLLM.PaLM2_Chat_Bison]: [4, 10],
|
||||
};
|
||||
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { PromptTemplate, PromptPermutationGenerator } from "./template";
|
||||
import { LLM, RATE_LIMITS } from './models';
|
||||
import { Dict, LLMResponseError, LLMResponseObject, ChatHistory, isEqualChatHistory, ChatHistoryInfo } from "./typing";
|
||||
import { LLM, NativeLLM, RATE_LIMITS } from './models';
|
||||
import { Dict, LLMResponseError, LLMResponseObject, isEqualChatHistory, ChatHistoryInfo } from "./typing";
|
||||
import { extract_responses, merge_response_objs, call_llm, mergeDicts } from "./utils";
|
||||
import StorageCache from "./cache";
|
||||
|
||||
@ -164,7 +164,7 @@ export class PromptPipeline {
|
||||
}
|
||||
|
||||
if (!prompt.is_concrete())
|
||||
throw Error(`Cannot send a prompt '${prompt}' to LLM: Prompt is a template.`)
|
||||
throw new Error(`Cannot send a prompt '${prompt}' to LLM: Prompt is a template.`)
|
||||
|
||||
// Get the cache of responses with respect to this prompt, + normalize format so it's always an array (of size >= 0)
|
||||
const cache_bucket = responses[prompt_str];
|
||||
@ -191,7 +191,7 @@ export class PromptPipeline {
|
||||
"query": cached_resp["query"],
|
||||
"responses": extracted_resps.slice(0, n),
|
||||
"raw_response": cached_resp["raw_response"],
|
||||
"llm": cached_resp["llm"] || LLM.OpenAI_ChatGPT,
|
||||
"llm": cached_resp["llm"] || NativeLLM.OpenAI_ChatGPT,
|
||||
// We want to use the new info, since 'vars' could have changed even though
|
||||
// the prompt text is the same (e.g., "this is a tool -> this is a {x} where x='tool'")
|
||||
"info": mergeDicts(info, chat_history?.fill_history),
|
||||
|
@ -223,16 +223,16 @@ export async function call_chatgpt(prompt: string, model: LLM, n: number = 1, te
|
||||
*/
|
||||
export async function call_azure_openai(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
|
||||
if (!AZURE_OPENAI_KEY)
|
||||
throw Error("Could not find an Azure OpenAPI Key to use. Double-check that your key is set in Settings or in your local environment.");
|
||||
throw new Error("Could not find an Azure OpenAPI Key to use. Double-check that your key is set in Settings or in your local environment.");
|
||||
if (!AZURE_OPENAI_ENDPOINT)
|
||||
throw Error("Could not find an Azure OpenAI Endpoint to use. Double-check that your endpoint is set in Settings or in your local environment.");
|
||||
throw new Error("Could not find an Azure OpenAI Endpoint to use. Double-check that your endpoint is set in Settings or in your local environment.");
|
||||
|
||||
const deployment_name: string = params?.deployment_name;
|
||||
const model_type: string = params?.model_type;
|
||||
if (!deployment_name)
|
||||
throw Error("Could not find an Azure OpenAPI deployment name. Double-check that your deployment name is set in Settings or in your local environment.");
|
||||
throw new Error("Could not find an Azure OpenAPI deployment name. Double-check that your deployment name is set in Settings or in your local environment.");
|
||||
if (!model_type)
|
||||
throw Error("Could not find a model type specified for an Azure OpenAI model. Double-check that your deployment name is set in Settings or in your local environment.");
|
||||
throw new Error("Could not find a model type specified for an Azure OpenAI model. Double-check that your deployment name is set in Settings or in your local environment.");
|
||||
|
||||
const client = new AzureOpenAIClient(AZURE_OPENAI_ENDPOINT, new AzureKeyCredential(AZURE_OPENAI_KEY));
|
||||
|
||||
@ -296,12 +296,12 @@ export async function call_azure_openai(prompt: string, model: LLM, n: number =
|
||||
*/
|
||||
export async function call_anthropic(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
|
||||
if (!ANTHROPIC_API_KEY)
|
||||
throw Error("Could not find an API key for Anthropic models. Double-check that your API key is set in Settings or in your local environment.");
|
||||
throw new Error("Could not find an API key for Anthropic models. Double-check that your API key is set in Settings or in your local environment.");
|
||||
|
||||
// Wrap the prompt in the provided template, or use the default Anthropic one
|
||||
const custom_prompt_wrapper: string = params?.custom_prompt_wrapper || (ANTHROPIC_HUMAN_PROMPT + " {prompt}" + ANTHROPIC_AI_PROMPT);
|
||||
if (!custom_prompt_wrapper.includes('{prompt}'))
|
||||
throw Error("Custom prompt wrapper is missing required {prompt} template variable.");
|
||||
throw new Error("Custom prompt wrapper is missing required {prompt} template variable.");
|
||||
const prompt_wrapper_template = new StringTemplate(custom_prompt_wrapper);
|
||||
let wrapped_prompt = prompt_wrapper_template.safe_substitute({prompt: prompt});
|
||||
|
||||
@ -389,7 +389,7 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
|
||||
*/
|
||||
export async function call_google_palm(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
|
||||
if (!GOOGLE_PALM_API_KEY)
|
||||
throw Error("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local environment.");
|
||||
throw new Error("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local environment.");
|
||||
|
||||
const is_chat_model = model.toString().includes('chat');
|
||||
|
||||
@ -621,6 +621,40 @@ export async function call_huggingface(prompt: string, model: LLM, n: number = 1
|
||||
return [query, responses];
|
||||
}
|
||||
|
||||
async function call_custom_provider(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
|
||||
if (!APP_IS_RUNNING_LOCALLY())
|
||||
throw new Error("The ChainForge app does not appear to be running locally. You can only call custom model providers if you are running ChainForge on your local machine, from a Flask app.")
|
||||
|
||||
// The model to call is in format:
|
||||
// __custom/<provider_name>/<submodel name>
|
||||
// It may also exclude the final tag.
|
||||
// We extract the provider name (this is the name used in the Python backend's `ProviderRegistry`) and optionally, the submodel name
|
||||
const provider_path = model.substring(9);
|
||||
const provider_name = provider_path.substring(0, provider_path.indexOf('/'));
|
||||
const submodel_name = (provider_path.length === provider_name.length-1) ? undefined : provider_path.substring(provider_path.lastIndexOf('/')+1);
|
||||
|
||||
let responses = [];
|
||||
const query = { prompt, model, temperature, ...params };
|
||||
|
||||
// Call the custom provider n times
|
||||
while (responses.length < n) {
|
||||
let {response, error} = await call_flask_backend('callCustomProvider',
|
||||
{ 'name': provider_name,
|
||||
'params': {
|
||||
prompt, model: submodel_name, temperature, ...params
|
||||
}
|
||||
});
|
||||
|
||||
// Fail if an error is encountered
|
||||
if (error !== undefined || response === undefined)
|
||||
throw new Error(error);
|
||||
|
||||
responses.push(response);
|
||||
}
|
||||
|
||||
return [query, responses];
|
||||
}
|
||||
|
||||
/**
|
||||
* Switcher that routes the request to the appropriate API call function. If call doesn't exist, throws error.
|
||||
*/
|
||||
@ -644,6 +678,8 @@ export async function call_llm(llm: LLM, prompt: string, n: number, temperature:
|
||||
call_api = call_anthropic;
|
||||
else if (llm_provider === LLMProvider.HuggingFace)
|
||||
call_api = call_huggingface;
|
||||
else if (llm_provider === LLMProvider.Custom)
|
||||
call_api = call_custom_provider;
|
||||
|
||||
return call_api(prompt, llm, n, temperature, params);
|
||||
}
|
||||
@ -744,8 +780,11 @@ export function extract_responses(response: Array<string | Dict> | Dict, llm: LL
|
||||
case LLMProvider.HuggingFace:
|
||||
return _extract_huggingface_responses(response as Dict[]);
|
||||
default:
|
||||
throw new Error(`No method defined to extract responses for LLM ${llm}.`)
|
||||
}
|
||||
if (Array.isArray(response) && response.length > 0 && typeof response[0] === 'string')
|
||||
return response as string[];
|
||||
else
|
||||
throw new Error(`No method defined to extract responses for LLM ${llm}.`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { queryLLM, executejs, executepy,
|
||||
fetchExampleFlow, fetchOpenAIEval, importCache,
|
||||
exportCache, countQueries, grabResponses,
|
||||
generatePrompts,
|
||||
evalWithLLM} from "./backend/backend";
|
||||
generatePrompts, initCustomProvider,
|
||||
removeCustomProvider, evalWithLLM, loadCachedCustomProviders } from "./backend/backend";
|
||||
|
||||
const clone = (obj) => JSON.parse(JSON.stringify(obj));
|
||||
|
||||
@ -30,6 +30,12 @@ async function _route_to_js_backend(route, params) {
|
||||
return fetchExampleFlow(params.name);
|
||||
case 'fetchOpenAIEval':
|
||||
return fetchOpenAIEval(params.name);
|
||||
case 'initCustomProvider':
|
||||
return initCustomProvider(params.code);
|
||||
case 'removeCustomProvider':
|
||||
return removeCustomProvider(params.name);
|
||||
case 'loadCachedCustomProviders':
|
||||
return loadCachedCustomProviders();
|
||||
default:
|
||||
throw new Error(`Could not find backend function for route named ${route}`);
|
||||
}
|
||||
|
19
chainforge/react-server/src/store.js
vendored
19
chainforge/react-server/src/store.js
vendored
@ -6,6 +6,7 @@ import {
|
||||
} from 'react-flow-renderer';
|
||||
import { escapeBraces } from './backend/template';
|
||||
import { filterDict } from './backend/utils';
|
||||
import { APP_IS_RUNNING_LOCALLY } from './backend/utils';
|
||||
|
||||
// Initial project settings
|
||||
const initialAPIKeys = {};
|
||||
@ -27,12 +28,30 @@ export const colorPalettes = {
|
||||
|
||||
const refreshableOutputNodeTypes = new Set(['evaluator', 'prompt', 'inspect', 'vis', 'llmeval', 'textfields', 'chat', 'simpleval']);
|
||||
|
||||
export let initLLMProviders = [
|
||||
{ name: "GPT3.5", emoji: "🤖", model: "gpt-3.5-turbo", base_model: "gpt-3.5-turbo", temp: 1.0 }, // The base_model designates what settings form will be used, and must be unique.
|
||||
{ name: "GPT4", emoji: "🥵", model: "gpt-4", base_model: "gpt-4", temp: 1.0 },
|
||||
{ name: "Claude", emoji: "📚", model: "claude-2", base_model: "claude-v1", temp: 0.5 },
|
||||
{ name: "PaLM2", emoji: "🦬", model: "chat-bison-001", base_model: "palm2-bison", temp: 0.7 },
|
||||
{ name: "Azure OpenAI", emoji: "🔷", model: "azure-openai", base_model: "azure-openai", temp: 1.0 },
|
||||
{ name: "HuggingFace", emoji: "🤗", model: "tiiuae/falcon-7b-instruct", base_model: "hf", temp: 1.0 },
|
||||
];
|
||||
if (APP_IS_RUNNING_LOCALLY()) {
|
||||
initLLMProviders.push({ name: "Dalai (Alpaca.7B)", emoji: "🦙", model: "alpaca.7B", base_model: "dalai", temp: 0.5 });
|
||||
}
|
||||
|
||||
// A global store of variables, used for maintaining state
|
||||
// across ChainForge and ReactFlow components.
|
||||
const useStore = create((set, get) => ({
|
||||
nodes: [],
|
||||
edges: [],
|
||||
|
||||
// Available LLMs in ChainForge, in the format expected by LLMListItems.
|
||||
AvailableLLMs: [...initLLMProviders],
|
||||
setAvailableLLMs: (llmProviderList) => {
|
||||
set({AvailableLLMs: llmProviderList});
|
||||
},
|
||||
|
||||
// Keeping track of LLM API keys
|
||||
apiKeys: initialAPIKeys,
|
||||
setAPIKeys: (apiKeys) => {
|
||||
|
Loading…
x
Reference in New Issue
Block a user