diff --git a/lollms/utilities.py b/lollms/utilities.py index 05cc077..a781c4d 100644 --- a/lollms/utilities.py +++ b/lollms/utilities.py @@ -18,7 +18,44 @@ import subprocess import gc from typing import List +# Prompting tools +def detect_antiprompt(text:str, anti_prompts=["!@>"]) -> bool: + """ + Detects if any of the antiprompts in self.anti_prompts are present in the given text. + Used for the Hallucination suppression system + Args: + text (str): The text to check for antiprompts. + + Returns: + bool: True if any antiprompt is found in the text (ignoring case), False otherwise. + """ + for prompt in anti_prompts: + if prompt.lower() in text.lower(): + return prompt.lower() + return None + + +def remove_text_from_string(string, text_to_find): + """ + Removes everything from the first occurrence of the specified text in the string (case-insensitive). + + Parameters: + string (str): The original string. + text_to_find (str): The text to find in the string. + + Returns: + str: The updated string. + """ + index = string.lower().find(text_to_find.lower()) + + if index != -1: + string = string[:index] + + return string + + +# Pytorch and cuda tools def check_torch_version(min_version): import torch # Extract torch version from __version__ attribute with regular expression