ianarawjo 583ea6506f
Convert entire front-end to TypeScript. Add image model support. (And more!) (#252)
* Refactor: modularize response boxes into a separate component

* Type store.js. Change info to vars. NOTE: This may break backwards compat.

* Refactor addNodes in App.tsx to be simpler.

* Turn AlertModal into a Provider with useContext

* Remove fetch_from_backend.

* Add build/ to gitignore

* Add support for image models and add Dall-E models.

* Better rate limiting with Bottleneck

* Fix new Chrome bug with file import readers not appearing as arrays; and fix bug with exportCache

* Add ability to add custom right-click context menu items per node

* Convert to/from TF and Items nodes

* Add lazyloader for images

* Add compression to images by default before storing in cache

* Add image compression toggle in Global Settings

* Move Alert Provider to top level of index.js
2024-03-30 22:43:40 -04:00

127 lines
6.0 KiB
Python

from typing import Protocol, Optional, Dict, Dict, List, Literal, Union, Any
import time
"""
OpenAI chat message format typing
"""
class ChatMessage(Dict):
""" A single message, in OpenAI chat message format. """
role: str
content: str
name: Optional[str]
function_call: Optional[Dict]
ChatHistory = List[ChatMessage]
class CustomProviderProtocol(Protocol):
"""
A Callable protocol to implement for custom model provider completions.
See `__call__` for more details.
"""
def __call__(self,
prompt: str,
model: Optional[str],
chat_history: Optional[ChatHistory],
**kwargs: Any) -> str:
"""
Define a call to your custom provider.
Parameters:
- `prompt`: Text to prompt the model. (If it's a chat model, this is the new message to send.)
- `model`: The name of the particular model to use, from the CF settings window. Useful when you have multiple models for a single provider. Optional.
- `chat_history`: Providers may be passed a past chat context as a list of chat messages in OpenAI format (see `chainforge.providers.ChatHistory`).
Chat history does not include the new message to send off (which is passed instead as the `prompt` parameter).
- `kwargs`: Any other parameters to pass the provider API call, like temperature. Parameter names are the keynames in your provider's settings_schema.
Only relevant if you are defining a custom settings_schema JSON to edit provider/model settings in ChainForge.
"""
pass
"""
A registry for custom providers
"""
class _ProviderRegistry:
def __init__(self):
self._registry = {}
self._curr_script_id = '0'
self._last_updated = {}
def set_curr_script_id(self, id: str):
self._curr_script_id = id
def register(self, cls: CustomProviderProtocol, name: str, **kwargs):
if name is None or isinstance(name, str) is False or len(name) == 0:
raise Exception("Cannot register custom model provider: No name given. Name must be a string and unique.")
self._last_updated[name] = self._registry[name]["script_id"] if name in self._registry else None
self._registry[name] = { "name": name, "func": cls, "script_id": self._curr_script_id, **kwargs }
def get(self, name):
return self._registry.get(name)
def get_all(self):
return list(self._registry.values())
def has(self, name):
return name in self._registry
def remove(self, name):
if self.has(name):
del self._registry[name]
def watch_next_registered(self):
self._last_updated = {}
def last_registered(self):
return {k: v for k, v in self._last_updated.items()}
# Global instance of the registry.
ProviderRegistry = _ProviderRegistry()
def provider(name: str = 'Custom Provider',
emoji: str = '',
models: Optional[List[str]] = None,
rate_limit: Union[int, Literal["sequential"]] = "sequential",
settings_schema: Optional[Dict] = None):
"""
A decorator for registering custom LLM provider methods or classes (Callables)
that conform to `CustomProviderProtocol`.
Parameters:
- `name`: The name of your custom provider. Required. (Must be unique; cannot be blank.)
- `emoji`: The emoji to use as the default icon for your provider in the CF interface. Required.
- `models`: A list of models that your provider supports, that you want to be able to choose between in Settings window.
If you're just calling a single model, you can omit this.
- `rate_limit`: If an integer, the maximum number of simulatenous requests to send per minute.
To force requests to be sequential (wait until each request returns before sending another), enter "sequential". Default is sequential.
- `settings_schema`: a JSON Schema specifying the name of your provider in the ChainForge UI, the available settings, and the UI for those settings.
The settings and UI specs are in react-jsonschema-form format: https://rjsf-team.github.io/react-jsonschema-form/.
Specifically, your `settings_schema` dict should have keys:
```
{
"settings": <JSON dict of the schema properties for your settings form, in react-jsonschema-form format (https://rjsf-team.github.io/react-jsonschema-form/docs/)>,
"ui": <JSON dict of the UI Schema for your settings form, in react-jsonschema-form (see UISchema example here: https://rjsf-team.github.io/react-jsonschema-form/)
}
```
You may look to adapt an existing schema from `ModelSettingsSchemas.js` in `chainforge/react-server/src/`,
BUT with the following things to keep in mind:
- the value of "settings" should just be the value of "properties" in the full schema
- don't include the 'shortname' property; this will be added by default and set to the value of `name`
- don't include the 'model' property; this will be populated by the list you passed to `models` (if any)
- the keynames of all properties of the schema should be valid as variable names for Python keyword args; i.e., no spaces
Finally, if you want temperature to appear in the ChainForge UI, you must name your
settings schema property `temperature`, and give it `minimum` and `maximum` values.
NOTE: Only `textarea`, `range`, and enum, and text input UI widgets are properly supported from `react-jsonschema-form`;
you can try other widget types, but the CSS may not display property.
"""
def dec(cls: CustomProviderProtocol):
ProviderRegistry.register(cls, name=name, emoji=emoji, models=models, rate_limit=rate_limit, settings_schema=settings_schema)
return cls
return dec