diff --git a/lollms/services/xtts/lollms_xtts.py b/lollms/services/xtts/lollms_xtts.py index fc5362b..641e693 100644 --- a/lollms/services/xtts/lollms_xtts.py +++ b/lollms/services/xtts/lollms_xtts.py @@ -53,7 +53,7 @@ class LollmsXTTS(LollmsTTS): use_streaming_mode = True ): super().__init__(app) - self.generation_threads = [] + self.generation_threads = {} self.voices_folder = voices_folder self.ready = False if xtts_base_url=="" or xtts_base_url=="http://127.0.0.1:8020": @@ -260,7 +260,7 @@ class LollmsXTTS(LollmsTTS): return False def tts_file(self, text, file_name_or_path, speaker=None, language="en")->str: - url = f"{self.xtts_base_url}/tts_file" + url = f"{self.xtts_base_url}/tts_to_file" # Define the request body payload = { @@ -315,7 +315,7 @@ class LollmsXTTS(LollmsTTS): trace_exception(ex) return {"status":False,"error":f"{ex}"} - def xtts_audio(self, text, speaker, file_name_or_path:Path|str=None, language="en", use_threading=False): + def xtts_audio(self, text, speaker, file_name_or_path:Path|str=None, language="en", use_threading=True): # Remove HTML tags text = re.sub(r'<.*?>', '', text) # Remove code blocks (assuming they're enclosed in backticks or similar markers) @@ -325,7 +325,7 @@ class LollmsXTTS(LollmsTTS): text = re.sub(r'[\{\}\[\]\(\)<>]', '', text) text = text.replace("\\","") def tts2_audio_th(thread_uid=None): - url = f"{self.xtts_base_url}/tts_audio" + url = f"{self.xtts_base_url}/tts_to_audio" # Define the request body payload = { @@ -362,7 +362,7 @@ class LollmsXTTS(LollmsTTS): self.generation_threads.pop(thread_uid, None) if use_threading: thread_uid = str(uuid.uuid4()) - thread = threading.Thread(target=tts2_audio_th, args=(thread_uid)) + thread = threading.Thread(target=tts2_audio_th, args=(thread_uid,)) self.generation_threads[thread_uid]=thread thread.start() ASCIIColors.green("Generation started") diff --git a/lollms/stt.py b/lollms/stt.py index fa1d063..ce9d266 100644 --- a/lollms/stt.py +++ b/lollms/stt.py @@ -72,6 +72,13 @@ class LollmsSTT: """ pass + def stop(self): + """ + Stops the current generation + """ + pass + + def get_models(self): return self.models @@ -114,6 +121,7 @@ class LollmsSTT: """ return LollmsSTT + def get_devices(self): devices = sd.query_devices() print(devices) diff --git a/lollms/tts.py b/lollms/tts.py index 4a95f91..68650a7 100644 --- a/lollms/tts.py +++ b/lollms/tts.py @@ -87,6 +87,11 @@ class LollmsTTS: """ pass + def stop(self): + """ + Stops the current generation + """ + pass @staticmethod def verify(app: LollmsApplication) -> bool: