ianarawjo 0134dbf59b
Custom providers (#122)
* Removed Python backend files that are no longer used (everything in `promptengine`)

* Added `providers` subdomain, with `CustomProviderProtocol`, `provider` decorator, and global singleton `ProviderRegistry`

* Added a tab for custom providers, and a dropzone, in ChainForge global settings UI

* List custom providers in the Global Settings screen once added. 

* Added ability to remove custom providers by clicking X. 

* Make custom funcs sync but call them async'ly.

* Add Cohere custom provider example in examples/

*Cache the custom provider scripts and load them upon page load

* Rebuild react and update package version

* Bug fix when custom provider is deleted and settings screen is opened on the deleted custom provider
2023-08-27 15:11:42 -04:00

89 lines
3.5 KiB
Python

from typing import Tuple, Dict
import asyncio, time
DALAI_MODEL = None
DALAI_RESPONSE = None
async def call_dalai(prompt: str, model: str, server: str="http://localhost:4000", n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
"""
Calls a Dalai server running LLMs Alpaca, Llama, etc locally.
Returns the raw query and response JSON dicts.
Parameters:
- model: The LLM model, whose value is the name known byt Dalai; e.g. 'alpaca.7b'
- port: The port of the local server where Dalai is running. By default 4000.
- prompt: The prompt to pass to the LLM.
- n: How many times to query. If n > 1, this will continue to query the LLM 'n' times and collect all responses.
- temperature: The temperature to query at
- params: Any other Dalai-specific params to pass. For more info, see below or https://cocktailpeanut.github.io/dalai/#/?id=syntax-1
TODO: Currently, this uses a modified dalaipy library for simplicity; however, in the future we might remove this dependency.
"""
# Import and load upon first run
global DALAI_MODEL, DALAI_RESPONSE
if not server or len(server.strip()) == 0: # In case user passed a blank server name, revert to default on port 4000
server = "http://localhost:4000"
if DALAI_MODEL is None:
from chainforge.providers.dalaipy import Dalai
DALAI_MODEL = Dalai(server)
elif DALAI_MODEL.server != server: # if the port has changed, we need to create a new model
DALAI_MODEL = Dalai(server)
# Make sure server is connected
DALAI_MODEL.connect()
# Create settings dict to pass to Dalai as args
def_params = {'n_predict':128, 'repeat_last_n':64, 'repeat_penalty':1.3, 'seed':-1, 'threads':4, 'top_k':40, 'top_p':0.9}
for key in params:
if key in def_params:
def_params[key] = params[key]
else:
print(f"Attempted to pass unsupported param '{key}' to Dalai. Ignoring.")
# Create full query to Dalai
query = {
'prompt': prompt,
'model': model,
'id': str(round(time.time()*1000)),
'temp': temperature,
**def_params
}
# Create spot to put response and a callback that sets it
DALAI_RESPONSE = None
def on_finish(r):
global DALAI_RESPONSE
DALAI_RESPONSE = r
print(f"Calling Dalai model '{query['model']}' with prompt '{query['prompt']}' (n={n}). Please be patient...")
# Repeat call n times
responses = []
while len(responses) < n:
# Call the Dalai model
req = DALAI_MODEL.generate_request(**query)
sent_req_success = DALAI_MODEL.generate(req, on_finish=on_finish)
if not sent_req_success:
print("Something went wrong pinging the Dalai server. Returning None.")
return None, None
# Blocking --wait for request to complete:
while DALAI_RESPONSE is None:
await asyncio.sleep(0.01)
response = DALAI_RESPONSE['response']
if response[-5:] == '<end>': # strip ending <end> tag, if present
response = response[:-5]
if response.index('\r\n') > -1: # strip off the prompt, which is included in the result up to \r\n:
response = response[(response.index('\r\n')+2):]
DALAI_RESPONSE = None
responses.append(response)
print(f'Response {len(responses)} of {n}:\n', response)
# Disconnect from the server
DALAI_MODEL.disconnect()
return query, responses