2024-03-30 22:43:40 -04:00
|
|
|
from typing import Protocol, Optional, Dict, Dict, List, Literal, Union, Any
|
2023-08-27 15:11:42 -04:00
|
|
|
import time
|
|
|
|
|
|
|
|
"""
|
|
|
|
OpenAI chat message format typing
|
|
|
|
"""
|
2024-03-30 22:43:40 -04:00
|
|
|
class ChatMessage(Dict):
|
2023-08-27 15:11:42 -04:00
|
|
|
""" A single message, in OpenAI chat message format. """
|
|
|
|
role: str
|
|
|
|
content: str
|
|
|
|
name: Optional[str]
|
|
|
|
function_call: Optional[Dict]
|
|
|
|
|
|
|
|
ChatHistory = List[ChatMessage]
|
|
|
|
|
|
|
|
|
|
|
|
class CustomProviderProtocol(Protocol):
|
|
|
|
"""
|
|
|
|
A Callable protocol to implement for custom model provider completions.
|
|
|
|
See `__call__` for more details.
|
|
|
|
"""
|
|
|
|
def __call__(self,
|
|
|
|
prompt: str,
|
|
|
|
model: Optional[str],
|
|
|
|
chat_history: Optional[ChatHistory],
|
|
|
|
**kwargs: Any) -> str:
|
|
|
|
"""
|
|
|
|
Define a call to your custom provider.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- `prompt`: Text to prompt the model. (If it's a chat model, this is the new message to send.)
|
|
|
|
- `model`: The name of the particular model to use, from the CF settings window. Useful when you have multiple models for a single provider. Optional.
|
|
|
|
- `chat_history`: Providers may be passed a past chat context as a list of chat messages in OpenAI format (see `chainforge.providers.ChatHistory`).
|
|
|
|
Chat history does not include the new message to send off (which is passed instead as the `prompt` parameter).
|
|
|
|
- `kwargs`: Any other parameters to pass the provider API call, like temperature. Parameter names are the keynames in your provider's settings_schema.
|
|
|
|
Only relevant if you are defining a custom settings_schema JSON to edit provider/model settings in ChainForge.
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
A registry for custom providers
|
|
|
|
"""
|
|
|
|
class _ProviderRegistry:
|
|
|
|
def __init__(self):
|
|
|
|
self._registry = {}
|
|
|
|
self._curr_script_id = '0'
|
|
|
|
self._last_updated = {}
|
|
|
|
|
|
|
|
def set_curr_script_id(self, id: str):
|
|
|
|
self._curr_script_id = id
|
|
|
|
|
|
|
|
def register(self, cls: CustomProviderProtocol, name: str, **kwargs):
|
|
|
|
if name is None or isinstance(name, str) is False or len(name) == 0:
|
|
|
|
raise Exception("Cannot register custom model provider: No name given. Name must be a string and unique.")
|
|
|
|
self._last_updated[name] = self._registry[name]["script_id"] if name in self._registry else None
|
|
|
|
self._registry[name] = { "name": name, "func": cls, "script_id": self._curr_script_id, **kwargs }
|
|
|
|
|
|
|
|
def get(self, name):
|
|
|
|
return self._registry.get(name)
|
|
|
|
|
|
|
|
def get_all(self):
|
|
|
|
return list(self._registry.values())
|
|
|
|
|
|
|
|
def has(self, name):
|
|
|
|
return name in self._registry
|
|
|
|
|
|
|
|
def remove(self, name):
|
|
|
|
if self.has(name):
|
|
|
|
del self._registry[name]
|
|
|
|
|
|
|
|
def watch_next_registered(self):
|
|
|
|
self._last_updated = {}
|
|
|
|
|
|
|
|
def last_registered(self):
|
|
|
|
return {k: v for k, v in self._last_updated.items()}
|
|
|
|
|
|
|
|
|
|
|
|
# Global instance of the registry.
|
|
|
|
ProviderRegistry = _ProviderRegistry()
|
|
|
|
|
|
|
|
def provider(name: str = 'Custom Provider',
|
|
|
|
emoji: str = '✨',
|
|
|
|
models: Optional[List[str]] = None,
|
|
|
|
rate_limit: Union[int, Literal["sequential"]] = "sequential",
|
|
|
|
settings_schema: Optional[Dict] = None):
|
|
|
|
"""
|
|
|
|
A decorator for registering custom LLM provider methods or classes (Callables)
|
|
|
|
that conform to `CustomProviderProtocol`.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- `name`: The name of your custom provider. Required. (Must be unique; cannot be blank.)
|
|
|
|
- `emoji`: The emoji to use as the default icon for your provider in the CF interface. Required.
|
|
|
|
- `models`: A list of models that your provider supports, that you want to be able to choose between in Settings window.
|
|
|
|
If you're just calling a single model, you can omit this.
|
|
|
|
- `rate_limit`: If an integer, the maximum number of simulatenous requests to send per minute.
|
|
|
|
To force requests to be sequential (wait until each request returns before sending another), enter "sequential". Default is sequential.
|
|
|
|
- `settings_schema`: a JSON Schema specifying the name of your provider in the ChainForge UI, the available settings, and the UI for those settings.
|
|
|
|
The settings and UI specs are in react-jsonschema-form format: https://rjsf-team.github.io/react-jsonschema-form/.
|
|
|
|
|
|
|
|
Specifically, your `settings_schema` dict should have keys:
|
|
|
|
|
|
|
|
```
|
|
|
|
{
|
|
|
|
"settings": <JSON dict of the schema properties for your settings form, in react-jsonschema-form format (https://rjsf-team.github.io/react-jsonschema-form/docs/)>,
|
|
|
|
"ui": <JSON dict of the UI Schema for your settings form, in react-jsonschema-form (see UISchema example here: https://rjsf-team.github.io/react-jsonschema-form/)
|
|
|
|
}
|
|
|
|
```
|
|
|
|
|
|
|
|
You may look to adapt an existing schema from `ModelSettingsSchemas.js` in `chainforge/react-server/src/`,
|
|
|
|
BUT with the following things to keep in mind:
|
|
|
|
- the value of "settings" should just be the value of "properties" in the full schema
|
|
|
|
- don't include the 'shortname' property; this will be added by default and set to the value of `name`
|
|
|
|
- don't include the 'model' property; this will be populated by the list you passed to `models` (if any)
|
|
|
|
- the keynames of all properties of the schema should be valid as variable names for Python keyword args; i.e., no spaces
|
|
|
|
|
|
|
|
Finally, if you want temperature to appear in the ChainForge UI, you must name your
|
|
|
|
settings schema property `temperature`, and give it `minimum` and `maximum` values.
|
|
|
|
|
|
|
|
NOTE: Only `textarea`, `range`, and enum, and text input UI widgets are properly supported from `react-jsonschema-form`;
|
|
|
|
you can try other widget types, but the CSS may not display property.
|
|
|
|
|
|
|
|
"""
|
|
|
|
def dec(cls: CustomProviderProtocol):
|
|
|
|
ProviderRegistry.register(cls, name=name, emoji=emoji, models=models, rate_limit=rate_limit, settings_schema=settings_schema)
|
|
|
|
return cls
|
|
|
|
return dec
|