feat(transformers): add support to OuteTTS (#4622)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2025-01-17 19:33:25 +01:00 committed by GitHub
parent a761e01944
commit ee7904f170
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 82 additions and 10 deletions

View File

@ -24,7 +24,7 @@ XPU=os.environ.get("XPU", "0") == "1"
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
from transformers import AutoProcessor, MusicgenForConditionalGeneration
from scipy.io import wavfile
import outetts
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -87,6 +87,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.CUDA = torch.cuda.is_available()
self.OV=False
self.OuteTTS=False
device_map="cpu"
@ -195,7 +196,45 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.OV = True
elif request.Type == "MusicgenForConditionalGeneration":
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
elif request.Type == "OuteTTS":
options = request.Options
MODELNAME = "OuteAI/OuteTTS-0.3-1B"
TOKENIZER = "OuteAI/OuteTTS-0.3-1B"
VERSION = "0.3"
SPEAKER = "en_male_1"
for opt in options:
if opt.startswith("tokenizer:"):
TOKENIZER = opt.split(":")[1]
break
if opt.startswith("version:"):
VERSION = opt.split(":")[1]
break
if opt.startswith("speaker:"):
SPEAKER = opt.split(":")[1]
break
if model_name != "":
MODELNAME = model_name
# Configure the model
model_config = outetts.HFModelConfig_v2(
model_path=MODELNAME,
tokenizer_path=TOKENIZER
)
# Initialize the interface
self.interface = outetts.InterfaceHF(model_version=VERSION, cfg=model_config)
self.OuteTTS = True
self.interface.print_default_speakers()
if request.AudioPath:
if os.path.isabs(request.AudioPath):
self.AudioPath = request.AudioPath
else:
self.AudioPath = os.path.join(request.ModelPath, request.AudioPath)
self.speaker = self.interface.create_speaker(audio_path=self.AudioPath)
else:
self.speaker = self.interface.load_default_speaker(name=SPEAKER)
else:
print("Automodel", file=sys.stderr)
self.model = AutoModel.from_pretrained(model_name,
@ -206,7 +245,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
torch_dtype=compute)
if request.ContextSize > 0:
self.max_tokens = request.ContextSize
elif request.Type != "MusicgenForConditionalGeneration":
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
self.max_tokens = self.model.config.max_position_embeddings
else:
self.max_tokens = 512
@ -445,9 +484,30 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(success=True)
def OuteTTS(self, request, context):
try:
print("[OuteTTS] generating TTS", file=sys.stderr)
gen_cfg = outetts.GenerationConfig(
text="Speech synthesis is the artificial production of human speech.",
temperature=0.1,
repetition_penalty=1.1,
max_length=self.max_tokens,
speaker=self.speaker,
# voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness"
)
output = self.interface.generate(config=gen_cfg)
print("[OuteTTS] Generated TTS", file=sys.stderr)
output.save(request.dst)
print("[OuteTTS] TTS done", file=sys.stderr)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(success=True)
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
def TTS(self, request, context):
if self.OuteTTS:
return self.OuteTTS(request, context)
model_name = request.model
try:
if self.processor is None:
@ -463,7 +523,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
padding=True,
return_tensors="pt",
)
tokens = 512 # No good place to set the "length" in TTS, so use 10s as a sane default
tokens = self.max_tokens # No good place to set the "length" in TTS, so use 10s as a sane default
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
sampling_rate = self.model.config.audio_encoder.sampling_rate

View File

@ -1,4 +1,6 @@
torch==2.4.1
llvmlite==0.43.0
accelerate
transformers
bitsandbytes
bitsandbytes
outetts

View File

@ -1,5 +1,7 @@
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.4.1+cu118
llvmlite==0.43.0
accelerate
transformers
bitsandbytes
bitsandbytes
outetts

View File

@ -1,4 +1,6 @@
torch==2.4.1
accelerate
llvmlite==0.43.0
transformers
bitsandbytes
bitsandbytes
outetts

View File

@ -2,4 +2,6 @@
torch==2.4.1+rocm6.0
accelerate
transformers
bitsandbytes
llvmlite==0.43.0
bitsandbytes
outetts

View File

@ -3,5 +3,7 @@ intel-extension-for-pytorch==2.3.110+xpu
torch==2.3.1+cxx11.abi
oneccl_bind_pt==2.3.100+xpu
optimum[openvino]
llvmlite==0.43.0
intel-extension-for-transformers
bitsandbytes
bitsandbytes
outetts

View File

@ -2,4 +2,6 @@ grpcio==1.69.0
protobuf
certifi
setuptools
scipy==1.14.0
scipy==1.14.0
numpy>=2.0.0
numba==0.60.0