Update tasks.py

This commit is contained in:
Saifeddine ALOUI 2024-07-29 09:48:28 +02:00 committed by GitHub
parent dbc2120b3a
commit 35b239866a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -80,16 +80,16 @@ class TasksLibrary:
self.bot_says = bot_says self.bot_says = bot_says
return True return True
def generate(self, prompt, max_size, temperature = None, top_k = None, top_p=None, repeat_penalty=None, repeat_last_n=None, callback=None, debug=False, show_progress=False ): def generate(self, prompt, max_size= None, temperature = None, top_k = None, top_p=None, repeat_penalty=None, repeat_last_n=None, callback=None, debug=False, show_progress=False ):
ASCIIColors.info("Text generation started: Warming up") ASCIIColors.info("Text generation started: Warming up")
self.nb_received_tokens = 0 self.nb_received_tokens = 0
self.bot_says = "" self.bot_says = ""
if debug: if debug:
self.print_prompt("gen",prompt) self.print_prompt("gen",prompt)
ntokens = len(self.lollms.model.tokenize(prompt))
self.lollms.model.generate( self.lollms.model.generate(
prompt, prompt,
max_size, max_size if max_size else min(self.lollms.config.ctx_size-ntokens,self.lollms.config.max_n_predict),
partial(self.process, callback=callback, show_progress=show_progress), partial(self.process, callback=callback, show_progress=show_progress),
temperature= temperature if temperature is not None else self.lollms.config.temperature if self.lollms.config.override_personality_model_parameters else self.lollms.personality.model_temperature, temperature= temperature if temperature is not None else self.lollms.config.temperature if self.lollms.config.override_personality_model_parameters else self.lollms.personality.model_temperature,
top_k= top_k if top_k is not None else self.lollms.config.top_k if self.lollms.config.override_personality_model_parameters else self.lollms.personality.model_top_k, top_k= top_k if top_k is not None else self.lollms.config.top_k if self.lollms.config.override_personality_model_parameters else self.lollms.personality.model_top_k,
@ -150,10 +150,6 @@ class TasksLibrary:
Returns: Returns:
- str: The generated text after removing special tokens ("<s>" and "</s>") and stripping any leading/trailing whitespace. - str: The generated text after removing special tokens ("<s>" and "</s>") and stripping any leading/trailing whitespace.
""" """
if max_generation_size is None:
prompt_size = self.lollms.model.tokenize(prompt)
max_generation_size = self.lollms.model.config.ctx_size - len(prompt_size)
pr = PromptReshaper(prompt) pr = PromptReshaper(prompt)
prompt = pr.build(placeholders, prompt = pr.build(placeholders,
self.lollms.model.tokenize, self.lollms.model.tokenize,
@ -161,10 +157,7 @@ class TasksLibrary:
self.lollms.model.config.ctx_size - max_generation_size, self.lollms.model.config.ctx_size - max_generation_size,
sacrifice sacrifice
) )
ntk = len(self.lollms.model.tokenize(prompt))
max_generation_size = min(self.lollms.model.config.ctx_size - ntk, max_generation_size)
# TODO : add show progress # TODO : add show progress
gen = self.generate(prompt, max_generation_size, temperature = temperature, top_k = top_k, top_p=top_p, repeat_penalty=repeat_penalty, repeat_last_n=repeat_last_n, callback=callback, show_progress=show_progress).strip().replace("</s>", "").replace("<s>", "") gen = self.generate(prompt, max_generation_size, temperature = temperature, top_k = top_k, top_p=top_p, repeat_penalty=repeat_penalty, repeat_last_n=repeat_last_n, callback=callback, show_progress=show_progress).strip().replace("</s>", "").replace("<s>", "")
if debug: if debug:
self.print_prompt("prompt", prompt+gen) self.print_prompt("prompt", prompt+gen)