diff --git a/models/transformers/.keep b/models/transformers/.keep new file mode 100644 index 00000000..e69de29b diff --git a/pyGpt4All/backends/transformers.py b/pyGpt4All/backends/transformers.py new file mode 100644 index 00000000..54932923 --- /dev/null +++ b/pyGpt4All/backends/transformers.py @@ -0,0 +1,65 @@ +###### +# Project : GPT4ALL-UI +# File : backend.py +# Author : ParisNeo with the help of the community +# Supported by Nomic-AI +# Licence : Apache 2.0 +# Description : +# This is an interface class for GPT4All-ui backends. +###### +from pathlib import Path +from typing import Callable +from transformers import AutoTokenizer +from transformers import AutoModelForCausalLM +from pyGpt4All.backends.backend import GPTBackend + +__author__ = "parisneo" +__github__ = "https://github.com/nomic-ai/gpt4all-ui" +__copyright__ = "Copyright 2023, " +__license__ = "Apache 2.0" + + +class Transformers(GPTBackend): + def __init__(self, config:dict) -> None: + """Builds a GPT-J backend + + Args: + config (dict): The configuration file + """ + super().__init__(config) + self.config = config + self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(f"./models/transformers/{self.config['model']}/tokenizer.model", local_files_only=True) + self.model = AutoModelForCausalLM.from_pretrained(f"./models/transformers/{self.config['model']}/model.bin", local_files_only=True) + + + def generate(self, + prompt:str, + n_predict: int = 128, + new_text_callback: Callable[[str], None] = bool, + verbose: bool = False, + **gpt_params ): + """Generates text out of a prompt + + Args: + prompt (str): The prompt to use for generation + n_predict (int, optional): Number of tokens to prodict. Defaults to 128. + new_text_callback (Callable[[str], None], optional): A callback function that is called everytime a new text element is generated. Defaults to None. + verbose (bool, optional): If true, the code will spit many informations about the generation process. Defaults to False. + """ + + inputs = self.tokenizer(prompt, return_tensors="pt").input_ids + while len(inputs