added print prompt to fastgen

This commit is contained in:
Saifeddine ALOUI 2023-11-13 16:30:28 +01:00
parent 20ac66c5a6
commit e4ed050527

View File

@ -169,6 +169,13 @@ Date: {{date}}
# Open and store the personality # Open and store the personality
self.load_personality() self.load_personality()
def print_prompt(self, title, prompt):
ASCIIColors.red("*-*-*-*-*-*-*-* ", end="")
ASCIIColors.red(title, end="")
ASCIIColors.red(" *-*-*-*-*-*-*-*")
ASCIIColors.yellow(prompt)
ASCIIColors.red(" *-*-*-*-*-*-*-*")
def fast_gen(self, prompt: str, max_generation_size: int=None, placeholders: dict = {}, sacrifice: list = ["previous_discussion"], debug: bool = False, callback=None) -> str: def fast_gen(self, prompt: str, max_generation_size: int=None, placeholders: dict = {}, sacrifice: list = ["previous_discussion"], debug: bool = False, callback=None) -> str:
""" """
@ -187,6 +194,9 @@ Date: {{date}}
Returns: Returns:
- str: The generated text after removing special tokens ("<s>" and "</s>") and stripping any leading/trailing whitespace. - str: The generated text after removing special tokens ("<s>" and "</s>") and stripping any leading/trailing whitespace.
""" """
if debug == False:
debug = self.config.debug
if max_generation_size is None: if max_generation_size is None:
prompt_size = self.model.tokenize(prompt) prompt_size = self.model.tokenize(prompt)
max_generation_size = self.model.config.ctx_size - len(prompt_size) max_generation_size = self.model.config.ctx_size - len(prompt_size)