From 35b239866ac5d942b9fbe8280043b21b8c085df3 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 29 Jul 2024 09:48:28 +0200 Subject: [PATCH] Update tasks.py --- lollms/tasks.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/lollms/tasks.py b/lollms/tasks.py index 65aec85..8d4a5df 100644 --- a/lollms/tasks.py +++ b/lollms/tasks.py @@ -80,16 +80,16 @@ class TasksLibrary: self.bot_says = bot_says return True - def generate(self, prompt, max_size, temperature = None, top_k = None, top_p=None, repeat_penalty=None, repeat_last_n=None, callback=None, debug=False, show_progress=False ): + def generate(self, prompt, max_size= None, temperature = None, top_k = None, top_p=None, repeat_penalty=None, repeat_last_n=None, callback=None, debug=False, show_progress=False ): ASCIIColors.info("Text generation started: Warming up") self.nb_received_tokens = 0 self.bot_says = "" if debug: self.print_prompt("gen",prompt) - + ntokens = len(self.lollms.model.tokenize(prompt)) self.lollms.model.generate( prompt, - max_size, + max_size if max_size else min(self.lollms.config.ctx_size-ntokens,self.lollms.config.max_n_predict), partial(self.process, callback=callback, show_progress=show_progress), temperature= temperature if temperature is not None else self.lollms.config.temperature if self.lollms.config.override_personality_model_parameters else self.lollms.personality.model_temperature, top_k= top_k if top_k is not None else self.lollms.config.top_k if self.lollms.config.override_personality_model_parameters else self.lollms.personality.model_top_k, @@ -150,10 +150,6 @@ class TasksLibrary: Returns: - str: The generated text after removing special tokens ("" and "") and stripping any leading/trailing whitespace. """ - if max_generation_size is None: - prompt_size = self.lollms.model.tokenize(prompt) - max_generation_size = self.lollms.model.config.ctx_size - len(prompt_size) - pr = PromptReshaper(prompt) prompt = pr.build(placeholders, self.lollms.model.tokenize, @@ -161,10 +157,7 @@ class TasksLibrary: self.lollms.model.config.ctx_size - max_generation_size, sacrifice ) - ntk = len(self.lollms.model.tokenize(prompt)) - max_generation_size = min(self.lollms.model.config.ctx_size - ntk, max_generation_size) # TODO : add show progress - gen = self.generate(prompt, max_generation_size, temperature = temperature, top_k = top_k, top_p=top_p, repeat_penalty=repeat_penalty, repeat_last_n=repeat_last_n, callback=callback, show_progress=show_progress).strip().replace("", "").replace("", "") if debug: self.print_prompt("prompt", prompt+gen)