diff --git a/lollms/tasks.py b/lollms/tasks.py
index 65aec85..8d4a5df 100644
--- a/lollms/tasks.py
+++ b/lollms/tasks.py
@@ -80,16 +80,16 @@ class TasksLibrary:
self.bot_says = bot_says
return True
- def generate(self, prompt, max_size, temperature = None, top_k = None, top_p=None, repeat_penalty=None, repeat_last_n=None, callback=None, debug=False, show_progress=False ):
+ def generate(self, prompt, max_size= None, temperature = None, top_k = None, top_p=None, repeat_penalty=None, repeat_last_n=None, callback=None, debug=False, show_progress=False ):
ASCIIColors.info("Text generation started: Warming up")
self.nb_received_tokens = 0
self.bot_says = ""
if debug:
self.print_prompt("gen",prompt)
-
+ ntokens = len(self.lollms.model.tokenize(prompt))
self.lollms.model.generate(
prompt,
- max_size,
+ max_size if max_size else min(self.lollms.config.ctx_size-ntokens,self.lollms.config.max_n_predict),
partial(self.process, callback=callback, show_progress=show_progress),
temperature= temperature if temperature is not None else self.lollms.config.temperature if self.lollms.config.override_personality_model_parameters else self.lollms.personality.model_temperature,
top_k= top_k if top_k is not None else self.lollms.config.top_k if self.lollms.config.override_personality_model_parameters else self.lollms.personality.model_top_k,
@@ -150,10 +150,6 @@ class TasksLibrary:
Returns:
- str: The generated text after removing special tokens ("" and "") and stripping any leading/trailing whitespace.
"""
- if max_generation_size is None:
- prompt_size = self.lollms.model.tokenize(prompt)
- max_generation_size = self.lollms.model.config.ctx_size - len(prompt_size)
-
pr = PromptReshaper(prompt)
prompt = pr.build(placeholders,
self.lollms.model.tokenize,
@@ -161,10 +157,7 @@ class TasksLibrary:
self.lollms.model.config.ctx_size - max_generation_size,
sacrifice
)
- ntk = len(self.lollms.model.tokenize(prompt))
- max_generation_size = min(self.lollms.model.config.ctx_size - ntk, max_generation_size)
# TODO : add show progress
-
gen = self.generate(prompt, max_generation_size, temperature = temperature, top_k = top_k, top_p=top_p, repeat_penalty=repeat_penalty, repeat_last_n=repeat_last_n, callback=callback, show_progress=show_progress).strip().replace("", "").replace("", "")
if debug:
self.print_prompt("prompt", prompt+gen)