This commit is contained in:
Saifeddine ALOUI 2024-05-20 18:20:26 +02:00
parent bb9bb2fa92
commit eb32509bcc
3 changed files with 18 additions and 5 deletions

View File

@ -53,7 +53,7 @@ class LollmsXTTS(LollmsTTS):
use_streaming_mode = True use_streaming_mode = True
): ):
super().__init__(app) super().__init__(app)
self.generation_threads = [] self.generation_threads = {}
self.voices_folder = voices_folder self.voices_folder = voices_folder
self.ready = False self.ready = False
if xtts_base_url=="" or xtts_base_url=="http://127.0.0.1:8020": if xtts_base_url=="" or xtts_base_url=="http://127.0.0.1:8020":
@ -260,7 +260,7 @@ class LollmsXTTS(LollmsTTS):
return False return False
def tts_file(self, text, file_name_or_path, speaker=None, language="en")->str: def tts_file(self, text, file_name_or_path, speaker=None, language="en")->str:
url = f"{self.xtts_base_url}/tts_file" url = f"{self.xtts_base_url}/tts_to_file"
# Define the request body # Define the request body
payload = { payload = {
@ -315,7 +315,7 @@ class LollmsXTTS(LollmsTTS):
trace_exception(ex) trace_exception(ex)
return {"status":False,"error":f"{ex}"} return {"status":False,"error":f"{ex}"}
def xtts_audio(self, text, speaker, file_name_or_path:Path|str=None, language="en", use_threading=False): def xtts_audio(self, text, speaker, file_name_or_path:Path|str=None, language="en", use_threading=True):
# Remove HTML tags # Remove HTML tags
text = re.sub(r'<.*?>', '', text) text = re.sub(r'<.*?>', '', text)
# Remove code blocks (assuming they're enclosed in backticks or similar markers) # Remove code blocks (assuming they're enclosed in backticks or similar markers)
@ -325,7 +325,7 @@ class LollmsXTTS(LollmsTTS):
text = re.sub(r'[\{\}\[\]\(\)<>]', '', text) text = re.sub(r'[\{\}\[\]\(\)<>]', '', text)
text = text.replace("\\","") text = text.replace("\\","")
def tts2_audio_th(thread_uid=None): def tts2_audio_th(thread_uid=None):
url = f"{self.xtts_base_url}/tts_audio" url = f"{self.xtts_base_url}/tts_to_audio"
# Define the request body # Define the request body
payload = { payload = {
@ -362,7 +362,7 @@ class LollmsXTTS(LollmsTTS):
self.generation_threads.pop(thread_uid, None) self.generation_threads.pop(thread_uid, None)
if use_threading: if use_threading:
thread_uid = str(uuid.uuid4()) thread_uid = str(uuid.uuid4())
thread = threading.Thread(target=tts2_audio_th, args=(thread_uid)) thread = threading.Thread(target=tts2_audio_th, args=(thread_uid,))
self.generation_threads[thread_uid]=thread self.generation_threads[thread_uid]=thread
thread.start() thread.start()
ASCIIColors.green("Generation started") ASCIIColors.green("Generation started")

View File

@ -72,6 +72,13 @@ class LollmsSTT:
""" """
pass pass
def stop(self):
"""
Stops the current generation
"""
pass
def get_models(self): def get_models(self):
return self.models return self.models
@ -114,6 +121,7 @@ class LollmsSTT:
""" """
return LollmsSTT return LollmsSTT
def get_devices(self): def get_devices(self):
devices = sd.query_devices() devices = sd.query_devices()
print(devices) print(devices)

View File

@ -87,6 +87,11 @@ class LollmsTTS:
""" """
pass pass
def stop(self):
"""
Stops the current generation
"""
pass
@staticmethod @staticmethod
def verify(app: LollmsApplication) -> bool: def verify(app: LollmsApplication) -> bool: