upgraded personality functionalities

This commit is contained in:
Saifeddine ALOUI 2023-11-12 17:55:20 +01:00
parent aaf2a99598
commit 0631ce4b2c

View File

@ -70,6 +70,8 @@ class AIPersonality:
Raises:
ValueError: If the provided path is not a folder or does not contain a config.yaml file.
"""
self.bot_says = ""
self.lollms_paths = lollms_paths
self.model = model
self.config = config
@ -167,6 +169,88 @@ Date: {{date}}
# Open and store the personality
self.load_personality()
def fast_gen(self, prompt: str, max_generation_size: int=None, placeholders: dict = {}, sacrifice: list = ["previous_discussion"], debug: bool = False) -> str:
"""
Fast way to generate code
This method takes in a prompt, maximum generation size, optional placeholders, sacrifice list, and debug flag.
It reshapes the context before performing text generation by adjusting and cropping the number of tokens.
Parameters:
- prompt (str): The input prompt for text generation.
- max_generation_size (int): The maximum number of tokens to generate.
- placeholders (dict, optional): A dictionary of placeholders to be replaced in the prompt. Defaults to an empty dictionary.
- sacrifice (list, optional): A list of placeholders to sacrifice if the window is bigger than the context size minus the number of tokens to generate. Defaults to ["previous_discussion"].
- debug (bool, optional): Flag to enable/disable debug mode. Defaults to False.
Returns:
- str: The generated text after removing special tokens ("<s>" and "</s>") and stripping any leading/trailing whitespace.
"""
if max_generation_size is None:
prompt_size = self.model.tokenize(prompt)
max_generation_size = self.model.config.ctx_size - len(prompt_size)
pr = PromptReshaper(prompt)
prompt = pr.build(placeholders,
self.model.tokenize,
self.model.detokenize,
self.model.config.ctx_size - max_generation_size,
sacrifice
)
if debug:
self.print_prompt("prompt", prompt)
return self.generate(prompt, max_generation_size).strip().replace("</s>", "").replace("<s>", "")
def remove_text_from_string(self, string, text_to_find):
"""
Removes everything from the first occurrence of the specified text in the string (case-insensitive).
Parameters:
string (str): The original string.
text_to_find (str): The text to find in the string.
Returns:
str: The updated string.
"""
index = string.lower().find(text_to_find.lower())
if index != -1:
string = string[:index]
return string
def process(self, text:str, message_type:MSG_TYPE, callback=None):
if text is None:
return True
bot_says = self.bot_says + text
antiprompt = self.detect_antiprompt(bot_says)
if antiprompt:
self.bot_says = self.remove_text_from_string(bot_says,antiprompt)
ASCIIColors.warning(f"\nDetected hallucination with antiprompt: {antiprompt}")
return False
else:
if callback:
callback(text,MSG_TYPE.MSG_TYPE_CHUNK)
self.bot_says = bot_says
return True
def generate(self, prompt, max_size, temperature = None, top_k = None, top_p=None, repeat_penalty=None, callback=None ):
ASCIIColors.info("Text generation started: Warming up")
self.bot_says = ""
self.model.generate(
prompt,
max_size,
partial(self.process, callback=callback),
temperature=self.model_temperature if temperature is None else temperature,
top_k=self.model_top_k if top_k is None else top_k,
top_p=self.model_top_p if top_p is None else top_p,
repeat_penalty=self.model_repeat_penalty if repeat_penalty is None else repeat_penalty,
).strip()
return self.bot_says
def setCallback(self, callback: Callable[[str, MSG_TYPE, dict, list], bool]):
self.callback = callback
if self._processor:
@ -1242,53 +1326,11 @@ class APScript(StateMachine):
yaml.dump(data, file)
def remove_text_from_string(self, string, text_to_find):
"""
Removes everything from the first occurrence of the specified text in the string (case-insensitive).
Parameters:
string (str): The original string.
text_to_find (str): The text to find in the string.
Returns:
str: The updated string.
"""
index = string.lower().find(text_to_find.lower())
if index != -1:
string = string[:index]
return string
def process(self, text:str, message_type:MSG_TYPE, callback=None):
if text is None:
return True
bot_says = self.bot_says + text
antiprompt = self.personality.detect_antiprompt(bot_says)
if antiprompt:
self.bot_says = self.remove_text_from_string(bot_says,antiprompt)
ASCIIColors.warning(f"\nDetected hallucination with antiprompt: {antiprompt}")
return False
else:
if callback:
callback(text,MSG_TYPE.MSG_TYPE_CHUNK)
self.bot_says = bot_says
return True
def generate(self, prompt, max_size, temperature = None, top_k = None, top_p=None, repeat_penalty=None, callback=None ):
self.bot_says = ""
ASCIIColors.info("Text generation started: Warming up")
self.personality.model.generate(
prompt,
max_size,
partial(self.process, callback=callback),
temperature=self.personality.model_temperature if temperature is None else temperature,
top_k=self.personality.model_top_k if top_k is None else top_k,
top_p=self.personality.model_top_p if top_p is None else top_p,
repeat_penalty=self.personality.model_repeat_penalty if repeat_penalty is None else repeat_penalty,
).strip()
return self.bot_says
self.personality.generate(prompt, max_size, temperature, top_k, top_p, repeat_penalty, callback)
def run_workflow(self, prompt:str, previous_discussion_text:str="", callback: Callable[[str, MSG_TYPE, dict, list], bool]=None):
"""
@ -1656,7 +1698,7 @@ Yes or No?
ASCIIColors.yellow(prompt)
ASCIIColors.red(" *-*-*-*-*-*-*-*")
def fast_gen(self, prompt: str, max_generation_size: int, placeholders: dict = {}, sacrifice: list = ["previous_discussion"], debug: bool = False) -> str:
def fast_gen(self, prompt: str, max_generation_size: int= None, placeholders: dict = {}, sacrifice: list = ["previous_discussion"], debug: bool = False) -> str:
"""
Fast way to generate code
@ -1673,18 +1715,7 @@ Yes or No?
Returns:
- str: The generated text after removing special tokens ("<s>" and "</s>") and stripping any leading/trailing whitespace.
"""
pr = PromptReshaper(prompt)
prompt = pr.build(placeholders,
self.personality.model.tokenize,
self.personality.model.detokenize,
self.personality.model.config.ctx_size - max_generation_size,
sacrifice
)
if debug:
self.print_prompt("prompt", prompt)
return self.generate(prompt, max_generation_size).strip().replace("</s>", "").replace("<s>", "")
return self.personality.fast_gen(prompt=prompt,max_generation_size=max_generation_size,placeholders=placeholders, sacrifice=sacrifice, debug=debug)
#Helper method to convert outputs path to url