mirror of
https://github.com/ParisNeo/lollms.git
synced 2025-03-01 04:06:07 +00:00
Update tasks.py
This commit is contained in:
parent
dbc2120b3a
commit
35b239866a
@ -80,16 +80,16 @@ class TasksLibrary:
|
||||
self.bot_says = bot_says
|
||||
return True
|
||||
|
||||
def generate(self, prompt, max_size, temperature = None, top_k = None, top_p=None, repeat_penalty=None, repeat_last_n=None, callback=None, debug=False, show_progress=False ):
|
||||
def generate(self, prompt, max_size= None, temperature = None, top_k = None, top_p=None, repeat_penalty=None, repeat_last_n=None, callback=None, debug=False, show_progress=False ):
|
||||
ASCIIColors.info("Text generation started: Warming up")
|
||||
self.nb_received_tokens = 0
|
||||
self.bot_says = ""
|
||||
if debug:
|
||||
self.print_prompt("gen",prompt)
|
||||
|
||||
ntokens = len(self.lollms.model.tokenize(prompt))
|
||||
self.lollms.model.generate(
|
||||
prompt,
|
||||
max_size,
|
||||
max_size if max_size else min(self.lollms.config.ctx_size-ntokens,self.lollms.config.max_n_predict),
|
||||
partial(self.process, callback=callback, show_progress=show_progress),
|
||||
temperature= temperature if temperature is not None else self.lollms.config.temperature if self.lollms.config.override_personality_model_parameters else self.lollms.personality.model_temperature,
|
||||
top_k= top_k if top_k is not None else self.lollms.config.top_k if self.lollms.config.override_personality_model_parameters else self.lollms.personality.model_top_k,
|
||||
@ -150,10 +150,6 @@ class TasksLibrary:
|
||||
Returns:
|
||||
- str: The generated text after removing special tokens ("<s>" and "</s>") and stripping any leading/trailing whitespace.
|
||||
"""
|
||||
if max_generation_size is None:
|
||||
prompt_size = self.lollms.model.tokenize(prompt)
|
||||
max_generation_size = self.lollms.model.config.ctx_size - len(prompt_size)
|
||||
|
||||
pr = PromptReshaper(prompt)
|
||||
prompt = pr.build(placeholders,
|
||||
self.lollms.model.tokenize,
|
||||
@ -161,10 +157,7 @@ class TasksLibrary:
|
||||
self.lollms.model.config.ctx_size - max_generation_size,
|
||||
sacrifice
|
||||
)
|
||||
ntk = len(self.lollms.model.tokenize(prompt))
|
||||
max_generation_size = min(self.lollms.model.config.ctx_size - ntk, max_generation_size)
|
||||
# TODO : add show progress
|
||||
|
||||
gen = self.generate(prompt, max_generation_size, temperature = temperature, top_k = top_k, top_p=top_p, repeat_penalty=repeat_penalty, repeat_last_n=repeat_last_n, callback=callback, show_progress=show_progress).strip().replace("</s>", "").replace("<s>", "")
|
||||
if debug:
|
||||
self.print_prompt("prompt", prompt+gen)
|
||||
|
Loading…
x
Reference in New Issue
Block a user