Custom providers (#122)

* Removed Python backend files that are no longer used (everything in `promptengine`)

* Added `providers` subdomain, with `CustomProviderProtocol`, `provider` decorator, and global singleton `ProviderRegistry`

* Added a tab for custom providers, and a dropzone, in ChainForge global settings UI

* List custom providers in the Global Settings screen once added. 

* Added ability to remove custom providers by clicking X. 

* Make custom funcs sync but call them async'ly.

* Add Cohere custom provider example in examples/

*Cache the custom provider scripts and load them upon page load

* Rebuild react and update package version

* Bug fix when custom provider is deleted and settings screen is opened on the deleted custom provider
This commit is contained in:
ianarawjo 2023-08-27 15:11:42 -04:00 committed by GitHub
parent f43861f075
commit 0134dbf59b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 1126 additions and 1341 deletions

View File

@ -0,0 +1,55 @@
"""
A simple custom model provider to add to the ChainForge interface,
to support Cohere AI text completions through their Python API.
NOTE: You must have the `cohere` package installed and an API key.
"""
from chainforge.providers import provider
import cohere
# Init the Cohere client (replace with your Cohere API Key)
co = cohere.Client('<YOUR_API_KEY>')
# JSON schemas to pass react-jsonschema-form, one for this endpoints' settings and one to describe the settings UI.
COHERE_SETTINGS_SCHEMA = {
"settings": {
"temperature": {
"type": "number",
"title": "temperature",
"description": "Controls the 'creativity' or randomness of the response.",
"default": 0.75,
"minimum": 0,
"maximum": 5.0,
"multipleOf": 0.01,
},
"max_tokens": {
"type": "integer",
"title": "max_tokens",
"description": "Maximum number of tokens to generate in the response.",
"default": 100,
"minimum": 1,
"maximum": 1024,
},
},
"ui": {
"temperature": {
"ui:help": "Defaults to 1.0.",
"ui:widget": "range"
},
"max_tokens": {
"ui:help": "Defaults to 100.",
"ui:widget": "range"
},
}
}
# Our custom model provider for Cohere's text generation API.
@provider(name="Cohere",
emoji="🖇",
models=['command', 'command-nightly', 'command-light', 'command-light-nightly'],
rate_limit="sequential", # enter "sequential" for blocking; an integer N > 0 means N is the max mumber of requests per minute.
settings_schema=COHERE_SETTINGS_SCHEMA)
def CohereCompletion(prompt: str, model: str, temperature: float = 0.75, **kwargs) -> str:
print(f"Calling Cohere model {model} with prompt '{prompt}'...")
response = co.generate(model=model, prompt=prompt, temperature=temperature, **kwargs)
return response.generations[0].text

View File

@ -1,11 +1,12 @@
import json, os, asyncio, sys, traceback
import json, os, sys, asyncio, time
from dataclasses import dataclass
from enum import Enum
from typing import List
from statistics import mean, median, stdev
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
from chainforge.promptengine.utils import LLM, call_dalai
from chainforge.providers.dalai import call_dalai
from chainforge.providers import ProviderRegistry
import requests as py_requests
""" =================
@ -16,6 +17,7 @@ import requests as py_requests
# Setup Flask app to serve static version of React front-end
HOSTNAME = "localhost"
PORT = 8000
# SESSION_TOKEN = secrets.token_hex(32)
BUILD_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'react-server', 'build')
STATIC_DIR = os.path.join(BUILD_DIR, 'static')
app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
@ -27,10 +29,6 @@ cors = CORS(app, resources={r"/*": {"origins": "*"}})
CACHE_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'cache')
EXAMPLES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'examples')
LLM_NAME_MAP = {}
for model in LLM:
LLM_NAME_MAP[model.value] = model
class MetricType(Enum):
KeyValue = 0
KeyValue_Numeric = 1
@ -221,6 +219,22 @@ def run_over_responses(eval_func, responses: list, scope: str) -> list:
}
return responses
async def make_sync_call_async(sync_method, *args, **params):
"""
Makes a blocking synchronous call asynchronous, so that it can be awaited.
NOTE: This is necessary for LLM APIs that do not yet support async (e.g. Google PaLM).
"""
loop = asyncio.get_running_loop()
method = sync_method
if len(params) > 0:
def partial_sync_meth(*a):
return sync_method(*a, **params)
method = partial_sync_meth
return await loop.run_in_executor(None, method, *args)
def exclude_key(d, key_to_exclude):
return {k: v for k, v in d.items() if k != key_to_exclude}
""" ===================
FLASK SERVER ROUTES
@ -261,7 +275,7 @@ def executepy():
data = request.get_json()
# Check that all required info is here:
if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope'}):
if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope', 'token'}):
return jsonify({'error': 'POST data is improper format.'})
if not isinstance(data['id'], str) or len(data['id']) == 0:
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
@ -273,7 +287,7 @@ def executepy():
if (isinstance(responses, str) or not isinstance(responses, list)) or (len(responses) > 0 and any([not isinstance(r, dict) for r in responses])):
return jsonify({'error': 'POST data responses is improper format.'})
# add the path to any scripts to the path:
# Add the path to any scripts to the path:
try:
if 'script_paths' in data:
for script_path in data['script_paths']:
@ -367,7 +381,7 @@ def fetchOpenAIEval():
POST'd data should be in form:
{
name: <str> # The name of the eval to grab (without .cforge extension)
'name': <str> # The name of the eval to grab (without .cforge extension)
}
"""
# Verify post'd data
@ -404,9 +418,8 @@ def fetchOpenAIEval():
return jsonify({'error': f"Error creating a new directory 'oaievals' at filepath {oaievals_cache_dir}: {str(e)}"})
# Download the preconverted OpenAI eval from the GitHub main branch for ChainForge
import requests
_url = f"https://raw.githubusercontent.com/ianarawjo/ChainForge/main/chainforge/oaievals/{evalname}.cforge"
response = requests.get(_url)
response = py_requests.get(_url)
# Check if the request was successful (status code 200)
if response.status_code == 200:
@ -449,9 +462,9 @@ def makeFetchCall():
POST'd data should be in form:
{
url: <str> # the url to fetch from
headers: <dict> # a JSON object of the headers
body: <dict> # the request payload, as JSON
'url': <str> # the url to fetch from
'headers': <dict> # a JSON object of the headers
'body': <dict> # the request payload, as JSON
}
"""
# Verify post'd data
@ -486,8 +499,6 @@ async def callDalai():
if not set(data.keys()).issuperset({'prompt', 'model', 'server', 'n', 'temperature'}):
return jsonify({'error': 'POST data is improper format.'})
data['model'] = LLM_NAME_MAP[data['model']]
try:
query, response = await call_dalai(**data)
except Exception as e:
@ -498,6 +509,189 @@ async def callDalai():
return ret
@app.route('/app/initCustomProvider', methods=['POST'])
def initCustomProvider():
"""
Initalizes custom model provider(s) defined in a Python script,
and returns specs for the front-end UI provider dropdown and the providers' settings window.
POST'd data should be in form:
{
'code': <str> # the Python script to save + execute,
}
"""
# Verify post'd data
data = request.get_json()
if 'code' not in data:
return jsonify({'error': 'POST data is improper format.'})
# Sanity check that the code actually registers a provider
if '@provider' not in data['code']:
return jsonify({'error': """Did not detect a @provider decorator. Custom provider scripts should register at least one @provider.
Do `from chainforge.providers import provider` and decorate your provider completion function with @provider."""})
# Establish the custom provider script cache directory
provider_scripts_dir = os.path.join(CACHE_DIR, "provider_scripts")
if not os.path.isdir(provider_scripts_dir):
# Create the directory
try:
os.mkdir(provider_scripts_dir)
except Exception as e:
return jsonify({'error': f"Error creating a new directory 'provider_scripts' at filepath {provider_scripts_dir}: {str(e)}"})
# For keeping track of what script registered providers came from
script_id = str(round(time.time()*1000))
ProviderRegistry.set_curr_script_id(script_id)
ProviderRegistry.watch_next_registered()
# Attempt to run the Python script, in context
try:
exec(data['code'], globals(), None)
# This should have registered one or more new CustomModelProviders.
except Exception as e:
return jsonify({'error': f'Error while executing custom provider code:\n{str(e)}'})
# Check whether anything was updated, and what
new_registries = ProviderRegistry.last_registered()
if len(new_registries) == 0: # Determine whether there's at least one custom provider.
return jsonify({'error': 'Did not detect any custom providers added to the registry. Make sure you are registering your provider with @provider correctly.'})
# At least one provider was registered; detect if it had a past script id and remove those file(s) from the cache
if any((v is not None for v in new_registries.values())):
# For every registered provider that was overwritten, remove the cache'd script(s) associated with it:
past_script_ids = [v for v in new_registries.values() if v is not None]
for sid in past_script_ids:
past_script_path = os.path.join(provider_scripts_dir, f"{sid}.py")
try:
if os.path.isfile(past_script_path):
os.remove(past_script_path)
except Exception as e:
return jsonify({'error': f"Error removing cache'd custom provider script at filepath {past_script_path}: {str(e)}"})
# Get the names and specs of all currently registered CustomModelProviders,
# and pass that info to the front-end (excluding the func):
registered_providers = [exclude_key(d, 'func') for d in ProviderRegistry.get_all()]
# Copy the passed Python script to a local file in the package directory
try:
with open(os.path.join(provider_scripts_dir, f"{script_id}.py"), 'w') as f:
f.write(data['code'])
except Exception as e:
return jsonify({'error': f"Error saving script 'provider_scripts' at filepath {provider_scripts_dir}: {str(e)}"})
# Return all loaded providers
return jsonify({'providers': registered_providers})
@app.route('/app/loadCachedCustomProviders', methods=['POST'])
def loadCachedCustomProviders():
"""
Initalizes all custom model provider(s) in the local provider_scripts directory.
"""
provider_scripts_dir = os.path.join(CACHE_DIR, "provider_scripts")
if not os.path.isdir(provider_scripts_dir):
# No providers to load.
return jsonify({'providers': []})
try:
for file_name in os.listdir(provider_scripts_dir):
file_path = os.path.join(provider_scripts_dir, file_name)
if os.path.isfile(file_path) and os.path.splitext(file_path)[1] == '.py':
# For keeping track of what script registered providers came from
ProviderRegistry.set_curr_script_id(os.path.splitext(file_name)[0])
# Read the Python script
with open(file_path, 'r') as f:
code = f.read()
# Try to execute it in the global context
try:
exec(code, globals(), None)
except Exception as code_exc:
# Remove the script file associated w the failed execution
os.remove(file_path)
raise code_exc
except Exception as e:
return jsonify({'error': f'Error while loading custom providers from cache: \n{str(e)}'})
# Get the names and specs of all currently registered CustomModelProviders,
# and pass that info to the front-end (excluding the func):
registered_providers = [exclude_key(d, 'func') for d in ProviderRegistry.get_all()]
return jsonify({'providers': registered_providers})
@app.route('/app/removeCustomProvider', methods=['POST'])
def removeCustomProvider():
"""
Initalizes custom model provider(s) defined in a Python script,
and returns specs for the front-end UI provider dropdown and the providers' settings window.
POST'd data should be in form:
{
'name': <str> # a name that refers to the registered custom provider in the `ProviderRegistry`
}
"""
# Verify post'd data
data = request.get_json()
name = data.get('name')
if name is None:
return jsonify({'error': 'POST data is improper format.'})
if not ProviderRegistry.has(name):
return jsonify({'error': f'Could not find a custom provider named "{name}"'})
# Get the script id associated with the provider we're about to remove
script_id = ProviderRegistry.get(name).get('script_id')
# Remove the custom provider from the registry
ProviderRegistry.remove(name)
# Attempt to delete associated script from cache
if script_id:
script_path = os.path.join(CACHE_DIR, "provider_scripts", f"{script_id}.py")
if os.path.isfile(script_path):
os.remove(script_path)
return jsonify({'success': True})
@app.route('/app/callCustomProvider', methods=['POST'])
async def callCustomProvider():
"""
Calls a custom model provider and returns the response.
POST'd data should be in form:
{
'name': <str> # the name of the provider in the `ProviderRegistry`
'params': <dict> # the params (prompt, model, etc) to pass to the provider function.
}
"""
# Verify post'd data
data = request.get_json()
if not set(data.keys()).issuperset({'name', 'params'}):
return jsonify({'error': 'POST data is improper format.'})
# Load the name of the provider
name = data['name']
params = data['params']
# Double-check that the custom provider exists in the registry, and (if passed) a model with that name exists
provider_spec = ProviderRegistry.get(name)
if provider_spec is None:
return jsonify({'error': f'Could not find provider named {name}. Perhaps you need to import a custom provider script?'})
# Call + await the custom provider function, passing in the JSON payload as kwargs
try:
response = await make_sync_call_async(provider_spec.get('func'), **params)
except Exception as e:
return jsonify({'error': f'Error encountered while calling custom provider function: {str(e)}'})
# Return the response
return jsonify({'response': response})
def run_server(host="", port=8000, cmd_args=None):
global HOSTNAME, PORT
HOSTNAME = host

View File

@ -1 +0,0 @@

View File

@ -1,82 +0,0 @@
"""
A list of all model APIs natively supported by ChainForge.
"""
from enum import Enum
class LLM(str, Enum):
""" OpenAI Chat """
OpenAI_ChatGPT = "gpt-3.5-turbo"
OpenAI_ChatGPT_16k = "gpt-3.5-turbo-16k"
OpenAI_ChatGPT_16k_0613 = "gpt-3.5-turbo-16k-0613"
OpenAI_ChatGPT_0301 = "gpt-3.5-turbo-0301"
OpenAI_ChatGPT_0613 = "gpt-3.5-turbo-0613"
OpenAI_GPT4 = "gpt-4"
OpenAI_GPT4_0314 = "gpt-4-0314"
OpenAI_GPT4_0613 = "gpt-4-0613"
OpenAI_GPT4_32k = "gpt-4-32k"
OpenAI_GPT4_32k_0314 = "gpt-4-32k-0314"
OpenAI_GPT4_32k_0613 = "gpt-4-32k-0613"
""" OpenAI Text Completions """
OpenAI_Davinci003 = "text-davinci-003"
OpenAI_Davinci002 = "text-davinci-002"
""" Azure OpenAI Endpoints """
Azure_OpenAI = "azure-openai"
""" Dalai-served models (Alpaca and Llama) """
Dalai_Alpaca_7B = "alpaca.7B"
Dalai_Alpaca_13B = "alpaca.13B"
Dalai_Llama_7B = "llama.7B"
Dalai_Llama_13B = "llama.13B"
Dalai_Llama_30B = "llama.30B"
Dalai_Llama_65B = "llama.65B"
""" Anthropic """
# Our largest model, ideal for a wide range of more complex tasks. Using this model name
# will automatically switch you to newer versions of claude-v1 as they are released.
Claude_v1 = "claude-v1"
# An earlier version of claude-v1
Claude_v1_0 = "claude-v1.0"
# An improved version of claude-v1. It is slightly improved at general helpfulness,
# instruction following, coding, and other tasks. It is also considerably better with
# non-English languages. This model also has the ability to role play (in harmless ways)
# more consistently, and it defaults to writing somewhat longer and more thorough responses.
Claude_v1_2 = "claude-v1.2"
# A significantly improved version of claude-v1. Compared to claude-v1.2, it's more robust
# against red-team inputs, better at precise instruction-following, better at code, and better
# and non-English dialogue and writing.
Claude_v1_3 = "claude-v1.3"
# A smaller model with far lower latency, sampling at roughly 40 words/sec! Its output quality
# is somewhat lower than claude-v1 models, particularly for complex tasks. However, it is much
# less expensive and blazing fast. We believe that this model provides more than adequate performance
# on a range of tasks including text classification, summarization, and lightweight chat applications,
# as well as search result summarization. Using this model name will automatically switch you to newer
# versions of claude-instant-v1 as they are released.
Claude_v1_instant = "claude-instant-v1"
""" Google models """
PaLM2_Text_Bison = "text-bison-001" # it's really models/text-bison-001, but that's confusing
PaLM2_Chat_Bison = "chat-bison-001"
# LLM APIs often have rate limits, which control number of requests. E.g., OpenAI: https://platform.openai.com/account/rate-limits
# For a basic organization in OpenAI, GPT3.5 is currently 3500 and GPT4 is 200 RPM (requests per minute).
# For Anthropic evaluaton preview of Claude, can only send 1 request at a time (synchronously).
# This 'cheap' version of controlling for rate limits is to wait a few seconds between batches of requests being sent off.
# If a model is missing from below, it means we must send and receive only 1 request at a time (synchronous).
# The following is only a guideline, and a bit on the conservative side.
RATE_LIMITS = {
LLM.OpenAI_ChatGPT: (30, 10), # max 30 requests a batch; wait 10 seconds between
LLM.OpenAI_ChatGPT_0301: (30, 10),
LLM.OpenAI_GPT4: (4, 15), # max 4 requests a batch; wait 15 seconds between
LLM.OpenAI_GPT4_0314: (4, 15),
LLM.OpenAI_GPT4_32k: (4, 15),
LLM.OpenAI_GPT4_32k_0314: (4, 15),
LLM.PaLM2_Text_Bison: (4, 10), # max 30 requests per minute; so do 4 per batch, 10 seconds between (conservative)
LLM.PaLM2_Chat_Bison: (4, 10),
}

View File

@ -1,276 +0,0 @@
from abc import abstractmethod
from typing import List, Dict, Tuple, Iterator, Union, Optional
import json, os, asyncio, random, string
from chainforge.promptengine.utils import call_chatgpt, call_dalai, call_anthropic, call_google_palm, call_azure_openai, is_valid_filepath, is_valid_json, extract_responses, merge_response_objs
from chainforge.promptengine.models import LLM, RATE_LIMITS
from chainforge.promptengine.template import PromptTemplate, PromptPermutationGenerator
class LLMResponseException(Exception):
""" Raised when there is an error generating a single response from an LLM """
pass
"""
Abstract class that captures a generic querying interface to prompt LLMs
"""
class PromptPipeline:
def __init__(self, storageFile: str):
if not is_valid_filepath(storageFile):
raise IOError(f"Filepath {storageFile} is invalid, or you do not have write access.")
self._filepath = storageFile
@abstractmethod
def gen_prompts(self, properties) -> Iterator[PromptTemplate]:
raise NotImplementedError("Please Implement the gen_prompts method")
async def gen_responses(self, properties, llm: LLM, n: int = 1, temperature: float = 1.0, **llm_params) -> Iterator[Union[Dict, LLMResponseException]]:
"""
Calls LLM 'llm' with all prompts, and yields responses as dicts in format {prompt, query, response, llm, info}.
Queries are sent off asynchronously (if possible).
Yields responses as they come in. All LLM calls that yield errors (e.g., 'rate limit' error)
will yield an individual LLMResponseException, so downstream tasks must check for this exception type.
By default, for each response successfully collected, this also saves reponses to disk as JSON at the filepath given during init.
(Very useful for saving money in case something goes awry!)
To clear the cached responses, call clear_cached_responses().
NOTE: The reason we collect, rather than raise, LLMResponseExceptions is because some API calls
may still succeed, even if some fail. We don't want to stop listening to pending API calls,
because we may lose money. Instead, we fail selectively.
Do not override this function.
"""
# Double-check that properties is the correct type (JSON dict):
if not is_valid_json(properties):
raise ValueError("Properties argument is not valid JSON.")
# Load any cache'd responses
responses = self._load_cached_responses()
# Query LLM with each prompt, yield + cache the responses
tasks = []
max_req, wait_secs = RATE_LIMITS[llm] if llm in RATE_LIMITS else (1, 0)
num_queries_sent = -1
for prompt in self.gen_prompts(properties):
if isinstance(prompt, PromptTemplate) and not prompt.is_concrete():
raise Exception(f"Cannot send a prompt '{prompt}' to LLM: Prompt is a template.")
prompt_str = str(prompt)
info = prompt.fill_history
metavars = prompt.metavars
cached_resp = responses[prompt_str] if prompt_str in responses else None
extracted_resps = cached_resp["responses"] if cached_resp is not None else []
# First check if there is already a response for this item under these settings. If so, we can save an LLM call:
if cached_resp and len(extracted_resps) >= n:
print(f" - Found cache'd response for prompt {prompt_str}. Using...")
yield {
"prompt": prompt_str,
"query": cached_resp["query"],
"responses": extracted_resps[:n],
"raw_response": cached_resp["raw_response"],
"llm": cached_resp["llm"] if "llm" in cached_resp else LLM.OpenAI_ChatGPT.value,
# We want to use the new info, since 'vars' could have changed even though
# the prompt text is the same (e.g., "this is a tool -> this is a {x} where x='tool'")
"info": info,
"metavars": metavars
}
continue
num_queries_sent += 1
if max_req > 1:
# Call the LLM asynchronously to generate a response, sending off
# requests in batches of size 'max_req' separated by seconds 'wait_secs' to avoid hitting rate limit
tasks.append(self._prompt_llm(llm=llm,
prompt=prompt,
n=n,
temperature=temperature,
past_resp_obj=cached_resp,
query_number=num_queries_sent,
rate_limit_batch_size=max_req,
rate_limit_wait_secs=wait_secs,
**llm_params))
else:
# Block. Await + yield a single LLM call.
_, query, response, past_resp_obj = await self._prompt_llm(llm=llm,
prompt=prompt,
n=n,
temperature=temperature,
past_resp_obj=cached_resp,
**llm_params)
# Check for selective failure
if query is None and isinstance(response, LLMResponseException):
yield response # yield the LLMResponseException
continue
# Create a response obj to represent the response
resp_obj = {
"prompt": str(prompt),
"query": query,
"responses": extract_responses(response, llm),
"raw_response": response,
"llm": llm.value,
"info": info,
"metavars": metavars
}
# Merge the response obj with the past one, if necessary
if past_resp_obj is not None:
resp_obj = merge_response_objs(resp_obj, past_resp_obj)
# Save the current state of cache'd responses to a JSON file
responses[resp_obj["prompt"]] = resp_obj
self._cache_responses(responses)
print(f" - collected response from {llm.value} for prompt:", resp_obj['prompt'])
# Yield the response
yield resp_obj
# Yield responses as they come in
for task in asyncio.as_completed(tasks):
# Collect the response from the earliest completed task
prompt, query, response, past_resp_obj = await task
# Check for selective failure
if query is None and isinstance(response, LLMResponseException):
yield response # yield the LLMResponseException
continue
# Each prompt has a history of what was filled in from its base template.
# This data --like, "class", "language", "library" etc --can be useful when parsing responses.
info = prompt.fill_history
metavars = prompt.metavars
# Create a response obj to represent the response
resp_obj = {
"prompt": str(prompt),
"query": query,
"responses": extract_responses(response, llm),
"raw_response": response,
"llm": llm.value,
"info": info,
"metavars": metavars,
}
# Merge the response obj with the past one, if necessary
if past_resp_obj is not None:
resp_obj = merge_response_objs(resp_obj, past_resp_obj)
# Save the current state of cache'd responses to a JSON file
# NOTE: We do this to save money --in case something breaks between calls, can ensure we got the data!
responses[resp_obj["prompt"]] = resp_obj
self._cache_responses(responses)
print(f" - collected response from {llm.value} for prompt:", resp_obj['prompt'])
# Yield the response
yield resp_obj
def _load_cached_responses(self) -> Dict:
"""
Loads saved responses of JSON at self._filepath.
Useful for continuing if computation was interrupted halfway through.
"""
if os.path.isfile(self._filepath):
with open(self._filepath, encoding="utf-8") as f:
responses = json.load(f)
return responses
else:
return {}
def _cache_responses(self, responses) -> None:
with open(self._filepath, "w", encoding='utf-8') as f:
json.dump(responses, f)
def clear_cached_responses(self) -> None:
self._cache_responses({})
async def _prompt_llm(self,
llm: LLM,
prompt: PromptTemplate,
n: int = 1,
temperature: float = 1.0,
past_resp_obj: Optional[Dict] = None,
query_number: Optional[int] = None,
rate_limit_batch_size: Optional[int] = None,
rate_limit_wait_secs: Optional[float] = None,
**llm_params) -> Tuple[str, Dict, Union[List, Dict], Union[Dict, None]]:
# Detect how many responses we have already (from cache obj past_resp_obj)
if past_resp_obj is not None:
# How many *new* queries we need to send:
# NOTE: The check n > len(past_resp_obj["responses"]) should occur prior to calling this function.
n = n - len(past_resp_obj["responses"])
# Block asynchronously when we exceed rate limits
if query_number is not None and rate_limit_batch_size is not None and rate_limit_wait_secs is not None and rate_limit_batch_size >= 1 and rate_limit_wait_secs > 0:
batch_num = int(query_number / rate_limit_batch_size)
if batch_num > 0:
# We've exceeded the estimated batch rate limit and need to wait the appropriate seconds before sending off new API calls:
wait_secs = rate_limit_wait_secs * batch_num
if query_number % rate_limit_batch_size == 0: # Print when we start blocking, for each batch
print(f"Batch rate limit of {rate_limit_batch_size} reached for LLM {llm}. Waiting {wait_secs} seconds until sending request batch #{batch_num}...")
await asyncio.sleep(wait_secs)
# Get the correct API call for the given LLM:
call_llm = None
if llm.name[:6] == 'OpenAI':
call_llm = call_chatgpt
elif llm.name[:5] == 'Azure':
call_llm = call_azure_openai
elif llm.name[:5] == 'PaLM2':
call_llm = call_google_palm
elif llm.name[:5] == 'Dalai':
call_llm = call_dalai
elif llm.value[:6] == 'claude':
call_llm = call_anthropic
else:
raise Exception(f"Language model {llm} is not supported.")
# Now try to call the API. If it fails for whatever reason, 'soft fail' by returning
# an LLMResponseException object as the 'response'.
try:
query, response = await call_llm(prompt=str(prompt), model=llm, n=n, temperature=temperature, **llm_params)
except Exception as e:
return prompt, None, LLMResponseException(str(e)), None
return prompt, query, response, past_resp_obj
"""
Most basic prompt pipeline: given a prompt (and any variables, if it's a template),
query the LLM with all prompt permutations, and cache responses.
"""
class PromptLLM(PromptPipeline):
def __init__(self, template: str, storageFile: str):
self._template = PromptTemplate(template)
super().__init__(storageFile)
def gen_prompts(self, properties: dict) -> Iterator[PromptTemplate]:
gen_prompts = PromptPermutationGenerator(self._template)
return gen_prompts(properties)
"""
A dummy class that spoofs LLM responses. Used for testing.
"""
class PromptLLMDummy(PromptLLM):
def __init__(self, template: str, storageFile: str):
# Hijack the 'extract_responses' method so that for whichever 'llm' parameter,
# it will just return the response verbatim (since dummy responses will always be strings)
global extract_responses
extract_responses = lambda response, llm: response
super().__init__(template, storageFile)
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0, past_resp_obj: Optional[Dict] = None, **params) -> Tuple[Dict, Dict]:
# Wait a random amount of time, to simulate wait times from real queries
await asyncio.sleep(random.uniform(0.1, 3))
if random.random() > 0.2:
# Return a random string of characters of random length (within a predefined range)
return prompt, {'prompt': str(prompt)}, [''.join(random.choice(string.ascii_letters) for i in range(random.randint(25, 80))) for _ in range(n)], past_resp_obj
else:
# Return a mock 'error' making the API request
return prompt, None, LLMResponseException('Dummy error'), past_resp_obj

View File

@ -1,257 +0,0 @@
import re
from string import Template
from typing import Dict, List, Union
def escape_dollar_signs(s: str) -> str:
pattern = r'\$(?![{])'
replaced_string = re.sub(pattern, '$$', s)
return replaced_string
class PromptTemplate:
"""
Wrapper around string.Template. Use to generate prompts fast.
Example usage:
prompt_temp = PromptTemplate('Can you list all the cities in the country ${country} by the cheapest ${domain} prices?')
concrete_prompt = prompt_temp.fill({
"country": "France",
"domain": "rent"
});
print(concrete_prompt)
# Fill can also fill the prompt only partially, which gives us a new prompt template:
partial_prompt = prompt_temp.fill({
"domain": "rent"
});
print(partial_prompt)
"""
def __init__(self, templateStr):
"""
Initialize a PromptTemplate with a string in string.Template format.
(See https://docs.python.org/3/library/string.html#template-strings for more details.)
"""
# NOTE: ChainForge only supports placeholders with braces {}
# We detect any $ without { to the right of them, and insert a '$' before it to escape the $.
templateStr = escape_dollar_signs(templateStr)
try:
Template(templateStr)
except Exception:
raise Exception("Invalid template formatting for string:", templateStr)
self.template = templateStr
self.fill_history = {}
self.metavars = {}
def __str__(self) -> str:
return self.template
def __repr__(self) -> str:
return self.__str__()
def has_var(self, varname) -> bool:
""" Returns True if the template has a variable with the given name.
"""
subbed_str = Template(self.template).safe_substitute({varname: '_'})
return subbed_str != self.template # if the strings differ, a replacement occurred
def is_concrete(self) -> bool:
""" Returns True if no template variables are left in template string.
"""
try:
Template(self.template).substitute({})
return True # no exception raised means there was nothing to substitute...
except Exception:
return False
def fill(self, paramDict: Dict[str, Union[str, Dict[str, str]]]) -> 'PromptTemplate':
"""
Formats the template string with the given parameters, returning a new PromptTemplate.
Can return a partial completion.
NOTE: paramDict values can be in a special form: {text: <str>, fill_history: {varname: <str>}}
in order to bundle in any past fill history that is lost in the current text.
Example usage:
prompt = prompt_template.fill({
"className": className,
"library": "Kivy",
"PL": "Python"
});
"""
# Check for special 'past fill history' format:
past_fill_history = {}
past_metavars = {}
if len(paramDict) > 0 and isinstance(next(iter(paramDict.values())), dict):
for obj in paramDict.values():
if "fill_history" in obj:
past_fill_history = {**obj['fill_history'], **past_fill_history}
if "metavars" in obj:
past_metavars = {**obj['metavars'], **past_metavars}
paramDict = {param: obj['text'] for param, obj in paramDict.items()}
filled_pt = PromptTemplate(
Template(self.template).safe_substitute(paramDict)
)
# Deep copy prior fill history of this PromptTemplate from this version over to new one
filled_pt.fill_history = { key: val for (key, val) in self.fill_history.items() }
filled_pt.metavars = { key: val for (key, val) in self.metavars.items() }
# Append any past history passed as vars:
for key, val in past_fill_history.items():
if key in filled_pt.fill_history:
print(f"Warning: PromptTemplate already has fill history for key {key}.")
filled_pt.fill_history[key] = val
# Append any metavars, overwriting existing ones with the same key
for key, val in past_metavars.items():
filled_pt.metavars[key] = val
# Add the new fill history using the passed parameters that we just filled in
for key, val in paramDict.items():
if key in filled_pt.fill_history:
print(f"Warning: PromptTemplate already has fill history for key {key}.")
filled_pt.fill_history[key] = val
return filled_pt
class PromptPermutationGenerator:
"""
Given a PromptTemplate and a parameter dict that includes arrays of items,
generate all the permutations of the prompt for all permutations of the items.
NOTE: Items can be in a special form: {text: <str>, fill_history: {varname: <str>}}
in order to bundle in any past fill history that is lost in the current text.
Example usage:
prompt_gen = PromptPermutationGenerator('Can you list all the cities in the country ${country} by the cheapest ${domain} prices?')
for prompt in prompt_gen({"country":["Canada", "South Africa", "China"],
"domain": ["rent", "food", "energy"]}):
print(prompt)
"""
def __init__(self, template: Union[PromptTemplate, str]):
if isinstance(template, str):
template = PromptTemplate(template)
self.template = template
def _gen_perm(self, template, params_to_fill, paramDict):
if len(params_to_fill) == 0: return []
# Extract the first param that occurs in the current template
param = None
params_left = params_to_fill
for p in params_to_fill:
if template.has_var(p):
param = p
params_left = [_p for _p in params_to_fill if _p != p]
break
if param is None:
return [template]
# Generate new prompts by filling in its value(s) into the PromptTemplate
val = paramDict[param]
if isinstance(val, list):
new_prompt_temps = []
for v in val:
param_fill_dict = {param: v}
# If this var has an "associate_id", then it wants to "carry with"
# values of other prompt parameters with the same id.
# We have to find any parameters with values of the same id,
# and fill them in alongside the initial parameter v:
if isinstance(v, dict) and "associate_id" in v:
v_associate_id = v["associate_id"]
for other_param in params_left[:]:
if template.has_var(other_param) and isinstance(paramDict[other_param], list):
other_vals = paramDict[other_param]
for ov in other_vals:
if isinstance(ov, dict) and ov.get("associate_id") == v_associate_id:
# This is a match. We should add the val to our param_fill_dict:
param_fill_dict[other_param] = ov
break
# Fill the template with the param values and append it to the list
new_prompt_temps.append(template.fill(param_fill_dict))
elif isinstance(val, str):
new_prompt_temps = [template.fill({param: val})]
else:
raise ValueError("Value of prompt template parameter is not a list or a string, but of type " + str(type(val)))
# Recurse
if len(params_left) == 0:
return new_prompt_temps
else:
res = []
for p in new_prompt_temps:
res.extend(self._gen_perm(p, params_left, paramDict))
return res
def __call__(self, paramDict: Dict[str, Union[str, List[str], Dict[str, str]]]):
if len(paramDict) == 0:
yield self.template
return
for p in self._gen_perm(self.template, list(paramDict.keys()), paramDict):
yield p
# Test cases
if __name__ == '__main__':
# Dollar sign escape works
tests = ["What is $2 + $2?", "If I have $4 and I want ${dollars} then how many do I have?", "$4 is equal to ${dollars}?", "${what} is the $400?"]
escaped_tests = [escape_dollar_signs(t) for t in tests]
print(escaped_tests)
assert escaped_tests[0] == "What is $$2 + $$2?"
assert escaped_tests[1] == "If I have $$4 and I want ${dollars} then how many do I have?"
assert escaped_tests[2] == "$$4 is equal to ${dollars}?"
assert escaped_tests[3] == "${what} is the $$400?"
# Single template
gen = PromptPermutationGenerator('What is the ${timeframe} when ${person} was born?')
res = [r for r in gen({'timeframe': ['year', 'decade', 'century'], 'person': ['Howard Hughes', 'Toni Morrison', 'Otis Redding']})]
for r in res:
print(r)
assert len(res) == 9
# Nested templates
gen = PromptPermutationGenerator('${prefix}... ${suffix}')
res = [r for r in gen({
'prefix': ['Who invented ${tool}?', 'When was ${tool} invented?', 'What can you do with ${tool}?'],
'suffix': ['Phrase your answer in the form of a ${response_type}', 'Respond with a ${response_type}'],
'tool': ['the flashlight', 'CRISPR', 'rubber'],
'response_type': ['question', 'poem', 'nightmare']
})]
for r in res:
print(r)
assert len(res) == (3*3)*(2*3)
# 'Carry together' vars with 'metavar' data attached
# NOTE: This feature may be used when passing rows of a table, so that vars that have associated values,
# like 'inventor' with 'tool', 'carry together' when being filled into the prompt template.
# In addition, 'metavars' may be attached which are, commonly, the values of other columns for that row, but
# columns which weren't used to fill in the prompt template explcitly.
gen = PromptPermutationGenerator('What ${timeframe} did ${inventor} invent the ${tool}?')
res = [r for r in gen({
'inventor': [
{'text': "Thomas Edison", "fill_history": {}, "associate_id": "A", "metavars": { "year": 1879 }},
{'text': "Alexander Fleming", "fill_history": {}, "associate_id": "B", "metavars": { "year": 1928 }},
{'text': "William Shockley", "fill_history": {}, "associate_id": "C", "metavars": { "year": 1947 }},
],
'tool': [
{'text': "lightbulb", "fill_history": {}, "associate_id": "A"},
{'text': "penicillin", "fill_history": {}, "associate_id": "B"},
{'text': "transistor", "fill_history": {}, "associate_id": "C"},
],
'timeframe': [ "year", "decade", "century" ]
})]
for r in res:
r_str = str(r)
print(r_str, r.metavars)
assert "year" in r.metavars
if "Edison" in r_str:
assert "lightbulb" in r_str
elif "Fleming" in r_str:
assert "penicillin" in r_str
elif "Shockley" in r_str:
assert "transistor" in r_str
assert len(res) == 3*3

View File

@ -1,505 +0,0 @@
from typing import Dict, Tuple, List, Union, Optional
import json, os, time, asyncio
from string import Template
from chainforge.promptengine.models import LLM
DALAI_MODEL = None
DALAI_RESPONSE = None
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY")
GOOGLE_PALM_API_KEY = os.environ.get("PALM_API_KEY")
AZURE_OPENAI_KEY = os.environ.get("AZURE_OPENAI_KEY")
AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
def set_api_keys(api_keys):
"""
Sets the local API keys for the revelant LLM API(s).
Currently only supports 'OpenAI', 'Anthropic'.
"""
global ANTHROPIC_API_KEY, GOOGLE_PALM_API_KEY, AZURE_OPENAI_KEY, AZURE_OPENAI_ENDPOINT
def key_is_present(name):
return name in api_keys and len(api_keys[name].strip()) > 0
if key_is_present('OpenAI'):
import openai
openai.api_key = api_keys['OpenAI']
if key_is_present('Anthropic'):
ANTHROPIC_API_KEY = api_keys['Anthropic']
if key_is_present('Google'):
GOOGLE_PALM_API_KEY = api_keys['Google']
if key_is_present('Azure_OpenAI'):
AZURE_OPENAI_KEY = api_keys['Azure_OpenAI']
if key_is_present('Azure_OpenAI_Endpoint'):
AZURE_OPENAI_ENDPOINT = api_keys['Azure_OpenAI_Endpoint']
# Soft fail for non-present keys
async def make_sync_call_async(sync_method, *args, **params):
"""
Makes a blocking synchronous call asynchronous, so that it can be awaited.
NOTE: This is necessary for LLM APIs that do not yet support async (e.g. Google PaLM).
"""
loop = asyncio.get_running_loop()
method = sync_method
if len(params) > 0:
def partial_sync_meth(*a):
return sync_method(*a, **params)
method = partial_sync_meth
return await loop.run_in_executor(None, method, *args)
async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float= 1.0,
system_msg: Optional[str]=None,
**params) -> Tuple[Dict, Dict]:
"""
Calls GPT3.5 via OpenAI's API.
Returns raw query and response JSON dicts.
NOTE: It is recommended to set an environment variable OPENAI_API_KEY with your OpenAI API key
"""
import openai
if not openai.api_key:
openai.api_key = os.environ.get("OPENAI_API_KEY")
model = model.value
if 'stop' in params and (not isinstance(params['stop'], list) or len(params['stop']) == 0):
del params['stop']
if 'functions' in params and (not isinstance(params['functions'], list) or len(params['functions']) == 0):
del params['functions']
if 'function_call' in params and (not isinstance(params['function_call'], str) or len(params['function_call'].strip()) == 0):
del params['function_call']
print(f"Querying OpenAI model '{model}' with prompt '{prompt}'...")
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
query = {
"model": model,
"n": n,
"temperature": temperature,
**params, # 'the rest' of the settings, passed from a front-end app
}
if 'davinci' in model: # text completions model
openai_call = openai.Completion.acreate
query['prompt'] = prompt
else: # chat model
openai_call = openai.ChatCompletion.acreate
query['messages'] = [
{"role": "system", "content": system_msg},
{"role": "user", "content": prompt},
]
try:
response = await openai_call(**query)
except Exception as e:
if (isinstance(e, openai.error.AuthenticationError)):
raise Exception("Could not authenticate to OpenAI. Double-check that your API key is set in Settings or in your local Python environment.")
raise e
return query, response
async def call_azure_openai(prompt: str, model: LLM, n: int = 1, temperature: float= 1.0,
deployment_name: str = 'gpt-35-turbo',
model_type: str = "chat-completion",
api_version: str = "2023-05-15",
system_msg: Optional[str]=None,
**params) -> Tuple[Dict, Dict]:
"""
Calls an OpenAI chat model GPT3.5 or GPT4 via Microsoft Azure services.
Returns raw query and response JSON dicts.
NOTE: It is recommended to set an environment variables AZURE_OPENAI_KEY and AZURE_OPENAI_ENDPOINT
"""
global AZURE_OPENAI_KEY, AZURE_OPENAI_ENDPOINT
if AZURE_OPENAI_KEY is None:
raise Exception("Could not find an Azure OpenAPI Key to use. Double-check that your key is set in Settings or in your local Python environment.")
if AZURE_OPENAI_ENDPOINT is None:
raise Exception("Could not find an Azure OpenAI Endpoint to use. Double-check that your endpoint is set in Settings or in your local Python environment.")
import openai
openai.api_type = "azure"
openai.api_version = api_version
openai.api_key = AZURE_OPENAI_KEY
openai.api_base = AZURE_OPENAI_ENDPOINT
if 'stop' in params and not isinstance(params['stop'], list) or len(params['stop']) == 0:
del params['stop']
if 'functions' in params and not isinstance(params['functions'], list) or len(params['functions']) == 0:
del params['functions']
if 'function_call' in params and not isinstance(params['function_call'], str) or len(params['function_call'].strip()) == 0:
del params['function_call']
print(f"Querying Azure OpenAI deployed model '{deployment_name}' at endpoint '{AZURE_OPENAI_ENDPOINT}' with prompt '{prompt}'...")
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
query = {
"engine": deployment_name, # this differs from a basic OpenAI call
"n": n,
"temperature": temperature,
**params, # 'the rest' of the settings, passed from a front-end app
}
if model_type == 'text-completion':
openai_call = openai.Completion.acreate
query['prompt'] = prompt
else:
openai_call = openai.ChatCompletion.acreate
query['messages'] = [
{"role": "system", "content": system_msg},
{"role": "user", "content": prompt},
]
try:
response = await openai_call(**query)
except Exception as e:
if (isinstance(e, openai.error.AuthenticationError)):
raise Exception("Could not authenticate to OpenAI. Double-check that your API key is set in Settings or in your local Python environment.")
raise e
return query, response
async def call_anthropic(prompt: str, model: LLM, n: int = 1, temperature: float= 1.0,
max_tokens_to_sample=1024,
async_mode=False,
custom_prompt_wrapper: Optional[str]=None,
stop_sequences: Optional[List[str]]=["\n\nHuman:"],
**params) -> Tuple[Dict, Dict]:
"""
Calls Anthropic API with the given model, passing in params.
Returns raw query and response JSON dicts.
Unique parameters:
- custom_prompt_wrapper: Anthropic models expect prompts in form "\n\nHuman: ${prompt}\n\nAssistant". If you wish to
explore custom prompt wrappers that deviate, write a python Template that maps from 'prompt' to custom wrapper.
If set to None, defaults to Anthropic's suggested prompt wrapper.
- max_tokens_to_sample: A maximum number of tokens to generate before stopping.
- stop_sequences: A list of strings upon which to stop generating. Defaults to ["\n\nHuman:"], the cue for the next turn in the dialog agent.
- async_mode: Evaluation access to Claude limits calls to 1 at a time, meaning we can't take advantage of async.
If you want to send all 'n' requests at once, you can set async_mode to True.
NOTE: It is recommended to set an environment variable ANTHROPIC_API_KEY with your Anthropic API key
"""
if ANTHROPIC_API_KEY is None:
raise Exception("Could not find an API key for Anthropic models. Double-check that your API key is set in Settings or in your local Python environment.")
import anthropic
client = anthropic.Client(ANTHROPIC_API_KEY)
# Wrap the prompt in the provided template, or use the default Anthropic one
if custom_prompt_wrapper is None or '${prompt}' not in custom_prompt_wrapper:
custom_prompt_wrapper = anthropic.HUMAN_PROMPT + " ${prompt}" + anthropic.AI_PROMPT
prompt_wrapper_template = Template(custom_prompt_wrapper)
wrapped_prompt = prompt_wrapper_template.substitute(prompt=prompt)
# Format query
query = {
'model': model.value,
'prompt': wrapped_prompt,
'max_tokens_to_sample': max_tokens_to_sample,
'stop_sequences': stop_sequences,
'temperature': temperature,
**params
}
print(f"Calling Anthropic model '{model.value}' with prompt '{prompt}' (n={n}). Please be patient...")
# Request responses using the passed async_mode
responses = []
if async_mode:
# Gather n responses by firing off all API requests at once
tasks = [client.acompletion(**query) for _ in range(n)]
responses = await asyncio.gather(*tasks)
else:
# Repeat call n times, waiting for each response to come in:
while len(responses) < n:
resp = await client.acompletion(**query)
responses.append(resp)
print(f'{model.value} response {len(responses)} of {n}:\n', resp)
return query, responses
async def call_google_palm(prompt: str, model: LLM, n: int = 1, temperature: float= 0.7,
max_output_tokens=800,
async_mode=False,
**params) -> Tuple[Dict, Dict]:
"""
Calls a Google PaLM model.
Returns raw query and response JSON dicts.
"""
if GOOGLE_PALM_API_KEY is None:
raise Exception("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local Python environment.")
import google.generativeai as palm
palm.configure(api_key=GOOGLE_PALM_API_KEY)
is_chat_model = 'chat' in model.value
query = {
'model': f"models/{model.value}",
'prompt': prompt,
'candidate_count': n,
'temperature': temperature,
'max_output_tokens': max_output_tokens,
**params,
}
# Remove erroneous parameters for text and chat models
if 'top_k' in query and query['top_k'] <= 0:
del query['top_k']
if 'top_p' in query and query['top_p'] <= 0:
del query['top_p']
if is_chat_model and 'max_output_tokens' in query:
del query['max_output_tokens']
if is_chat_model and 'stop_sequences' in query:
del query['stop_sequences']
# Get the correct model's completions call
palm_call = palm.chat if is_chat_model else palm.generate_text
# Google PaLM's python API does not currently support async calls.
# To make one, we need to wrap it in an asynchronous executor:
completion = await make_sync_call_async(palm_call, **query)
completion_dict = completion.to_dict()
# Google PaLM, unlike other chat models, will output empty
# responses for any response it deems unsafe (blocks). Although the text completions
# API has a (relatively undocumented) 'safety_settings' parameter,
# the current chat completions API provides users no control over the blocking.
# We need to detect this and fill the response with the safety reasoning:
if len(completion.filters) > 0:
# Request was blocked. Output why in the response text,
# repairing the candidate dict to mock up 'n' responses
block_error_msg = f'[[BLOCKED_REQUEST]] Request was blocked because it triggered safety filters: {str(completion.filters)}'
completion_dict['candidates'] = [{'author': 1, 'content':block_error_msg}] * n
# Weirdly, google ignores candidate_count if temperature is 0.
# We have to check for this and manually append the n-1 responses:
if n > 1 and temperature == 0 and len(completion_dict['candidates']) == 1:
copied_candidates = [completion_dict['candidates'][0]] * n
completion_dict['candidates'] = copied_candidates
return query, completion_dict
async def call_dalai(prompt: str, model: LLM, server: str="http://localhost:4000", n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
"""
Calls a Dalai server running LLMs Alpaca, Llama, etc locally.
Returns the raw query and response JSON dicts.
Parameters:
- model: The LLM model, whose value is the name known byt Dalai; e.g. 'alpaca.7b'
- port: The port of the local server where Dalai is running. By default 4000.
- prompt: The prompt to pass to the LLM.
- n: How many times to query. If n > 1, this will continue to query the LLM 'n' times and collect all responses.
- temperature: The temperature to query at
- params: Any other Dalai-specific params to pass. For more info, see below or https://cocktailpeanut.github.io/dalai/#/?id=syntax-1
TODO: Currently, this uses a modified dalaipy library for simplicity; however, in the future we might remove this dependency.
"""
# Import and load upon first run
global DALAI_MODEL, DALAI_RESPONSE
if not server or len(server.strip()) == 0: # In case user passed a blank server name, revert to default on port 4000
server = "http://localhost:4000"
if DALAI_MODEL is None:
from chainforge.promptengine.dalaipy import Dalai
DALAI_MODEL = Dalai(server)
elif DALAI_MODEL.server != server: # if the port has changed, we need to create a new model
DALAI_MODEL = Dalai(server)
# Make sure server is connected
DALAI_MODEL.connect()
# Create settings dict to pass to Dalai as args
def_params = {'n_predict':128, 'repeat_last_n':64, 'repeat_penalty':1.3, 'seed':-1, 'threads':4, 'top_k':40, 'top_p':0.9}
for key in params:
if key in def_params:
def_params[key] = params[key]
else:
print(f"Attempted to pass unsupported param '{key}' to Dalai. Ignoring.")
# Create full query to Dalai
query = {
'prompt': prompt,
'model': model.value,
'id': str(round(time.time()*1000)),
'temp': temperature,
**def_params
}
# Create spot to put response and a callback that sets it
DALAI_RESPONSE = None
def on_finish(r):
global DALAI_RESPONSE
DALAI_RESPONSE = r
print(f"Calling Dalai model '{query['model']}' with prompt '{query['prompt']}' (n={n}). Please be patient...")
# Repeat call n times
responses = []
while len(responses) < n:
# Call the Dalai model
req = DALAI_MODEL.generate_request(**query)
sent_req_success = DALAI_MODEL.generate(req, on_finish=on_finish)
if not sent_req_success:
print("Something went wrong pinging the Dalai server. Returning None.")
return None, None
# Blocking --wait for request to complete:
while DALAI_RESPONSE is None:
await asyncio.sleep(0.01)
response = DALAI_RESPONSE['response']
if response[-5:] == '<end>': # strip ending <end> tag, if present
response = response[:-5]
if response.index('\r\n') > -1: # strip off the prompt, which is included in the result up to \r\n:
response = response[(response.index('\r\n')+2):]
DALAI_RESPONSE = None
responses.append(response)
print(f'Response {len(responses)} of {n}:\n', response)
# Disconnect from the server
DALAI_MODEL.disconnect()
return query, responses
def _extract_openai_chat_choice_content(choice: dict) -> str:
"""
Extracts the relevant portion of a OpenAI chat response.
Note that chat choice objects can now include 'function_call' and a blank 'content' response.
This method detects a 'function_call's presence, prepends [[FUNCTION]] and converts the function call into Python format.
"""
if choice['finish_reason'] == 'function_call' or choice["message"]["content"] is None or \
('function_call' in choice['message'] and len(choice['message']['function_call']) > 0):
func = choice['message']['function_call']
return '[[FUNCTION]] ' + func['name'] + str(func['arguments'])
else:
return choice["message"]["content"]
def _extract_chatgpt_responses(response: dict) -> List[str]:
"""
Extracts the text part of a response JSON from ChatGPT. If there is more
than 1 response (e.g., asking the LLM to generate multiple responses),
this produces a list of all returned responses.
"""
choices = response["choices"]
return [
_extract_openai_chat_choice_content(c)
for c in choices
]
def _extract_openai_completion_responses(response: dict) -> List[str]:
"""
Extracts the text part of a response JSON from OpenAI completions models like Davinci. If there are more
than 1 response (e.g., asking the LLM to generate multiple responses),
this produces a list of all returned responses.
"""
choices = response["choices"]
return [
c["text"].strip()
for c in choices
]
def _extract_openai_responses(response: dict) -> List[str]:
"""
Deduces the format of an OpenAI model response (completion or chat)
and extracts the response text using the appropriate method.
"""
if len(response["choices"]) == 0: return []
first_choice = response["choices"][0]
if "message" in first_choice:
return _extract_chatgpt_responses(response)
else:
return _extract_openai_completion_responses(response)
def _extract_palm_responses(completion) -> List[str]:
"""
Extracts the text part of a 'Completion' object from Google PaLM2 `generate_text` or `chat`
NOTE: The candidate object for `generate_text` has a key 'output' which contains the response,
while the `chat` API uses a key 'content'. This checks for either.
"""
return [
c['output'] if 'output' in c else c['content']
for c in completion['candidates']
]
def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List[str]:
"""
Given a LLM and a response object from its API, extract the
text response(s) part of the response object.
"""
llm_str = llm.name if isinstance(llm, LLM) else llm
if llm_str[:6] == 'OpenAI':
if 'davinci' in llm_str.lower():
return _extract_openai_completion_responses(response)
else:
return _extract_chatgpt_responses(response)
elif llm_str[:5] == 'Azure':
return _extract_openai_responses(response)
elif llm_str[:5] == 'PaLM2':
return _extract_palm_responses(response)
elif llm_str[:5] == 'Dalai':
return response
elif llm_str[:6] == 'Claude':
return [r["completion"] for r in response]
else:
raise ValueError(f"LLM {llm_str} is unsupported.")
def merge_response_objs(resp_obj_A: Union[dict, None], resp_obj_B: Union[dict, None]) -> dict:
if resp_obj_B is None:
return resp_obj_A
elif resp_obj_A is None:
return resp_obj_B
raw_resp_A = resp_obj_A["raw_response"]
raw_resp_B = resp_obj_B["raw_response"]
if not isinstance(raw_resp_A, list):
raw_resp_A = [ raw_resp_A ]
if not isinstance(raw_resp_B, list):
raw_resp_B = [ raw_resp_B ]
C = {
"responses": resp_obj_A["responses"] + resp_obj_B["responses"],
"raw_response": raw_resp_A + raw_resp_B,
}
return {
**C,
"prompt": resp_obj_B['prompt'],
"query": resp_obj_B['query'],
"llm": resp_obj_B['llm'],
"info": resp_obj_B['info'],
"metavars": resp_obj_B['metavars'],
}
def create_dir_if_not_exists(path: str) -> None:
if not os.path.exists(path):
os.makedirs(path)
def is_valid_filepath(filepath: str) -> bool:
try:
with open(filepath, 'r', encoding='utf-8'):
pass
except IOError:
try:
# Create the file if it doesn't exist, and write an empty json string to it
with open(filepath, 'w+', encoding='utf-8') as f:
f.write("{}")
pass
except IOError:
return False
return True
def is_valid_json(json_dict: dict) -> bool:
if isinstance(json_dict, dict):
try:
json.dumps(json_dict)
return True
except Exception:
pass
return False
def get_files_at_dir(path: str) -> list:
f = []
for (dirpath, dirnames, filenames) in os.walk(path):
f = filenames
break
return f

View File

@ -0,0 +1,2 @@
__all__ = ['CustomProviderProtocol', 'provider', 'ProviderRegistry']
from .protocol import CustomProviderProtocol, provider, ProviderRegistry

View File

@ -0,0 +1,89 @@
from typing import Tuple, Dict
import asyncio, time
DALAI_MODEL = None
DALAI_RESPONSE = None
async def call_dalai(prompt: str, model: str, server: str="http://localhost:4000", n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
"""
Calls a Dalai server running LLMs Alpaca, Llama, etc locally.
Returns the raw query and response JSON dicts.
Parameters:
- model: The LLM model, whose value is the name known byt Dalai; e.g. 'alpaca.7b'
- port: The port of the local server where Dalai is running. By default 4000.
- prompt: The prompt to pass to the LLM.
- n: How many times to query. If n > 1, this will continue to query the LLM 'n' times and collect all responses.
- temperature: The temperature to query at
- params: Any other Dalai-specific params to pass. For more info, see below or https://cocktailpeanut.github.io/dalai/#/?id=syntax-1
TODO: Currently, this uses a modified dalaipy library for simplicity; however, in the future we might remove this dependency.
"""
# Import and load upon first run
global DALAI_MODEL, DALAI_RESPONSE
if not server or len(server.strip()) == 0: # In case user passed a blank server name, revert to default on port 4000
server = "http://localhost:4000"
if DALAI_MODEL is None:
from chainforge.providers.dalaipy import Dalai
DALAI_MODEL = Dalai(server)
elif DALAI_MODEL.server != server: # if the port has changed, we need to create a new model
DALAI_MODEL = Dalai(server)
# Make sure server is connected
DALAI_MODEL.connect()
# Create settings dict to pass to Dalai as args
def_params = {'n_predict':128, 'repeat_last_n':64, 'repeat_penalty':1.3, 'seed':-1, 'threads':4, 'top_k':40, 'top_p':0.9}
for key in params:
if key in def_params:
def_params[key] = params[key]
else:
print(f"Attempted to pass unsupported param '{key}' to Dalai. Ignoring.")
# Create full query to Dalai
query = {
'prompt': prompt,
'model': model,
'id': str(round(time.time()*1000)),
'temp': temperature,
**def_params
}
# Create spot to put response and a callback that sets it
DALAI_RESPONSE = None
def on_finish(r):
global DALAI_RESPONSE
DALAI_RESPONSE = r
print(f"Calling Dalai model '{query['model']}' with prompt '{query['prompt']}' (n={n}). Please be patient...")
# Repeat call n times
responses = []
while len(responses) < n:
# Call the Dalai model
req = DALAI_MODEL.generate_request(**query)
sent_req_success = DALAI_MODEL.generate(req, on_finish=on_finish)
if not sent_req_success:
print("Something went wrong pinging the Dalai server. Returning None.")
return None, None
# Blocking --wait for request to complete:
while DALAI_RESPONSE is None:
await asyncio.sleep(0.01)
response = DALAI_RESPONSE['response']
if response[-5:] == '<end>': # strip ending <end> tag, if present
response = response[:-5]
if response.index('\r\n') > -1: # strip off the prompt, which is included in the result up to \r\n:
response = response[(response.index('\r\n')+2):]
DALAI_RESPONSE = None
responses.append(response)
print(f'Response {len(responses)} of {n}:\n', response)
# Disconnect from the server
DALAI_MODEL.disconnect()
return query, responses

View File

@ -0,0 +1,127 @@
from typing import Protocol, Optional, TypedDict, Dict, List, Literal, Union, Any
import time
"""
OpenAI chat message format typing
"""
class ChatMessage(TypedDict):
""" A single message, in OpenAI chat message format. """
role: str
content: str
name: Optional[str]
function_call: Optional[Dict]
ChatHistory = List[ChatMessage]
class CustomProviderProtocol(Protocol):
"""
A Callable protocol to implement for custom model provider completions.
See `__call__` for more details.
"""
def __call__(self,
prompt: str,
model: Optional[str],
chat_history: Optional[ChatHistory],
**kwargs: Any) -> str:
"""
Define a call to your custom provider.
Parameters:
- `prompt`: Text to prompt the model. (If it's a chat model, this is the new message to send.)
- `model`: The name of the particular model to use, from the CF settings window. Useful when you have multiple models for a single provider. Optional.
- `chat_history`: Providers may be passed a past chat context as a list of chat messages in OpenAI format (see `chainforge.providers.ChatHistory`).
Chat history does not include the new message to send off (which is passed instead as the `prompt` parameter).
- `kwargs`: Any other parameters to pass the provider API call, like temperature. Parameter names are the keynames in your provider's settings_schema.
Only relevant if you are defining a custom settings_schema JSON to edit provider/model settings in ChainForge.
"""
pass
"""
A registry for custom providers
"""
class _ProviderRegistry:
def __init__(self):
self._registry = {}
self._curr_script_id = '0'
self._last_updated = {}
def set_curr_script_id(self, id: str):
self._curr_script_id = id
def register(self, cls: CustomProviderProtocol, name: str, **kwargs):
if name is None or isinstance(name, str) is False or len(name) == 0:
raise Exception("Cannot register custom model provider: No name given. Name must be a string and unique.")
self._last_updated[name] = self._registry[name]["script_id"] if name in self._registry else None
self._registry[name] = { "name": name, "func": cls, "script_id": self._curr_script_id, **kwargs }
def get(self, name):
return self._registry.get(name)
def get_all(self):
return list(self._registry.values())
def has(self, name):
return name in self._registry
def remove(self, name):
if self.has(name):
del self._registry[name]
def watch_next_registered(self):
self._last_updated = {}
def last_registered(self):
return {k: v for k, v in self._last_updated.items()}
# Global instance of the registry.
ProviderRegistry = _ProviderRegistry()
def provider(name: str = 'Custom Provider',
emoji: str = '',
models: Optional[List[str]] = None,
rate_limit: Union[int, Literal["sequential"]] = "sequential",
settings_schema: Optional[Dict] = None):
"""
A decorator for registering custom LLM provider methods or classes (Callables)
that conform to `CustomProviderProtocol`.
Parameters:
- `name`: The name of your custom provider. Required. (Must be unique; cannot be blank.)
- `emoji`: The emoji to use as the default icon for your provider in the CF interface. Required.
- `models`: A list of models that your provider supports, that you want to be able to choose between in Settings window.
If you're just calling a single model, you can omit this.
- `rate_limit`: If an integer, the maximum number of simulatenous requests to send per minute.
To force requests to be sequential (wait until each request returns before sending another), enter "sequential". Default is sequential.
- `settings_schema`: a JSON Schema specifying the name of your provider in the ChainForge UI, the available settings, and the UI for those settings.
The settings and UI specs are in react-jsonschema-form format: https://rjsf-team.github.io/react-jsonschema-form/.
Specifically, your `settings_schema` dict should have keys:
```
{
"settings": <JSON dict of the schema properties for your settings form, in react-jsonschema-form format (https://rjsf-team.github.io/react-jsonschema-form/docs/)>,
"ui": <JSON dict of the UI Schema for your settings form, in react-jsonschema-form (see UISchema example here: https://rjsf-team.github.io/react-jsonschema-form/)
}
```
You may look to adapt an existing schema from `ModelSettingsSchemas.js` in `chainforge/react-server/src/`,
BUT with the following things to keep in mind:
- the value of "settings" should just be the value of "properties" in the full schema
- don't include the 'shortname' property; this will be added by default and set to the value of `name`
- don't include the 'model' property; this will be populated by the list you passed to `models` (if any)
- the keynames of all properties of the schema should be valid as variable names for Python keyword args; i.e., no spaces
Finally, if you want temperature to appear in the ChainForge UI, you must name your
settings schema property `temperature`, and give it `minimum` and `maximum` values.
NOTE: Only `textarea`, `range`, and enum, and text input UI widgets are properly supported from `react-jsonschema-form`;
you can try other widget types, but the CSS may not display property.
"""
def dec(cls: CustomProviderProtocol):
ProviderRegistry.register(cls, name=name, emoji=emoji, models=models, rate_limit=rate_limit, settings_schema=settings_schema)
return cls
return dec

View File

@ -1,4 +1,4 @@
# Getting Started with Create React App
# ChainForge React Server
This project was bootstrapped with [Create React App](https://github.com/facebook/create-react-app).

View File

@ -1,15 +1,15 @@
{
"files": {
"main.css": "/static/css/main.37690c8a.css",
"main.js": "/static/js/main.c2ab48ed.js",
"main.js": "/static/js/main.1642966f.js",
"static/js/787.4c72bb55.chunk.js": "/static/js/787.4c72bb55.chunk.js",
"index.html": "/index.html",
"main.37690c8a.css.map": "/static/css/main.37690c8a.css.map",
"main.c2ab48ed.js.map": "/static/js/main.c2ab48ed.js.map",
"main.1642966f.js.map": "/static/js/main.1642966f.js.map",
"787.4c72bb55.chunk.js.map": "/static/js/787.4c72bb55.chunk.js.map"
},
"entrypoints": [
"static/css/main.37690c8a.css",
"static/js/main.c2ab48ed.js"
"static/js/main.1642966f.js"
]
}

View File

@ -1 +1 @@
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.c2ab48ed.js"></script><link href="/static/css/main.37690c8a.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.1642966f.js"></script><link href="/static/css/main.37690c8a.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>

View File

@ -17,6 +17,7 @@
"@google-ai/generativelanguage": "^0.2.0",
"@mantine/core": "^6.0.9",
"@mantine/dates": "^6.0.13",
"@mantine/dropzone": "^6.0.19",
"@mantine/form": "^6.0.11",
"@mantine/prism": "^6.0.15",
"@reactflow/background": "^11.2.0",
@ -4013,6 +4014,29 @@
"react": ">=16.8.0"
}
},
"node_modules/@mantine/dropzone": {
"version": "6.0.19",
"resolved": "https://registry.npmjs.org/@mantine/dropzone/-/dropzone-6.0.19.tgz",
"integrity": "sha512-riftzhsXSe84oUHIzMsANtJSIINT08N8ycUuGFxGStf+ytVUZn7TommITnifEEokY6iu4yxjv27FK8maaB1BHA==",
"dependencies": {
"@mantine/utils": "6.0.19",
"react-dropzone": "14.2.3"
},
"peerDependencies": {
"@mantine/core": "6.0.19",
"@mantine/hooks": "6.0.19",
"react": ">=16.8.0",
"react-dom": ">=16.8.0"
}
},
"node_modules/@mantine/dropzone/node_modules/@mantine/utils": {
"version": "6.0.19",
"resolved": "https://registry.npmjs.org/@mantine/utils/-/utils-6.0.19.tgz",
"integrity": "sha512-duvtnaW1gDR2gnvUqnWhl6DMW7sN0HEWqS8Z/BbwaMi75U+Xp17Q72R9JtiIrxQbzsq+KvH9L9B/pxMVwbLirg==",
"peerDependencies": {
"react": ">=16.8.0"
}
},
"node_modules/@mantine/form": {
"version": "6.0.11",
"resolved": "https://registry.npmjs.org/@mantine/form/-/form-6.0.11.tgz",
@ -7080,6 +7104,14 @@
"resolved": "https://registry.npmjs.org/atob-lite/-/atob-lite-2.0.0.tgz",
"integrity": "sha512-LEeSAWeh2Gfa2FtlQE1shxQ8zi5F9GHarrGKz08TMdODD5T4eH6BMsvtnhbWZ+XQn+Gb6om/917ucvRu7l7ukw=="
},
"node_modules/attr-accept": {
"version": "2.2.2",
"resolved": "https://registry.npmjs.org/attr-accept/-/attr-accept-2.2.2.tgz",
"integrity": "sha512-7prDjvt9HmqiZ0cl5CRjtS84sEyhsHP2coDkaZKRKVfCDo9s7iw7ChVmar78Gu9pC4SoR/28wFu/G5JJhTnqEg==",
"engines": {
"node": ">=4"
}
},
"node_modules/autoprefixer": {
"version": "10.4.14",
"resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.14.tgz",
@ -11515,6 +11547,17 @@
"webpack": "^4.0.0 || ^5.0.0"
}
},
"node_modules/file-selector": {
"version": "0.6.0",
"resolved": "https://registry.npmjs.org/file-selector/-/file-selector-0.6.0.tgz",
"integrity": "sha512-QlZ5yJC0VxHxQQsQhXvBaC7VRJ2uaxTf+Tfpu4Z/OcVQJVpZO+DGU0rkoVW5ce2SccxugvpBJoMvUs59iILYdw==",
"dependencies": {
"tslib": "^2.4.0"
},
"engines": {
"node": ">= 12"
}
},
"node_modules/filelist": {
"version": "1.0.4",
"resolved": "https://registry.npmjs.org/filelist/-/filelist-1.0.4.tgz",
@ -20265,6 +20308,22 @@
"react": "^18.2.0"
}
},
"node_modules/react-dropzone": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/react-dropzone/-/react-dropzone-14.2.3.tgz",
"integrity": "sha512-O3om8I+PkFKbxCukfIR3QAGftYXDZfOE2N1mr/7qebQJHs7U+/RSL/9xomJNpRg9kM5h9soQSdf0Gc7OHF5Fug==",
"dependencies": {
"attr-accept": "^2.2.2",
"file-selector": "^0.6.0",
"prop-types": "^15.8.1"
},
"engines": {
"node": ">= 10.13"
},
"peerDependencies": {
"react": ">= 16.8 || 18.0.0"
}
},
"node_modules/react-edit-text": {
"version": "5.1.0",
"resolved": "https://registry.npmjs.org/react-edit-text/-/react-edit-text-5.1.0.tgz",

View File

@ -12,6 +12,7 @@
"@google-ai/generativelanguage": "^0.2.0",
"@mantine/core": "^6.0.9",
"@mantine/dates": "^6.0.13",
"@mantine/dropzone": "^6.0.19",
"@mantine/form": "^6.0.11",
"@mantine/prism": "^6.0.15",
"@reactflow/background": "^11.2.0",

View File

@ -23,7 +23,7 @@ import GlobalSettingsModal from './GlobalSettingsModal';
import ExampleFlowsModal from './ExampleFlowsModal';
import AreYouSureModal from './AreYouSureModal';
import LLMEvaluatorNode from './LLMEvalNode';
import { getDefaultModelFormData, getDefaultModelSettings } from './ModelSettingSchemas';
import { getDefaultModelFormData, getDefaultModelSettings, setCustomProviders } from './ModelSettingSchemas';
import { v4 as uuid } from 'uuid';
import LZString from 'lz-string';
import { EXAMPLEFLOW_1 } from './example_flows';
@ -122,7 +122,8 @@ const snapGrid = [16, 16];
const App = () => {
// Get nodes, edges, etc. state from the Zustand store:
const { nodes, edges, onNodesChange, onEdgesChange, onConnect, addNode, setNodes, setEdges, resetLLMColors } = useStore(selector, shallow);
const { nodes, edges, onNodesChange, onEdgesChange,
onConnect, addNode, setNodes, setEdges, resetLLMColors } = useStore(selector, shallow);
// For saving / loading
const [rfInstance, setRfInstance] = useState(null);
@ -672,7 +673,7 @@ const App = () => {
}
else return (
<div>
<GlobalSettingsModal ref={settingsModal} />
<GlobalSettingsModal ref={settingsModal} alertModal={alertModal} />
<AlertModal ref={alertModal} />
<LoadingOverlay visible={isLoading} overlayBlur={1} />
<ExampleFlowsModal ref={examplesModal} onSelect={onSelectExampleFlow} />

View File

@ -1,14 +1,160 @@
import React, { useState, forwardRef, useImperativeHandle } from 'react';
import { TextInput, Button, Group, Box, Modal, Divider, Text } from '@mantine/core';
import React, { useState, forwardRef, useImperativeHandle, useCallback, useEffect } from 'react';
import { TextInput, Button, Group, Box, Modal, Divider, Text, Tabs, useMantineTheme, rem, Flex, Center, Badge, Card } from '@mantine/core';
import { useDisclosure } from '@mantine/hooks';
import { useForm } from '@mantine/form';
import useStore from './store';
import { IconUpload, IconBrandPython, IconX } from '@tabler/icons-react';
import { Dropzone, DropzoneProps } from '@mantine/dropzone';
import useStore, { initLLMProviders } from './store';
import { APP_IS_RUNNING_LOCALLY } from './backend/utils';
import fetch_from_backend from './fetch_from_backend';
import { setCustomProviders } from './ModelSettingSchemas';
const _LINK_STYLE = {color: '#1E90FF', textDecoration: 'none'};
// To only let us call the backend to load custom providers once upon initalization
let LOADED_CUSTOM_PROVIDERS = false;
// Read a file as text and pass the text to a cb (callback) function
const read_file = (file, cb) => {
const reader = new FileReader();
reader.onload = function(event) {
const fileContent = event.target.result;
cb(fileContent);
};
reader.onerror = function(event) {
console.error("Error reading file:", event);
};
reader.readAsText(file);
};
/** A Dropzone to load a Python `.py` script that registers a `CustomModelProvider` in the Flask backend.
* If successful, the list of custom model providers in the ChainForge UI dropdown is updated.
* */
const CustomProviderScriptDropzone = ({onError, onSetProviders}) => {
const theme = useMantineTheme();
const [isLoading, setIsLoading] = useState(false);
return (<Dropzone
loading={isLoading}
onDrop={(files) => {
if (files.length === 1) {
setIsLoading(true);
read_file(files[0], (content) => {
// Read the file into text and then send it to backend
fetch_from_backend('initCustomProvider', {
code: content
}).then((response) => {
setIsLoading(false);
if (response.error || !response.providers) {
onError(response.error);
return;
}
// Successfully loaded custom providers in backend,
// now load them into the ChainForge UI:
console.log(response.providers);
setCustomProviders(response.providers);
onSetProviders(response.providers);
}).catch((err) => {
setIsLoading(false);
onError(err.message);
});
});
} else {
console.error('Too many files dropped. Only drop one file at a time.')
}
}}
onReject={(files) => console.log('rejected files', files)}
maxSize={3 * 1024 ** 2}
accept={{'text/x-python-script': []}}
>
<Flex pos="center" spacing="md" style={{ minHeight: rem(80), pointerEvents: 'none' }}>
<Center>
<Dropzone.Accept>
<IconUpload
size="4.2rem"
stroke={1.5}
color={theme.colors[theme.primaryColor][theme.colorScheme === 'dark' ? 4 : 6]}
/>
</Dropzone.Accept>
<Dropzone.Reject>
<IconX size="4.2rem"
stroke={1.5}
color={theme.colors.red[theme.colorScheme === 'dark' ? 4 : 6]} />
</Dropzone.Reject>
<Dropzone.Idle>
<IconBrandPython size="4.2rem" stroke={1.5} />
</Dropzone.Idle>
<Box ml='md'>
<Text size="md" lh={1.2} inline>
Drag a Python script for your custom model provider here
</Text>
<Text size="sm" color="dimmed" inline mt={7}>
Each script should contain one or more registered @provider callables
</Text>
</Box>
</Center>
</Flex>
</Dropzone>);
};
const GlobalSettingsModal = forwardRef((props, ref) => {
const [opened, { open, close }] = useDisclosure(false);
const setAPIKeys = useStore((state) => state.setAPIKeys);
const AvailableLLMs = useStore((state) => state.AvailableLLMs);
const setAvailableLLMs = useStore((state) => state.setAvailableLLMs);
const nodes = useStore((state) => state.nodes);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const alertModal = props?.alertModal;
const handleError = useCallback((msg) => {
if (alertModal && alertModal.current)
alertModal.current.trigger(msg);
}, [alertModal]);
const [customProviders, setLocalCustomProviders] = useState([]);
const refreshLLMProviderLists = useCallback(() => {
// We unfortunately have to force all prompt/chat nodes to refresh their LLM lists, bc
// apparently the update to the AvailableLLMs list is not immediately propagated to them.
const prompt_nodes = nodes.filter(n => n.type === 'prompt' || n.type === 'chat');
prompt_nodes.forEach(n => setDataPropsForNode(n.id, { refreshLLMList: true }));
}, [nodes, setDataPropsForNode]);
const removeCustomProvider = useCallback((name) => {
fetch_from_backend('removeCustomProvider', {
name: name,
}).then((response) => {
if (response.error || !response.success) {
handleError(response.error);
return;
}
// Successfully deleted the custom provider from backend;
// now updated the front-end UI to reflect this:
setAvailableLLMs(AvailableLLMs.filter((p) => p.name !== name));
setLocalCustomProviders(customProviders.filter((p) => p.name !== name));
refreshLLMProviderLists();
}).catch((err) => handleError(err.message));
}, [customProviders, handleError, AvailableLLMs, refreshLLMProviderLists]);
// On init
useEffect(() => {
if (APP_IS_RUNNING_LOCALLY() && !LOADED_CUSTOM_PROVIDERS) {
LOADED_CUSTOM_PROVIDERS = true;
// Is running locally; try to load any custom providers.
// Soft fails if it encounters error:
fetch_from_backend('loadCachedCustomProviders', {}, console.error).then((json) => {
if (json?.error || json?.providers === undefined) {
console.error(json?.error || "Could not load custom provider scripts: Error contacting backend.");
return;
}
// Success; pass custom providers list to store:
setCustomProviders(json.providers);
setLocalCustomProviders(json.providers);
});
}
}, []);
const form = useForm({
initialValues: {
@ -25,6 +171,7 @@ const GlobalSettingsModal = forwardRef((props, ref) => {
},
});
// When the API settings form is submitted
const onSubmit = (values) => {
setAPIKeys(values);
close();
@ -39,63 +186,102 @@ const GlobalSettingsModal = forwardRef((props, ref) => {
}));
return (
<Modal opened={opened} onClose={close} title="ChainForge Settings" closeOnClickOutside={false} style={{position: 'relative', 'left': '-100px'}}>
<Modal keepMounted opened={opened} onClose={close} title="ChainForge Settings" closeOnClickOutside={false} style={{position: 'relative', 'left': '-5%'}}>
<Box maw={380} mx="auto">
<Text mb="md" fz="xs" lh={1.15} color='dimmed'>
Note: <b>We do not store your API keys</b> &mdash;not in a cookie, localStorage, or server.
Because of this, <b>you must set your API keys every time you load ChainForge.</b> If you prefer not to worry about it,
we recommend <a href="https://github.com/ianarawjo/ChainForge" target="_blank" style={_LINK_STYLE}>installing ChainForge locally</a> and
<a href="https://github.com/ianarawjo/ChainForge/blob/main/INSTALL_GUIDE.md#2-set-api-keys-openai-anthropic-google-palm" target="_blank" style={_LINK_STYLE}> setting your API keys as environment variables.</a>
</Text>
<form onSubmit={form.onSubmit(onSubmit)}>
<TextInput
label="OpenAI API Key"
placeholder="Paste your OpenAI API key here"
{...form.getInputProps('OpenAI')}
/>
<Tabs defaultValue="api-keys">
<br />
<TextInput
label="HuggingFace API Key"
placeholder="Paste your HuggingFace API key here"
{...form.getInputProps('HuggingFace')}
/>
<Tabs.List>
<Tabs.Tab value="api-keys" >API Keys</Tabs.Tab>
<Tabs.Tab value="custom-providers" >Custom Model Providers</Tabs.Tab>
</Tabs.List>
<br />
<TextInput
label="Anthropic API Key"
placeholder="Paste your Anthropic API key here"
{...form.getInputProps('Anthropic')}
/>
<br />
<TextInput
label="Google PaLM API Key"
placeholder="Paste your Google PaLM API key here"
{...form.getInputProps('Google')}
/>
<br />
<Tabs.Panel value="api-keys" pt="xs">
<Text mb="md" fz="xs" lh={1.15} color='dimmed'>
Note: <b>We do not store your API keys</b> &mdash;not in a cookie, localStorage, or server.
Because of this, <b>you must set your API keys every time you load ChainForge.</b> If you prefer not to worry about it,
we recommend <a href="https://github.com/ianarawjo/ChainForge" target="_blank" style={_LINK_STYLE}>installing ChainForge locally</a> and
<a href="https://github.com/ianarawjo/ChainForge/blob/main/INSTALL_GUIDE.md#2-set-api-keys-openai-anthropic-google-palm" target="_blank" style={_LINK_STYLE}> setting your API keys as environment variables.</a>
</Text>
<form onSubmit={form.onSubmit(onSubmit)}>
<TextInput
label="OpenAI API Key"
placeholder="Paste your OpenAI API key here"
{...form.getInputProps('OpenAI')}
/>
<Divider my="xs" label="Microsoft Azure" labelPosition="center" />
<TextInput
label="Azure OpenAI Key"
description={<span>For more details on Azure OpenAI, see <a href="https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal" target="_blank" style={{color: '#1E90FF', textDecoration:'none'}}>Microsoft Learn.</a> Note that you will have to set the Deployment Name in the Settings of any Azure OpenAI model you add to a Prompt Node.</span>}
placeholder="Paste your Azure OpenAI Key here"
{...form.getInputProps('Azure_OpenAI')}
style={{marginBottom: '8pt'}}
/>
<br />
<TextInput
label="HuggingFace API Key"
placeholder="Paste your HuggingFace API key here"
{...form.getInputProps('HuggingFace')}
/>
<TextInput
label="Azure OpenAI Endpoint"
placeholder="Paste your Azure OpenAI Endpoint here"
{...form.getInputProps('Azure_OpenAI_Endpoint')}
/>
<br />
<TextInput
label="Anthropic API Key"
placeholder="Paste your Anthropic API key here"
{...form.getInputProps('Anthropic')}
/>
<Group position="right" mt="md">
<Button type="submit">Submit</Button>
</Group>
</form>
<br />
<TextInput
label="Google PaLM API Key"
placeholder="Paste your Google PaLM API key here"
{...form.getInputProps('Google')}
/>
<br />
<Divider my="xs" label="Microsoft Azure" labelPosition="center" />
<TextInput
label="Azure OpenAI Key"
description={<span>For more details on Azure OpenAI, see <a href="https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal" target="_blank" style={{color: '#1E90FF', textDecoration:'none'}}>Microsoft Learn.</a> Note that you will have to set the Deployment Name in the Settings of any Azure OpenAI model you add to a Prompt Node.</span>}
placeholder="Paste your Azure OpenAI Key here"
{...form.getInputProps('Azure_OpenAI')}
style={{marginBottom: '8pt'}}
/>
<TextInput
label="Azure OpenAI Endpoint"
placeholder="Paste your Azure OpenAI Endpoint here"
{...form.getInputProps('Azure_OpenAI_Endpoint')}
/>
<Group position="right" mt="md">
<Button type="submit">Submit</Button>
</Group>
</form>
</Tabs.Panel>
{APP_IS_RUNNING_LOCALLY() ?
<Tabs.Panel value="custom-providers" pt="md">
<Text mb="md" fz="sm" lh={1.3}>
You can add model providers to ChainForge by writing custom completion functions as Python scripts. (You can even make your own settings screen!)
To learn more, <a href="https://chainforge.ai/docs/custom_providers/" target="_blank" style={_LINK_STYLE}>see the documentation.</a>
</Text>
{ customProviders.map(p => (
<Card key={p.name} shadow='sm' radius='sm' pt='0px' pb='4px' mb='md' withBorder>
<Group position="apart">
<Group position="left" mt="md" mb="xs">
<Text w='10px'>{p.emoji}</Text>
<Text weight={500}>{p.name}</Text>
{ p.settings_schema ?
<Badge color="blue" variant="light">has settings</Badge>
: <></> }
</Group>
<Button onClick={() => removeCustomProvider(p.name)} color='red' p='0px' mt='4px' variant='subtle'><IconX /></Button>
</Group>
</Card>
)) }
<CustomProviderScriptDropzone onError={handleError} onSetProviders={(ps) => {
refreshLLMProviderLists();
setLocalCustomProviders(ps);
}} />
</Tabs.Panel>
: <></>}
</Tabs>
</Box>
</Modal>
);

View File

@ -1,23 +1,24 @@
import React, { useState, useCallback, useRef, useEffect } from 'react';
import { Handle } from 'react-flow-renderer';
import { Button, Alert, Progress, Textarea } from '@mantine/core';
import { Alert, Progress, Textarea } from '@mantine/core';
import { IconAlertTriangle, IconRobot, IconSearch } from "@tabler/icons-react";
import { v4 as uuid } from 'uuid';
import useStore from './store';
import NodeLabel from './NodeLabelComponent';
import fetch_from_backend from './fetch_from_backend';
import { AvailableLLMs, getDefaultModelSettings } from './ModelSettingSchemas';
import { getDefaultModelSettings } from './ModelSettingSchemas';
import { LLMListContainer } from './LLMListComponent';
import LLMResponseInspectorModal from './LLMResponseInspectorModal';
import InspectFooter from './InspectFooter';
import { initLLMProviders } from './store';
// The default prompt shown in gray highlights to give people a good example of an evaluation prompt.
const PLACEHOLDER_PROMPT = "Respond with 'true' if the text below has a positive sentiment, and 'false' if not. Do not reply with anything else.";
// The default LLM annotator is GPT-4 at temperature 0.
const DEFAULT_LLM_ITEM = (() => {
let item = [AvailableLLMs.find(i => i.base_model === 'gpt-4')]
.map((i) => ({key: uuid(), settings: getDefaultModelSettings(i.base_model), ...i}))[0];
let item = [initLLMProviders.find(i => i.base_model === 'gpt-4')]
.map((i) => ({key: uuid(), settings: getDefaultModelSettings(i.base_model), ...i}))[0];
item.settings.temperature = 0.0;
return item;
})();

View File

@ -1,15 +1,16 @@
import { useState, useEffect, useCallback, useRef, forwardRef, useImperativeHandle } from "react";
import { useState, useEffect, useCallback, useRef, forwardRef, useImperativeHandle, useReducer } from "react";
import { DragDropContext, Draggable } from "react-beautiful-dnd";
import { Menu } from "@mantine/core";
import { v4 as uuid } from 'uuid';
import LLMListItem, { LLMListItemClone } from "./LLMListItem";
import { StrictModeDroppable } from './StrictModeDroppable';
import ModelSettingsModal from "./ModelSettingsModal";
import { getDefaultModelSettings, AvailableLLMs } from './ModelSettingSchemas';
import { getDefaultModelSettings } from './ModelSettingSchemas';
import useStore, { initLLMProviders } from "./store";
// The LLM(s) to include by default on a PromptNode whenever one is created.
// Defaults to ChatGPT (GPT3.5) when running locally, and HF-hosted falcon-7b for online version since it's free.
const DEFAULT_INIT_LLMS = [AvailableLLMs[0]];
const DEFAULT_INIT_LLMS = [initLLMProviders[0]];
// Helper funcs
// Ensure that a name is 'unique'; if not, return an amended version with a count tacked on (e.g. "GPT-4 (2)")
@ -68,8 +69,13 @@ export function LLMList({llms, onItemsChange}) {
updated_item.formData = {...formData};
updated_item.settings = {...settingsData};
if ('model' in formData) // Update the name of the specific model to call
updated_item.model = formData['model'];
if ('model' in formData) { // Update the name of the specific model to call
if (item.base_model.startsWith('__custom'))
// Custom models must always have their base name, to avoid name collisions
updated_item.model = item.base_model + '/' + formData['model'];
else
updated_item.model = formData['model'];
}
if ('shortname' in formData) {
// Change the name, amending any name that isn't unique to ensure it is unique:
const unique_name = ensureUniqueName(formData['shortname'], prev_names);
@ -158,6 +164,17 @@ export function LLMList({llms, onItemsChange}) {
export const LLMListContainer = forwardRef(({description, modelSelectButtonText, initLLMItems, onSelectModel, selectModelAction, onItemsChange}, ref) => {
// All available LLM providers, for the dropdown list
const AvailableLLMs = useStore((state) => state.AvailableLLMs);
// For some reason, when the AvailableLLMs list is updated in the store/, it is not
// immediately updated here. I've tried all kinds of things, but cannot seem to fix this problem.
// We must force a re-render of the component:
const [ignored, forceUpdate] = useReducer(x => x + 1, 0);
const refreshLLMProviderList = () => {
forceUpdate();
};
// Selecting LLM models to prompt
const [llmItems, setLLMItems] = useState(initLLMItems || DEFAULT_INIT_LLMS.map((i) => ({key: uuid(), settings: getDefaultModelSettings(i.base_model), ...i})));
const [llmItemsCurrState, setLLMItemsCurrState] = useState([]);
@ -228,7 +245,7 @@ export const LLMListContainer = forwardRef(({description, modelSelectButtonText,
setLLMItems(new_items);
if (onSelectModel) onSelectModel(item, new_items);
}, [llmItemsCurrState, onSelectModel, selectModelAction]);
}, [llmItemsCurrState, onSelectModel, selectModelAction, AvailableLLMs]);
const onLLMListItemsChange = useCallback((new_items) => {
setLLMItemsCurrState(new_items);
@ -242,6 +259,7 @@ export const LLMListContainer = forwardRef(({description, modelSelectButtonText,
updateProgress,
ensureLLMItemsErrorProgress,
getLLMListItemForKey,
refreshLLMProviderList,
}));
return (<div className="llm-list-container nowheel">

View File

@ -10,21 +10,18 @@
* Descriptions of OpenAI model parameters copied from OpenAI's official chat completions documentation: https://platform.openai.com/docs/models/model-endpoint-compatibility
*/
import { APP_IS_RUNNING_LOCALLY } from "./backend/utils";
import { RATE_LIMITS } from "./backend/models";
import { filterDict } from './backend/utils';
import useStore from "./store";
// Available LLMs in ChainForge, in the format expected by LLMListItems.
export let AvailableLLMs = [
{ name: "GPT3.5", emoji: "🤖", model: "gpt-3.5-turbo", base_model: "gpt-3.5-turbo", temp: 1.0 }, // The base_model designates what settings form will be used, and must be unique.
{ name: "GPT4", emoji: "🥵", model: "gpt-4", base_model: "gpt-4", temp: 1.0 },
{ name: "Claude", emoji: "📚", model: "claude-2", base_model: "claude-v1", temp: 0.5 },
{ name: "PaLM2", emoji: "🦬", model: "chat-bison-001", base_model: "palm2-bison", temp: 0.7 },
{ name: "Azure OpenAI", emoji: "🔷", model: "azure-openai", base_model: "azure-openai", temp: 1.0 },
{ name: "HuggingFace", emoji: "🤗", model: "tiiuae/falcon-7b-instruct", base_model: "hf", temp: 1.0 },
];
if (APP_IS_RUNNING_LOCALLY()) {
AvailableLLMs.push({ name: "Dalai (Alpaca.7B)", emoji: "🦙", model: "alpaca.7B", base_model: "dalai", temp: 0.5 });
}
const UI_SUBMIT_BUTTON_SPEC = {
props: {
disabled: false,
className: 'mantine-UnstyledButton-root mantine-Button-root',
},
norender: false,
submitText: 'Submit',
};
const ChatGPTSettings = {
fullName: "GPT-3.5+ (OpenAI)",
@ -122,14 +119,7 @@ const ChatGPTSettings = {
},
uiSchema: {
'ui:submitButtonOptions': {
props: {
disabled: false,
className: 'mantine-UnstyledButton-root mantine-Button-root',
},
norender: false,
submitText: 'Submit',
},
'ui:submitButtonOptions': UI_SUBMIT_BUTTON_SPEC,
"shortname": {
"ui:autofocus": true
},
@ -295,14 +285,7 @@ const ClaudeSettings = {
},
uiSchema: {
'ui:submitButtonOptions': {
props: {
disabled: false,
className: 'mantine-UnstyledButton-root mantine-Button-root',
},
norender: false,
submitText: 'Submit',
},
'ui:submitButtonOptions': UI_SUBMIT_BUTTON_SPEC,
"shortname": {
"ui:autofocus": true
},
@ -403,14 +386,7 @@ const PaLM2Settings = {
},
uiSchema: {
'ui:submitButtonOptions': {
props: {
disabled: false,
className: 'mantine-UnstyledButton-root mantine-Button-root',
},
norender: false,
submitText: 'Submit',
},
'ui:submitButtonOptions': UI_SUBMIT_BUTTON_SPEC,
"shortname": {
"ui:autofocus": true
},
@ -536,14 +512,7 @@ const DalaiModelSettings = {
},
uiSchema: {
'ui:submitButtonOptions': {
props: {
disabled: false,
className: 'mantine-UnstyledButton-root mantine-Button-root',
},
norender: false,
submitText: 'Submit',
},
'ui:submitButtonOptions': UI_SUBMIT_BUTTON_SPEC,
"shortname": {
"ui:autofocus": true
},
@ -730,19 +699,12 @@ const HuggingFaceTextInferenceSettings = {
},
uiSchema: {
'ui:submitButtonOptions': {
props: {
disabled: false,
className: 'mantine-UnstyledButton-root mantine-Button-root',
},
norender: false,
submitText: 'Submit',
},
'ui:submitButtonOptions': UI_SUBMIT_BUTTON_SPEC,
"shortname": {
"ui:autofocus": true
},
"model": {
"ui:help": "Defaults to Falcon.7B."
"ui:help": "Defaults to Falcon.7B."
},
"temperature": {
"ui:help": "Defaults to 1.0.",
@ -781,7 +743,7 @@ const HuggingFaceTextInferenceSettings = {
};
// A lookup table indexed by base_model.
export const ModelSettings = {
export let ModelSettings = {
'gpt-3.5-turbo': ChatGPTSettings,
'gpt-4': GPT4Settings,
'claude-v1': ClaudeSettings,
@ -791,6 +753,104 @@ export const ModelSettings = {
'hf': HuggingFaceTextInferenceSettings,
};
/**
* Add new model provider to the AvailableLLMs list. Also adds the respective ModelSettings schema and rate limit.
* @param {*} name The name of the provider, to use in the dropdown menu and default name. Must be unique.
* @param {*} emoji The emoji to use for the provider. Optional.
* @param {*} models A list of models the user can select from this provider. Optional.
* @param {*} rate_limit
* @param {*} settings_schema
*/
export const setCustomProvider = (name, emoji, models, rate_limit, settings_schema, llmProviderList) => {
if (typeof emoji === 'string' && (emoji.length === 0 || emoji.length > 2))
throw new Error(`Emoji for a custom provider must have a character.`)
let new_provider = { name };
new_provider.emoji = emoji || '✨';
// Each LLM *model* must have a unique name. To avoid name collisions, for custom providers,
// the full LLM model name is a path, __custom/<provider_name>/<submodel name>
// If there's no submodel, it's just __custom/<provider_name>.
const base_model = `__custom/${name}/`;
new_provider.base_model = base_model;
new_provider.model = base_model + ((Array.isArray(models) && models.length > 0) ? `${models[0]}` : '');
// Build the settings form schema for this new custom provider
let compiled_schema = {
fullName: `${name} (custom provider)`,
schema: {
"type": "object",
"required": [
"shortname",
],
"properties": {
"shortname": {
"type": "string",
"title": "Nickname",
"description": "Unique identifier to appear in ChainForge. Keep it short.",
"default": name,
}
}
},
uiSchema: {
'ui:submitButtonOptions': UI_SUBMIT_BUTTON_SPEC,
"shortname": {
"ui:autofocus": true
},
}
};
// Add a models selector if there's multiple models
if (Array.isArray(models) && models.length > 0) {
compiled_schema.schema["properties"]["model"] = {
"type": "string",
"title": "Model",
"description": `Select a ${name} model to query.`,
"enum": models,
"default": models[0],
};
compiled_schema.uiSchema["model"] = {
"ui:help": `Defaults to ${models[0]}`
};
}
// Add the rest of the settings window if there's one
if (settings_schema) {
compiled_schema.schema["properties"] = {...compiled_schema.schema["properties"], ...settings_schema.settings};
compiled_schema.uiSchema = {...compiled_schema.uiSchema, ...settings_schema.ui};
}
// Check for a default temperature
const default_temp = compiled_schema?.schema?.properties?.temperature?.default;
if (default_temp !== undefined)
new_provider.temp = default_temp;
// Add the built provider and its settings to the global lookups:
let AvailableLLMs = useStore.getState().AvailableLLMs;
const prev_provider_idx = AvailableLLMs.findIndex((d) => d.name === name);
if (prev_provider_idx > -1)
AvailableLLMs[prev_provider_idx] = new_provider;
else
AvailableLLMs.push(new_provider);
ModelSettings[base_model] = compiled_schema;
// Add rate limit info, if specified
if (rate_limit !== undefined && typeof rate_limit === 'number' && rate_limit > 0) {
if (rate_limit >= 60)
RATE_LIMITS[base_model] = [ Math.trunc(rate_limit/60), 1 ]; // for instance, 300 rpm means 5 every second
else
RATE_LIMITS[base_model] = [ 1, Math.trunc(60/rate_limit) ]; // for instance, 10 rpm means 1 every 6 seconds
}
// Commit changes to LLM list
useStore.getState().setAvailableLLMs(AvailableLLMs);
};
export const setCustomProviders = (providers) => {
for (const p of providers)
setCustomProvider(p.name, p.emoji, p.models, p.rate_limit, p.settings_schema);
};
export const getTemperatureSpecForModel = (modelName) => {
if (modelName in ModelSettings) {
const temperature_property = ModelSettings[modelName].schema?.properties?.temperature;
@ -806,7 +866,7 @@ export const postProcessFormData = (settingsSpec, formData) => {
const skip_keys = {'model': true, 'shortname': true};
let new_data = {};
let postprocessors = settingsSpec.postprocessors ? settingsSpec.postprocessors : {};
let postprocessors = settingsSpec?.postprocessors ? settingsSpec.postprocessors : {};
Object.keys(formData).forEach(key => {
if (key in skip_keys) return;
@ -815,7 +875,7 @@ export const postProcessFormData = (settingsSpec, formData) => {
else
new_data[key] = formData[key];
});
return new_data;
};

View File

@ -16,7 +16,6 @@ const ModelSettingsModal = forwardRef((props, ref) => {
const [formData, setFormData] = useState(undefined);
const onSettingsSubmit = props.onSettingsSubmit;
const selectedModelKey = props.model ? props.model.key : null;
const [schema, setSchema] = useState({'type': 'object', 'description': 'No model info object was passed to settings modal.'});
const [uiSchema, setUISchema] = useState({});
@ -30,7 +29,7 @@ const ModelSettingsModal = forwardRef((props, ref) => {
if (props.model && props.model.base_model) {
setModelEmoji(props.model.emoji);
if (!(props.model.base_model in ModelSettings)) {
setSchema({'type': 'object', 'description': `Did not find settings schema for base model ${props.model.base_model}.`});
setSchema({'type': 'object', 'description': `Did not find settings schema for base model ${props.model.base_model}. Maybe you are missing importing a custom provider script?`});
setUISchema({});
setModelName(props.model.base_model);
return;

View File

@ -199,9 +199,12 @@ const PromptNode = ({ data, id, type: node_type }) => {
// On upstream changes
useEffect(() => {
if (data.refresh && data.refresh === true) {
if (data.refresh === true) {
setDataPropsForNode(id, { refresh: false });
setStatus('warning');
} else if (data.refreshLLMList === true) {
llmListContainer?.current?.refreshLLMProviderList();
setDataPropsForNode(id, { refreshLLMList: false });
}
}, [data]);

View File

@ -1,7 +1,7 @@
/*
* @jest-environment node
* @jest-environment jsdom
*/
import { LLM } from '../models';
import { NativeLLM } from '../models';
import { expect, test } from '@jest/globals';
import { queryLLM, executejs, countQueries, ResponseInfo } from '../backend';
import { StandardizedLLMResponse, Dict } from '../typing';
@ -32,14 +32,14 @@ test('count queries required', async () => {
};
// Try a number of different inputs
await test_count_queries([LLM.OpenAI_ChatGPT, LLM.Claude_v1], 3);
await test_count_queries([NativeLLM.OpenAI_ChatGPT, NativeLLM.Claude_v1], 3);
await test_count_queries([{ name: "Claude", key: 'claude-test', emoji: "📚", model: "claude-v1", base_model: "claude-v1", temp: 0.5 }], 5);
});
test('call three LLMs with a single prompt', async () => {
// Setup params to call
const prompt = 'What is one major difference between French and English languages? Be brief.'
const llms = [LLM.OpenAI_ChatGPT, LLM.Claude_v1, LLM.PaLM2_Chat_Bison];
const llms = [NativeLLM.OpenAI_ChatGPT, NativeLLM.Claude_v1, NativeLLM.PaLM2_Chat_Bison];
const n = 1;
const progress_listener = (progress: {[key: symbol]: any}) => {
console.log(JSON.stringify(progress));

View File

@ -2,7 +2,7 @@
* @jest-environment node
*/
import { PromptPipeline } from '../query';
import { LLM } from '../models';
import { LLM, NativeLLM } from '../models';
import { expect, test } from '@jest/globals';
import { LLMResponseError, LLMResponseObject } from '../typing';
@ -65,13 +65,13 @@ async function prompt_model(model: LLM): Promise<void> {
test('basic prompt pipeline with chatgpt', async () => {
// Setup a simple pipeline with a prompt template, 1 variable and 3 input values
await prompt_model(LLM.OpenAI_ChatGPT);
await prompt_model(NativeLLM.OpenAI_ChatGPT);
}, 20000);
test('basic prompt pipeline with anthropic', async () => {
await prompt_model(LLM.Claude_v1);
await prompt_model(NativeLLM.Claude_v1);
}, 40000);
test('basic prompt pipeline with google palm2', async () => {
await prompt_model(LLM.PaLM2_Chat_Bison);
await prompt_model(NativeLLM.PaLM2_Chat_Bison);
}, 40000);

View File

@ -1,8 +1,8 @@
/*
* @jest-environment node
* @jest-environment jsdom
*/
import { call_anthropic, call_chatgpt, call_google_palm, extract_responses, merge_response_objs } from '../utils';
import { LLM } from '../models';
import { LLM, NativeLLM } from '../models';
import { expect, test } from '@jest/globals';
import { LLMResponseObject } from '../typing';
@ -13,7 +13,7 @@ test('merge response objects', () => {
raw_response: ['x', 'y', 'z'],
prompt: 'this is a test',
query: {},
llm: LLM.OpenAI_ChatGPT,
llm: NativeLLM.OpenAI_ChatGPT,
info: { var1: 'value1', var2: 'value2' },
metavars: { meta1: 'meta1' },
};
@ -22,7 +22,7 @@ test('merge response objects', () => {
raw_response: {B: 'B'},
prompt: 'this is a test 2',
query: {},
llm: LLM.OpenAI_ChatGPT,
llm: NativeLLM.OpenAI_ChatGPT,
info: { varB1: 'valueB1', varB2: 'valueB2' },
metavars: { metaB1: 'metaB1' },
};
@ -46,62 +46,62 @@ test('merge response objects', () => {
test('openai chat completions', async () => {
// Call ChatGPT with a basic question, and n=2
const [query, response] = await call_chatgpt("Who invented modern playing cards? Keep your answer brief.", LLM.OpenAI_ChatGPT, 2, 1.0);
const [query, response] = await call_chatgpt("Who invented modern playing cards? Keep your answer brief.", NativeLLM.OpenAI_ChatGPT, 2, 1.0);
console.log(response.choices[0].message);
expect(response.choices).toHaveLength(2);
expect(query).toHaveProperty('temperature');
// Extract responses, check their type
const resps = extract_responses(response, LLM.OpenAI_ChatGPT);
const resps = extract_responses(response, NativeLLM.OpenAI_ChatGPT);
expect(resps).toHaveLength(2);
expect(typeof resps[0]).toBe('string');
}, 20000);
test('openai text completions', async () => {
// Call OpenAI template with a basic question, and n=2
const [query, response] = await call_chatgpt("Who invented modern playing cards? The answer is ", LLM.OpenAI_Davinci003, 2, 1.0);
const [query, response] = await call_chatgpt("Who invented modern playing cards? The answer is ", NativeLLM.OpenAI_Davinci003, 2, 1.0);
console.log(response.choices[0].text);
expect(response.choices).toHaveLength(2);
expect(query).toHaveProperty('n');
// Extract responses, check their type
const resps = extract_responses(response, LLM.OpenAI_Davinci003);
const resps = extract_responses(response, NativeLLM.OpenAI_Davinci003);
expect(resps).toHaveLength(2);
expect(typeof resps[0]).toBe('string');
}, 20000);
test('anthropic models', async () => {
// Call Anthropic's Claude with a basic question
const [query, response] = await call_anthropic("Who invented modern playing cards?", LLM.Claude_v1, 1, 1.0);
const [query, response] = await call_anthropic("Who invented modern playing cards?", NativeLLM.Claude_v1, 1, 1.0);
console.log(response);
expect(response).toHaveLength(1);
expect(query).toHaveProperty('max_tokens_to_sample');
// Extract responses, check their type
const resps = extract_responses(response, LLM.Claude_v1);
const resps = extract_responses(response, NativeLLM.Claude_v1);
expect(resps).toHaveLength(1);
expect(typeof resps[0]).toBe('string');
}, 20000);
test('google palm2 models', async () => {
// Call Google's PaLM Chat API with a basic question
let [query, response] = await call_google_palm("Who invented modern playing cards?", LLM.PaLM2_Chat_Bison, 3, 0.7);
let [query, response] = await call_google_palm("Who invented modern playing cards?", NativeLLM.PaLM2_Chat_Bison, 3, 0.7);
expect(response.candidates).toHaveLength(3);
expect(query).toHaveProperty('candidateCount');
// Extract responses, check their type
let resps = extract_responses(response, LLM.PaLM2_Chat_Bison);
let resps = extract_responses(response, NativeLLM.PaLM2_Chat_Bison);
expect(resps).toHaveLength(3);
expect(typeof resps[0]).toBe('string');
console.log(JSON.stringify(resps));
// Call Google's PaLM Text Completions API with a basic question
[query, response] = await call_google_palm("Who invented modern playing cards? The answer ", LLM.PaLM2_Text_Bison, 3, 0.7);
[query, response] = await call_google_palm("Who invented modern playing cards? The answer ", NativeLLM.PaLM2_Text_Bison, 3, 0.7);
expect(response.candidates).toHaveLength(3);
expect(query).toHaveProperty('maxOutputTokens');
// Extract responses, check their type
resps = extract_responses(response, LLM.PaLM2_Chat_Bison);
resps = extract_responses(response, NativeLLM.PaLM2_Chat_Bison);
expect(resps).toHaveLength(3);
expect(typeof resps[0]).toBe('string');
console.log(JSON.stringify(resps));

View File

@ -1,7 +1,7 @@
import markdownIt from "markdown-it";
import { Dict, StringDict, LLMResponseError, LLMResponseObject, StandardizedLLMResponse, ChatHistory, ChatHistoryInfo, isEqualChatHistory } from "./typing";
import { LLM, getEnumName } from "./models";
import { Dict, StringDict, LLMResponseError, LLMResponseObject, StandardizedLLMResponse, ChatHistoryInfo, isEqualChatHistory } from "./typing";
import { LLM, NativeLLM, getEnumName } from "./models";
import { APP_IS_RUNNING_LOCALLY, set_api_keys, FLASK_BASE_URL, call_flask_backend } from "./utils";
import StorageCache from "./cache";
import { PromptPipeline } from "./query";
@ -12,11 +12,6 @@ import { PromptPermutationGenerator, PromptTemplate } from "./template";
// =================
// """
let LLM_NAME_MAP = {};
Object.entries(LLM).forEach(([key, value]) => {
LLM_NAME_MAP[value] = key;
});
enum MetricType {
KeyValue = 0,
KeyValue_Numeric = 1,
@ -553,13 +548,7 @@ export async function queryLLM(id: string,
// Ensure llm param is an array
if (typeof llm === 'string')
llm = [ llm ];
llm = llm as (Array<string> | Array<Dict>);
for (let i = 0; i < llm.length; i++) {
const llm_spec = llm[i];
if (!(extract_llm_name(llm_spec) in LLM_NAME_MAP))
return {'error': `LLM named '${llm_spec}' is not supported.`};
}
llm = llm as (Array<string> | Array<Dict>);
await setAPIKeys(api_keys);
@ -1120,3 +1109,57 @@ export async function fetchOpenAIEval(evalname: string): Promise<Dict> {
.then(response => response.json())
.then(res => ({data: res}));
}
/**
* Passes a Python script to load a custom model provider to the Flask backend.
* @param code The Python script to pass, as a string.
* @returns a Promise with the JSON of the response. Will include 'error' key if error'd; if success,
* a 'providers' key with a list of all loaded custom provider callbacks, as dicts.
*/
export async function initCustomProvider(code: string): Promise<Dict> {
// Attempt to fetch the example flow from the local filesystem
// by querying the Flask server:
return fetch(`${FLASK_BASE_URL}app/initCustomProvider`, {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({ code })
}).then(function(res) {
return res.json();
});
}
/**
* Asks Python script to remove a custom provider with name 'name'.
* @param name The name of the provider to remove. The name must match the name in the `ProviderRegistry`.
* @returns a Promise with the JSON of the response. Will include 'error' key if error'd; if success,
* a 'success' key with a true value.
*/
export async function removeCustomProvider(name: string): Promise<Dict> {
// Attempt to fetch the example flow from the local filesystem
// by querying the Flask server:
return fetch(`${FLASK_BASE_URL}app/removeCustomProvider`, {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({ name })
}).then(function(res) {
return res.json();
});
}
/**
* Asks Python backend to load custom provider scripts that are cache'd in the user's local dir.
*
* @returns a Promise with the JSON of the response. Will include 'error' key if error'd; if success,
* a 'providers' key with all loaded custom providers in an array. If there were none, returns empty array.
*/
export async function loadCachedCustomProviders(): Promise<Dict> {
return fetch(`${FLASK_BASE_URL}app/loadCachedCustomProviders`, {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: "{}"
}).then(function(res) {
return res.json();
});
}

View File

@ -1,7 +1,8 @@
/**
* A list of all model APIs natively supported by ChainForge.
*/
export enum LLM {
export type LLM = string | NativeLLM;
export enum NativeLLM {
// OpenAI Chat
OpenAI_ChatGPT = "gpt-3.5-turbo",
OpenAI_ChatGPT_16k = "gpt-3.5-turbo-16k",
@ -70,6 +71,7 @@ export enum LLMProvider {
Anthropic = "anthropic",
Google = "google",
HuggingFace = "hf",
Custom = "__custom",
}
/**
@ -78,7 +80,7 @@ export enum LLMProvider {
* @returns an `LLMProvider` describing what provider hosts the model
*/
export function getProvider(llm: LLM): LLMProvider | undefined {
const llm_name = getEnumName(LLM, llm.toString());
const llm_name = getEnumName(NativeLLM, llm.toString());
if (llm_name?.startsWith('OpenAI'))
return LLMProvider.OpenAI;
else if (llm_name?.startsWith('Azure'))
@ -91,6 +93,8 @@ export function getProvider(llm: LLM): LLMProvider | undefined {
return LLMProvider.HuggingFace;
else if (llm.toString().startsWith('claude'))
return LLMProvider.Anthropic;
else if (llm.toString().startsWith('__custom/'))
return LLMProvider.Custom;
return undefined;
}
@ -101,21 +105,21 @@ export function getProvider(llm: LLM): LLMProvider | undefined {
# This 'cheap' version of controlling for rate limits is to wait a few seconds between batches of requests being sent off.
# If a model is missing from below, it means we must send and receive only 1 request at a time (synchronous).
# The following is only a guideline, and a bit on the conservative side. */
export const RATE_LIMITS: { [key in LLM]?: [number, number] } = {
[LLM.OpenAI_ChatGPT]: [30, 10], // max 30 requests a batch; wait 10 seconds between
[LLM.OpenAI_ChatGPT_0301]: [30, 10],
[LLM.OpenAI_ChatGPT_0613]: [30, 10],
[LLM.OpenAI_ChatGPT_16k]: [30, 10],
[LLM.OpenAI_ChatGPT_16k_0613]: [30, 10],
[LLM.OpenAI_GPT4]: [4, 15], // max 4 requests a batch; wait 15 seconds between
[LLM.OpenAI_GPT4_0314]: [4, 15],
[LLM.OpenAI_GPT4_0613]: [4, 15],
[LLM.OpenAI_GPT4_32k]: [4, 15],
[LLM.OpenAI_GPT4_32k_0314]: [4, 15],
[LLM.OpenAI_GPT4_32k_0613]: [4, 15],
[LLM.Azure_OpenAI]: [30, 10],
[LLM.PaLM2_Text_Bison]: [4, 10], // max 30 requests per minute; so do 4 per batch, 10 seconds between (conservative)
[LLM.PaLM2_Chat_Bison]: [4, 10],
export let RATE_LIMITS: { [key in LLM]?: [number, number] } = {
[NativeLLM.OpenAI_ChatGPT]: [30, 10], // max 30 requests a batch; wait 10 seconds between
[NativeLLM.OpenAI_ChatGPT_0301]: [30, 10],
[NativeLLM.OpenAI_ChatGPT_0613]: [30, 10],
[NativeLLM.OpenAI_ChatGPT_16k]: [30, 10],
[NativeLLM.OpenAI_ChatGPT_16k_0613]: [30, 10],
[NativeLLM.OpenAI_GPT4]: [4, 15], // max 4 requests a batch; wait 15 seconds between
[NativeLLM.OpenAI_GPT4_0314]: [4, 15],
[NativeLLM.OpenAI_GPT4_0613]: [4, 15],
[NativeLLM.OpenAI_GPT4_32k]: [4, 15],
[NativeLLM.OpenAI_GPT4_32k_0314]: [4, 15],
[NativeLLM.OpenAI_GPT4_32k_0613]: [4, 15],
[NativeLLM.Azure_OpenAI]: [30, 10],
[NativeLLM.PaLM2_Text_Bison]: [4, 10], // max 30 requests per minute; so do 4 per batch, 10 seconds between (conservative)
[NativeLLM.PaLM2_Chat_Bison]: [4, 10],
};

View File

@ -1,6 +1,6 @@
import { PromptTemplate, PromptPermutationGenerator } from "./template";
import { LLM, RATE_LIMITS } from './models';
import { Dict, LLMResponseError, LLMResponseObject, ChatHistory, isEqualChatHistory, ChatHistoryInfo } from "./typing";
import { LLM, NativeLLM, RATE_LIMITS } from './models';
import { Dict, LLMResponseError, LLMResponseObject, isEqualChatHistory, ChatHistoryInfo } from "./typing";
import { extract_responses, merge_response_objs, call_llm, mergeDicts } from "./utils";
import StorageCache from "./cache";
@ -164,7 +164,7 @@ export class PromptPipeline {
}
if (!prompt.is_concrete())
throw Error(`Cannot send a prompt '${prompt}' to LLM: Prompt is a template.`)
throw new Error(`Cannot send a prompt '${prompt}' to LLM: Prompt is a template.`)
// Get the cache of responses with respect to this prompt, + normalize format so it's always an array (of size >= 0)
const cache_bucket = responses[prompt_str];
@ -191,7 +191,7 @@ export class PromptPipeline {
"query": cached_resp["query"],
"responses": extracted_resps.slice(0, n),
"raw_response": cached_resp["raw_response"],
"llm": cached_resp["llm"] || LLM.OpenAI_ChatGPT,
"llm": cached_resp["llm"] || NativeLLM.OpenAI_ChatGPT,
// We want to use the new info, since 'vars' could have changed even though
// the prompt text is the same (e.g., "this is a tool -> this is a {x} where x='tool'")
"info": mergeDicts(info, chat_history?.fill_history),

View File

@ -223,16 +223,16 @@ export async function call_chatgpt(prompt: string, model: LLM, n: number = 1, te
*/
export async function call_azure_openai(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
if (!AZURE_OPENAI_KEY)
throw Error("Could not find an Azure OpenAPI Key to use. Double-check that your key is set in Settings or in your local environment.");
throw new Error("Could not find an Azure OpenAPI Key to use. Double-check that your key is set in Settings or in your local environment.");
if (!AZURE_OPENAI_ENDPOINT)
throw Error("Could not find an Azure OpenAI Endpoint to use. Double-check that your endpoint is set in Settings or in your local environment.");
throw new Error("Could not find an Azure OpenAI Endpoint to use. Double-check that your endpoint is set in Settings or in your local environment.");
const deployment_name: string = params?.deployment_name;
const model_type: string = params?.model_type;
if (!deployment_name)
throw Error("Could not find an Azure OpenAPI deployment name. Double-check that your deployment name is set in Settings or in your local environment.");
throw new Error("Could not find an Azure OpenAPI deployment name. Double-check that your deployment name is set in Settings or in your local environment.");
if (!model_type)
throw Error("Could not find a model type specified for an Azure OpenAI model. Double-check that your deployment name is set in Settings or in your local environment.");
throw new Error("Could not find a model type specified for an Azure OpenAI model. Double-check that your deployment name is set in Settings or in your local environment.");
const client = new AzureOpenAIClient(AZURE_OPENAI_ENDPOINT, new AzureKeyCredential(AZURE_OPENAI_KEY));
@ -296,12 +296,12 @@ export async function call_azure_openai(prompt: string, model: LLM, n: number =
*/
export async function call_anthropic(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
if (!ANTHROPIC_API_KEY)
throw Error("Could not find an API key for Anthropic models. Double-check that your API key is set in Settings or in your local environment.");
throw new Error("Could not find an API key for Anthropic models. Double-check that your API key is set in Settings or in your local environment.");
// Wrap the prompt in the provided template, or use the default Anthropic one
const custom_prompt_wrapper: string = params?.custom_prompt_wrapper || (ANTHROPIC_HUMAN_PROMPT + " {prompt}" + ANTHROPIC_AI_PROMPT);
if (!custom_prompt_wrapper.includes('{prompt}'))
throw Error("Custom prompt wrapper is missing required {prompt} template variable.");
throw new Error("Custom prompt wrapper is missing required {prompt} template variable.");
const prompt_wrapper_template = new StringTemplate(custom_prompt_wrapper);
let wrapped_prompt = prompt_wrapper_template.safe_substitute({prompt: prompt});
@ -389,7 +389,7 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
*/
export async function call_google_palm(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
if (!GOOGLE_PALM_API_KEY)
throw Error("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local environment.");
throw new Error("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local environment.");
const is_chat_model = model.toString().includes('chat');
@ -621,6 +621,40 @@ export async function call_huggingface(prompt: string, model: LLM, n: number = 1
return [query, responses];
}
async function call_custom_provider(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
if (!APP_IS_RUNNING_LOCALLY())
throw new Error("The ChainForge app does not appear to be running locally. You can only call custom model providers if you are running ChainForge on your local machine, from a Flask app.")
// The model to call is in format:
// __custom/<provider_name>/<submodel name>
// It may also exclude the final tag.
// We extract the provider name (this is the name used in the Python backend's `ProviderRegistry`) and optionally, the submodel name
const provider_path = model.substring(9);
const provider_name = provider_path.substring(0, provider_path.indexOf('/'));
const submodel_name = (provider_path.length === provider_name.length-1) ? undefined : provider_path.substring(provider_path.lastIndexOf('/')+1);
let responses = [];
const query = { prompt, model, temperature, ...params };
// Call the custom provider n times
while (responses.length < n) {
let {response, error} = await call_flask_backend('callCustomProvider',
{ 'name': provider_name,
'params': {
prompt, model: submodel_name, temperature, ...params
}
});
// Fail if an error is encountered
if (error !== undefined || response === undefined)
throw new Error(error);
responses.push(response);
}
return [query, responses];
}
/**
* Switcher that routes the request to the appropriate API call function. If call doesn't exist, throws error.
*/
@ -644,6 +678,8 @@ export async function call_llm(llm: LLM, prompt: string, n: number, temperature:
call_api = call_anthropic;
else if (llm_provider === LLMProvider.HuggingFace)
call_api = call_huggingface;
else if (llm_provider === LLMProvider.Custom)
call_api = call_custom_provider;
return call_api(prompt, llm, n, temperature, params);
}
@ -744,8 +780,11 @@ export function extract_responses(response: Array<string | Dict> | Dict, llm: LL
case LLMProvider.HuggingFace:
return _extract_huggingface_responses(response as Dict[]);
default:
throw new Error(`No method defined to extract responses for LLM ${llm}.`)
}
if (Array.isArray(response) && response.length > 0 && typeof response[0] === 'string')
return response as string[];
else
throw new Error(`No method defined to extract responses for LLM ${llm}.`)
}
}
/**

View File

@ -1,8 +1,8 @@
import { queryLLM, executejs, executepy,
fetchExampleFlow, fetchOpenAIEval, importCache,
exportCache, countQueries, grabResponses,
generatePrompts,
evalWithLLM} from "./backend/backend";
generatePrompts, initCustomProvider,
removeCustomProvider, evalWithLLM, loadCachedCustomProviders } from "./backend/backend";
const clone = (obj) => JSON.parse(JSON.stringify(obj));
@ -30,6 +30,12 @@ async function _route_to_js_backend(route, params) {
return fetchExampleFlow(params.name);
case 'fetchOpenAIEval':
return fetchOpenAIEval(params.name);
case 'initCustomProvider':
return initCustomProvider(params.code);
case 'removeCustomProvider':
return removeCustomProvider(params.name);
case 'loadCachedCustomProviders':
return loadCachedCustomProviders();
default:
throw new Error(`Could not find backend function for route named ${route}`);
}

View File

@ -6,6 +6,7 @@ import {
} from 'react-flow-renderer';
import { escapeBraces } from './backend/template';
import { filterDict } from './backend/utils';
import { APP_IS_RUNNING_LOCALLY } from './backend/utils';
// Initial project settings
const initialAPIKeys = {};
@ -27,12 +28,30 @@ export const colorPalettes = {
const refreshableOutputNodeTypes = new Set(['evaluator', 'prompt', 'inspect', 'vis', 'llmeval', 'textfields', 'chat', 'simpleval']);
export let initLLMProviders = [
{ name: "GPT3.5", emoji: "🤖", model: "gpt-3.5-turbo", base_model: "gpt-3.5-turbo", temp: 1.0 }, // The base_model designates what settings form will be used, and must be unique.
{ name: "GPT4", emoji: "🥵", model: "gpt-4", base_model: "gpt-4", temp: 1.0 },
{ name: "Claude", emoji: "📚", model: "claude-2", base_model: "claude-v1", temp: 0.5 },
{ name: "PaLM2", emoji: "🦬", model: "chat-bison-001", base_model: "palm2-bison", temp: 0.7 },
{ name: "Azure OpenAI", emoji: "🔷", model: "azure-openai", base_model: "azure-openai", temp: 1.0 },
{ name: "HuggingFace", emoji: "🤗", model: "tiiuae/falcon-7b-instruct", base_model: "hf", temp: 1.0 },
];
if (APP_IS_RUNNING_LOCALLY()) {
initLLMProviders.push({ name: "Dalai (Alpaca.7B)", emoji: "🦙", model: "alpaca.7B", base_model: "dalai", temp: 0.5 });
}
// A global store of variables, used for maintaining state
// across ChainForge and ReactFlow components.
const useStore = create((set, get) => ({
nodes: [],
edges: [],
// Available LLMs in ChainForge, in the format expected by LLMListItems.
AvailableLLMs: [...initLLMProviders],
setAvailableLLMs: (llmProviderList) => {
set({AvailableLLMs: llmProviderList});
},
// Keeping track of LLM API keys
apiKeys: initialAPIKeys,
setAPIKeys: (apiKeys) => {

View File

@ -6,7 +6,7 @@ def readme():
setup(
name='chainforge',
version='0.2.5.4',
version='0.2.6',
packages=find_packages(),
author="Ian Arawjo",
description="A Visual Programming Environment for Prompt Engineering",