mirror of
https://github.com/mudler/LocalAI.git
synced 2025-01-29 15:44:17 +00:00
feat(transformers): add support to OuteTTS (#4622)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
a761e01944
commit
ee7904f170
@ -24,7 +24,7 @@ XPU=os.environ.get("XPU", "0") == "1"
|
||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
from scipy.io import wavfile
|
||||
|
||||
import outetts
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
@ -87,6 +87,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
self.CUDA = torch.cuda.is_available()
|
||||
self.OV=False
|
||||
self.OuteTTS=False
|
||||
|
||||
device_map="cpu"
|
||||
|
||||
@ -195,7 +196,45 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.OV = True
|
||||
elif request.Type == "MusicgenForConditionalGeneration":
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
elif request.Type == "OuteTTS":
|
||||
options = request.Options
|
||||
MODELNAME = "OuteAI/OuteTTS-0.3-1B"
|
||||
TOKENIZER = "OuteAI/OuteTTS-0.3-1B"
|
||||
VERSION = "0.3"
|
||||
SPEAKER = "en_male_1"
|
||||
for opt in options:
|
||||
if opt.startswith("tokenizer:"):
|
||||
TOKENIZER = opt.split(":")[1]
|
||||
break
|
||||
if opt.startswith("version:"):
|
||||
VERSION = opt.split(":")[1]
|
||||
break
|
||||
if opt.startswith("speaker:"):
|
||||
SPEAKER = opt.split(":")[1]
|
||||
break
|
||||
|
||||
if model_name != "":
|
||||
MODELNAME = model_name
|
||||
|
||||
# Configure the model
|
||||
model_config = outetts.HFModelConfig_v2(
|
||||
model_path=MODELNAME,
|
||||
tokenizer_path=TOKENIZER
|
||||
)
|
||||
# Initialize the interface
|
||||
self.interface = outetts.InterfaceHF(model_version=VERSION, cfg=model_config)
|
||||
self.OuteTTS = True
|
||||
|
||||
self.interface.print_default_speakers()
|
||||
if request.AudioPath:
|
||||
if os.path.isabs(request.AudioPath):
|
||||
self.AudioPath = request.AudioPath
|
||||
else:
|
||||
self.AudioPath = os.path.join(request.ModelPath, request.AudioPath)
|
||||
self.speaker = self.interface.create_speaker(audio_path=self.AudioPath)
|
||||
else:
|
||||
self.speaker = self.interface.load_default_speaker(name=SPEAKER)
|
||||
else:
|
||||
print("Automodel", file=sys.stderr)
|
||||
self.model = AutoModel.from_pretrained(model_name,
|
||||
@ -206,7 +245,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
torch_dtype=compute)
|
||||
if request.ContextSize > 0:
|
||||
self.max_tokens = request.ContextSize
|
||||
elif request.Type != "MusicgenForConditionalGeneration":
|
||||
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
|
||||
self.max_tokens = self.model.config.max_position_embeddings
|
||||
else:
|
||||
self.max_tokens = 512
|
||||
@ -445,9 +484,30 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def OuteTTS(self, request, context):
|
||||
try:
|
||||
print("[OuteTTS] generating TTS", file=sys.stderr)
|
||||
gen_cfg = outetts.GenerationConfig(
|
||||
text="Speech synthesis is the artificial production of human speech.",
|
||||
temperature=0.1,
|
||||
repetition_penalty=1.1,
|
||||
max_length=self.max_tokens,
|
||||
speaker=self.speaker,
|
||||
# voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness"
|
||||
)
|
||||
output = self.interface.generate(config=gen_cfg)
|
||||
print("[OuteTTS] Generated TTS", file=sys.stderr)
|
||||
output.save(request.dst)
|
||||
print("[OuteTTS] TTS done", file=sys.stderr)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
||||
def TTS(self, request, context):
|
||||
if self.OuteTTS:
|
||||
return self.OuteTTS(request, context)
|
||||
|
||||
model_name = request.model
|
||||
try:
|
||||
if self.processor is None:
|
||||
@ -463,7 +523,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = 512 # No good place to set the "length" in TTS, so use 10s as a sane default
|
||||
tokens = self.max_tokens # No good place to set the "length" in TTS, so use 10s as a sane default
|
||||
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
|
||||
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
|
@ -1,4 +1,6 @@
|
||||
torch==2.4.1
|
||||
llvmlite==0.43.0
|
||||
accelerate
|
||||
transformers
|
||||
bitsandbytes
|
||||
bitsandbytes
|
||||
outetts
|
@ -1,5 +1,7 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.4.1+cu118
|
||||
llvmlite==0.43.0
|
||||
accelerate
|
||||
transformers
|
||||
bitsandbytes
|
||||
bitsandbytes
|
||||
outetts
|
@ -1,4 +1,6 @@
|
||||
torch==2.4.1
|
||||
accelerate
|
||||
llvmlite==0.43.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
bitsandbytes
|
||||
outetts
|
@ -2,4 +2,6 @@
|
||||
torch==2.4.1+rocm6.0
|
||||
accelerate
|
||||
transformers
|
||||
bitsandbytes
|
||||
llvmlite==0.43.0
|
||||
bitsandbytes
|
||||
outetts
|
@ -3,5 +3,7 @@ intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
oneccl_bind_pt==2.3.100+xpu
|
||||
optimum[openvino]
|
||||
llvmlite==0.43.0
|
||||
intel-extension-for-transformers
|
||||
bitsandbytes
|
||||
bitsandbytes
|
||||
outetts
|
@ -2,4 +2,6 @@ grpcio==1.69.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
scipy==1.14.0
|
||||
scipy==1.14.0
|
||||
numpy>=2.0.0
|
||||
numba==0.60.0
|
Loading…
x
Reference in New Issue
Block a user