From 2a2aa9f8bbf4975e68900ec7afbf0f43b0df8944 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Thu, 16 Nov 2023 00:48:12 +0100 Subject: [PATCH] added halucination suppression system to utilities --- lollms/utilities.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/lollms/utilities.py b/lollms/utilities.py index 05cc077..a781c4d 100644 --- a/lollms/utilities.py +++ b/lollms/utilities.py @@ -18,7 +18,44 @@ import subprocess import gc from typing import List +# Prompting tools +def detect_antiprompt(text:str, anti_prompts=["!@>"]) -> bool: + """ + Detects if any of the antiprompts in self.anti_prompts are present in the given text. + Used for the Hallucination suppression system + Args: + text (str): The text to check for antiprompts. + + Returns: + bool: True if any antiprompt is found in the text (ignoring case), False otherwise. + """ + for prompt in anti_prompts: + if prompt.lower() in text.lower(): + return prompt.lower() + return None + + +def remove_text_from_string(string, text_to_find): + """ + Removes everything from the first occurrence of the specified text in the string (case-insensitive). + + Parameters: + string (str): The original string. + text_to_find (str): The text to find in the string. + + Returns: + str: The updated string. + """ + index = string.lower().find(text_to_find.lower()) + + if index != -1: + string = string[:index] + + return string + + +# Pytorch and cuda tools def check_torch_version(min_version): import torch # Extract torch version from __version__ attribute with regular expression