mirror of
https://github.com/mudler/LocalAI.git
synced 2025-01-30 16:14:33 +00:00
feat(transformers): add support to OuteTTS (#4622)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
a761e01944
commit
ee7904f170
@ -24,7 +24,7 @@ XPU=os.environ.get("XPU", "0") == "1"
|
|||||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
|
import outetts
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
@ -87,6 +87,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
|
|
||||||
self.CUDA = torch.cuda.is_available()
|
self.CUDA = torch.cuda.is_available()
|
||||||
self.OV=False
|
self.OV=False
|
||||||
|
self.OuteTTS=False
|
||||||
|
|
||||||
device_map="cpu"
|
device_map="cpu"
|
||||||
|
|
||||||
@ -196,6 +197,44 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
elif request.Type == "MusicgenForConditionalGeneration":
|
elif request.Type == "MusicgenForConditionalGeneration":
|
||||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||||
|
elif request.Type == "OuteTTS":
|
||||||
|
options = request.Options
|
||||||
|
MODELNAME = "OuteAI/OuteTTS-0.3-1B"
|
||||||
|
TOKENIZER = "OuteAI/OuteTTS-0.3-1B"
|
||||||
|
VERSION = "0.3"
|
||||||
|
SPEAKER = "en_male_1"
|
||||||
|
for opt in options:
|
||||||
|
if opt.startswith("tokenizer:"):
|
||||||
|
TOKENIZER = opt.split(":")[1]
|
||||||
|
break
|
||||||
|
if opt.startswith("version:"):
|
||||||
|
VERSION = opt.split(":")[1]
|
||||||
|
break
|
||||||
|
if opt.startswith("speaker:"):
|
||||||
|
SPEAKER = opt.split(":")[1]
|
||||||
|
break
|
||||||
|
|
||||||
|
if model_name != "":
|
||||||
|
MODELNAME = model_name
|
||||||
|
|
||||||
|
# Configure the model
|
||||||
|
model_config = outetts.HFModelConfig_v2(
|
||||||
|
model_path=MODELNAME,
|
||||||
|
tokenizer_path=TOKENIZER
|
||||||
|
)
|
||||||
|
# Initialize the interface
|
||||||
|
self.interface = outetts.InterfaceHF(model_version=VERSION, cfg=model_config)
|
||||||
|
self.OuteTTS = True
|
||||||
|
|
||||||
|
self.interface.print_default_speakers()
|
||||||
|
if request.AudioPath:
|
||||||
|
if os.path.isabs(request.AudioPath):
|
||||||
|
self.AudioPath = request.AudioPath
|
||||||
|
else:
|
||||||
|
self.AudioPath = os.path.join(request.ModelPath, request.AudioPath)
|
||||||
|
self.speaker = self.interface.create_speaker(audio_path=self.AudioPath)
|
||||||
|
else:
|
||||||
|
self.speaker = self.interface.load_default_speaker(name=SPEAKER)
|
||||||
else:
|
else:
|
||||||
print("Automodel", file=sys.stderr)
|
print("Automodel", file=sys.stderr)
|
||||||
self.model = AutoModel.from_pretrained(model_name,
|
self.model = AutoModel.from_pretrained(model_name,
|
||||||
@ -206,7 +245,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
torch_dtype=compute)
|
torch_dtype=compute)
|
||||||
if request.ContextSize > 0:
|
if request.ContextSize > 0:
|
||||||
self.max_tokens = request.ContextSize
|
self.max_tokens = request.ContextSize
|
||||||
elif request.Type != "MusicgenForConditionalGeneration":
|
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
|
||||||
self.max_tokens = self.model.config.max_position_embeddings
|
self.max_tokens = self.model.config.max_position_embeddings
|
||||||
else:
|
else:
|
||||||
self.max_tokens = 512
|
self.max_tokens = 512
|
||||||
@ -445,9 +484,30 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
return backend_pb2.Result(success=True)
|
return backend_pb2.Result(success=True)
|
||||||
|
|
||||||
|
def OuteTTS(self, request, context):
|
||||||
|
try:
|
||||||
|
print("[OuteTTS] generating TTS", file=sys.stderr)
|
||||||
|
gen_cfg = outetts.GenerationConfig(
|
||||||
|
text="Speech synthesis is the artificial production of human speech.",
|
||||||
|
temperature=0.1,
|
||||||
|
repetition_penalty=1.1,
|
||||||
|
max_length=self.max_tokens,
|
||||||
|
speaker=self.speaker,
|
||||||
|
# voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness"
|
||||||
|
)
|
||||||
|
output = self.interface.generate(config=gen_cfg)
|
||||||
|
print("[OuteTTS] Generated TTS", file=sys.stderr)
|
||||||
|
output.save(request.dst)
|
||||||
|
print("[OuteTTS] TTS done", file=sys.stderr)
|
||||||
|
except Exception as err:
|
||||||
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
|
return backend_pb2.Result(success=True)
|
||||||
|
|
||||||
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
||||||
def TTS(self, request, context):
|
def TTS(self, request, context):
|
||||||
|
if self.OuteTTS:
|
||||||
|
return self.OuteTTS(request, context)
|
||||||
|
|
||||||
model_name = request.model
|
model_name = request.model
|
||||||
try:
|
try:
|
||||||
if self.processor is None:
|
if self.processor is None:
|
||||||
@ -463,7 +523,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
padding=True,
|
padding=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
tokens = 512 # No good place to set the "length" in TTS, so use 10s as a sane default
|
tokens = self.max_tokens # No good place to set the "length" in TTS, so use 10s as a sane default
|
||||||
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
|
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
|
||||||
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
|
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
|
||||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
torch==2.4.1
|
torch==2.4.1
|
||||||
|
llvmlite==0.43.0
|
||||||
accelerate
|
accelerate
|
||||||
transformers
|
transformers
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
|
outetts
|
@ -1,5 +1,7 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
torch==2.4.1+cu118
|
torch==2.4.1+cu118
|
||||||
|
llvmlite==0.43.0
|
||||||
accelerate
|
accelerate
|
||||||
transformers
|
transformers
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
|
outetts
|
@ -1,4 +1,6 @@
|
|||||||
torch==2.4.1
|
torch==2.4.1
|
||||||
accelerate
|
accelerate
|
||||||
|
llvmlite==0.43.0
|
||||||
transformers
|
transformers
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
|
outetts
|
@ -2,4 +2,6 @@
|
|||||||
torch==2.4.1+rocm6.0
|
torch==2.4.1+rocm6.0
|
||||||
accelerate
|
accelerate
|
||||||
transformers
|
transformers
|
||||||
|
llvmlite==0.43.0
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
|
outetts
|
@ -3,5 +3,7 @@ intel-extension-for-pytorch==2.3.110+xpu
|
|||||||
torch==2.3.1+cxx11.abi
|
torch==2.3.1+cxx11.abi
|
||||||
oneccl_bind_pt==2.3.100+xpu
|
oneccl_bind_pt==2.3.100+xpu
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
|
llvmlite==0.43.0
|
||||||
intel-extension-for-transformers
|
intel-extension-for-transformers
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
|
outetts
|
@ -3,3 +3,5 @@ protobuf
|
|||||||
certifi
|
certifi
|
||||||
setuptools
|
setuptools
|
||||||
scipy==1.14.0
|
scipy==1.14.0
|
||||||
|
numpy>=2.0.0
|
||||||
|
numba==0.60.0
|
Loading…
x
Reference in New Issue
Block a user