mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-19 20:57:54 +00:00
feat: elevenlabs sound-generation
api (#3355)
* initial version of elevenlabs compatible soundgeneration api and cli command Signed-off-by: Dave Lee <dave@gray101.com> * minor cleanup Signed-off-by: Dave Lee <dave@gray101.com> * restore TTS, add test Signed-off-by: Dave Lee <dave@gray101.com> * remove stray s Signed-off-by: Dave Lee <dave@gray101.com> * fix Signed-off-by: Dave Lee <dave@gray101.com> --------- Signed-off-by: Dave Lee <dave@gray101.com> Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
parent
84d6e5a987
commit
81ae92f017
@ -16,6 +16,7 @@ service Backend {
|
|||||||
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
|
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
|
||||||
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
|
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
|
||||||
rpc TTS(TTSRequest) returns (Result) {}
|
rpc TTS(TTSRequest) returns (Result) {}
|
||||||
|
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
|
||||||
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
|
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
|
||||||
rpc Status(HealthMessage) returns (StatusResponse) {}
|
rpc Status(HealthMessage) returns (StatusResponse) {}
|
||||||
|
|
||||||
@ -270,6 +271,17 @@ message TTSRequest {
|
|||||||
optional string language = 5;
|
optional string language = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message SoundGenerationRequest {
|
||||||
|
string text = 1;
|
||||||
|
string model = 2;
|
||||||
|
string dst = 3;
|
||||||
|
optional float duration = 4;
|
||||||
|
optional float temperature = 5;
|
||||||
|
optional bool sample = 6;
|
||||||
|
optional string src = 7;
|
||||||
|
optional int32 src_divisor = 8;
|
||||||
|
}
|
||||||
|
|
||||||
message TokenizationResponse {
|
message TokenizationResponse {
|
||||||
int32 length = 1;
|
int32 length = 1;
|
||||||
repeated int32 tokens = 2;
|
repeated int32 tokens = 2;
|
||||||
|
@ -15,7 +15,7 @@ import backend_pb2_grpc
|
|||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
from scipy.io.wavfile import write as write_wav
|
from scipy.io import wavfile
|
||||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
@ -63,6 +63,61 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
|
|
||||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||||
|
|
||||||
|
def SoundGeneration(self, request, context):
|
||||||
|
model_name = request.model
|
||||||
|
if model_name == "":
|
||||||
|
return backend_pb2.Result(success=False, message="request.model is required")
|
||||||
|
try:
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||||
|
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||||
|
inputs = None
|
||||||
|
if request.text == "":
|
||||||
|
inputs = self.model.get_unconditional_inputs(num_samples=1)
|
||||||
|
elif request.HasField('src'):
|
||||||
|
# TODO SECURITY CODE GOES HERE LOL
|
||||||
|
# WHO KNOWS IF THIS WORKS???
|
||||||
|
sample_rate, wsamples = wavfile.read('path_to_your_file.wav')
|
||||||
|
|
||||||
|
if request.HasField('src_divisor'):
|
||||||
|
wsamples = wsamples[: len(wsamples) // request.src_divisor]
|
||||||
|
|
||||||
|
inputs = self.processor(
|
||||||
|
audio=wsamples,
|
||||||
|
sampling_rate=sample_rate,
|
||||||
|
text=[request.text],
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[request.text],
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = 256
|
||||||
|
if request.HasField('duration'):
|
||||||
|
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
|
||||||
|
guidance = 3.0
|
||||||
|
if request.HasField('temperature'):
|
||||||
|
guidance = request.temperature
|
||||||
|
dosample = True
|
||||||
|
if request.HasField('sample'):
|
||||||
|
dosample = request.sample
|
||||||
|
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=tokens)
|
||||||
|
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
|
||||||
|
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||||
|
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||||
|
print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr)
|
||||||
|
print("[transformers-musicgen] SoundGeneration for", file=sys.stderr)
|
||||||
|
print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr)
|
||||||
|
print(request, file=sys.stderr)
|
||||||
|
except Exception as err:
|
||||||
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
|
return backend_pb2.Result(success=True)
|
||||||
|
|
||||||
|
|
||||||
|
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
||||||
def TTS(self, request, context):
|
def TTS(self, request, context):
|
||||||
model_name = request.model
|
model_name = request.model
|
||||||
if model_name == "":
|
if model_name == "":
|
||||||
@ -75,8 +130,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
padding=True,
|
padding=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
tokens = 256
|
tokens = 512 # No good place to set the "length" in TTS, so use 10s as a sane default
|
||||||
# TODO get tokens from request?
|
|
||||||
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
|
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
|
||||||
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
|
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
|
||||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||||
|
@ -63,7 +63,7 @@ class TestBackendServicer(unittest.TestCase):
|
|||||||
|
|
||||||
def test_tts(self):
|
def test_tts(self):
|
||||||
"""
|
"""
|
||||||
This method tests if the embeddings are generated successfully
|
This method tests if TTS is generated successfully
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self.setUp()
|
self.setUp()
|
||||||
@ -77,5 +77,24 @@ class TestBackendServicer(unittest.TestCase):
|
|||||||
except Exception as err:
|
except Exception as err:
|
||||||
print(err)
|
print(err)
|
||||||
self.fail("TTS service failed")
|
self.fail("TTS service failed")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
||||||
|
|
||||||
|
def test_sound_generation(self):
|
||||||
|
"""
|
||||||
|
This method tests if SoundGeneration is generated successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/musicgen-small"))
|
||||||
|
self.assertTrue(response.success)
|
||||||
|
sg_request = backend_pb2.SoundGenerationRequest(text="80s TV news production music hit for tonight's biggest story")
|
||||||
|
sg_response = stub.SoundGeneration(sg_request)
|
||||||
|
self.assertIsNotNone(sg_response)
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("SoundGeneration service failed")
|
||||||
finally:
|
finally:
|
||||||
self.tearDown()
|
self.tearDown()
|
@ -87,7 +87,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
|||||||
case string:
|
case string:
|
||||||
protoMessages[i].Content = ct
|
protoMessages[i].Content = ct
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unsupported type for schema.Message.Content for inference: %T", ct)
|
return nil, fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
74
core/backend/soundgeneration.go
Normal file
74
core/backend/soundgeneration.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SoundGeneration(
|
||||||
|
backend string,
|
||||||
|
modelFile string,
|
||||||
|
text string,
|
||||||
|
duration *float32,
|
||||||
|
temperature *float32,
|
||||||
|
doSample *bool,
|
||||||
|
sourceFile *string,
|
||||||
|
sourceDivisor *int32,
|
||||||
|
loader *model.ModelLoader,
|
||||||
|
appConfig *config.ApplicationConfig,
|
||||||
|
backendConfig config.BackendConfig,
|
||||||
|
) (string, *proto.Result, error) {
|
||||||
|
if backend == "" {
|
||||||
|
return "", nil, fmt.Errorf("backend is a required parameter")
|
||||||
|
}
|
||||||
|
|
||||||
|
grpcOpts := gRPCModelOpts(backendConfig)
|
||||||
|
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
|
||||||
|
model.WithBackendString(backend),
|
||||||
|
model.WithModel(modelFile),
|
||||||
|
model.WithContext(appConfig.Context),
|
||||||
|
model.WithAssetDir(appConfig.AssetsDestination),
|
||||||
|
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||||
|
})
|
||||||
|
|
||||||
|
soundGenModel, err := loader.BackendLoader(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if soundGenModel == nil {
|
||||||
|
return "", nil, fmt.Errorf("could not load sound generation model")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "sound_generation", ".wav")
|
||||||
|
filePath := filepath.Join(appConfig.AudioDir, fileName)
|
||||||
|
|
||||||
|
res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
|
||||||
|
Text: text,
|
||||||
|
Model: modelFile,
|
||||||
|
Dst: filePath,
|
||||||
|
Sample: doSample,
|
||||||
|
Duration: duration,
|
||||||
|
Temperature: temperature,
|
||||||
|
Src: sourceFile,
|
||||||
|
SrcDivisor: sourceDivisor,
|
||||||
|
})
|
||||||
|
|
||||||
|
// return RPC error if any
|
||||||
|
if !res.Success {
|
||||||
|
return "", nil, fmt.Errorf(res.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filePath, res, err
|
||||||
|
}
|
@ -9,31 +9,15 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func generateUniqueFileName(dir, baseName, ext string) string {
|
|
||||||
counter := 1
|
|
||||||
fileName := baseName + ext
|
|
||||||
|
|
||||||
for {
|
|
||||||
filePath := filepath.Join(dir, fileName)
|
|
||||||
_, err := os.Stat(filePath)
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return fileName
|
|
||||||
}
|
|
||||||
|
|
||||||
counter++
|
|
||||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ModelTTS(
|
func ModelTTS(
|
||||||
backend,
|
backend,
|
||||||
text,
|
text,
|
||||||
modelFile,
|
modelFile,
|
||||||
voice ,
|
voice,
|
||||||
language string,
|
language string,
|
||||||
loader *model.ModelLoader,
|
loader *model.ModelLoader,
|
||||||
appConfig *config.ApplicationConfig,
|
appConfig *config.ApplicationConfig,
|
||||||
@ -66,7 +50,7 @@ func ModelTTS(
|
|||||||
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := generateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
|
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
|
||||||
filePath := filepath.Join(appConfig.AudioDir, fileName)
|
filePath := filepath.Join(appConfig.AudioDir, fileName)
|
||||||
|
|
||||||
// If the model file is not empty, we pass it joined with the model path
|
// If the model file is not empty, we pass it joined with the model path
|
||||||
@ -88,10 +72,10 @@ func ModelTTS(
|
|||||||
}
|
}
|
||||||
|
|
||||||
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
|
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
|
||||||
Text: text,
|
Text: text,
|
||||||
Model: modelPath,
|
Model: modelPath,
|
||||||
Voice: voice,
|
Voice: voice,
|
||||||
Dst: filePath,
|
Dst: filePath,
|
||||||
Language: &language,
|
Language: &language,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -8,12 +8,13 @@ import (
|
|||||||
var CLI struct {
|
var CLI struct {
|
||||||
cliContext.Context `embed:""`
|
cliContext.Context `embed:""`
|
||||||
|
|
||||||
Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
|
Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
|
||||||
Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
|
Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
|
||||||
Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
|
Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
|
||||||
TTS TTSCMD `cmd:"" help:"Convert text to speech"`
|
TTS TTSCMD `cmd:"" help:"Convert text to speech"`
|
||||||
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
|
SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"`
|
||||||
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
|
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
|
||||||
Util UtilCMD `cmd:"" help:"Utility commands"`
|
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
|
||||||
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
|
Util UtilCMD `cmd:"" help:"Utility commands"`
|
||||||
|
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
|
||||||
}
|
}
|
||||||
|
110
core/cli/soundgeneration.go
Normal file
110
core/cli/soundgeneration.go
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SoundGenerationCMD struct {
|
||||||
|
Text []string `arg:""`
|
||||||
|
|
||||||
|
Backend string `short:"b" required:"" help:"Backend to run the SoundGeneration model"`
|
||||||
|
Model string `short:"m" required:"" help:"Model name to run the SoundGeneration"`
|
||||||
|
Duration string `short:"d" help:"If specified, the length of audio to generate in seconds"`
|
||||||
|
Temperature string `short:"t" help:"If specified, the temperature of the generation"`
|
||||||
|
InputFile string `short:"i" help:"If specified, the input file to condition generation upon"`
|
||||||
|
InputFileSampleDivisor string `short:"f" help:"If InputFile and this divisor is specified, the first portion of the sample file will be used"`
|
||||||
|
DoSample bool `short:"s" default:"true" help:"Enables sampling from the model. Better quality at the cost of speed. Defaults to enabled."`
|
||||||
|
OutputFile string `short:"o" type:"path" help:"The path to write the output wav file"`
|
||||||
|
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
||||||
|
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
|
||||||
|
ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseToFloat32Ptr(input string) *float32 {
|
||||||
|
f, err := strconv.ParseFloat(input, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
f2 := float32(f)
|
||||||
|
return &f2
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseToInt32Ptr(input string) *int32 {
|
||||||
|
i, err := strconv.ParseInt(input, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
i2 := int32(i)
|
||||||
|
return &i2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
|
||||||
|
outputFile := t.OutputFile
|
||||||
|
outputDir := t.BackendAssetsPath
|
||||||
|
if outputFile != "" {
|
||||||
|
outputDir = filepath.Dir(outputFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
text := strings.Join(t.Text, " ")
|
||||||
|
|
||||||
|
externalBackends := make(map[string]string)
|
||||||
|
// split ":" to get backend name and the uri
|
||||||
|
for _, v := range t.ExternalGRPCBackends {
|
||||||
|
backend := v[:strings.IndexByte(v, ':')]
|
||||||
|
uri := v[strings.IndexByte(v, ':')+1:]
|
||||||
|
externalBackends[backend] = uri
|
||||||
|
fmt.Printf("TMP externalBackends[%q]=%q\n\n", backend, uri)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := &config.ApplicationConfig{
|
||||||
|
ModelPath: t.ModelsPath,
|
||||||
|
Context: context.Background(),
|
||||||
|
AudioDir: outputDir,
|
||||||
|
AssetsDestination: t.BackendAssetsPath,
|
||||||
|
ExternalGRPCBackends: externalBackends,
|
||||||
|
}
|
||||||
|
ml := model.NewModelLoader(opts.ModelPath)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := ml.StopAllGRPC()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("unable to stop all grpc processes")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
options := config.BackendConfig{}
|
||||||
|
options.SetDefaults()
|
||||||
|
|
||||||
|
var inputFile *string
|
||||||
|
if t.InputFile != "" {
|
||||||
|
inputFile = &t.InputFile
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath, _, err := backend.SoundGeneration(t.Backend, t.Model, text,
|
||||||
|
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
|
||||||
|
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if outputFile != "" {
|
||||||
|
if err := os.Rename(filePath, outputFile); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf("Generate file %s\n", outputFile)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Generate file %s\n", filePath)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
65
core/http/endpoints/elevenlabs/soundgeneration.go
Normal file
65
core/http/endpoints/elevenlabs/soundgeneration.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package elevenlabs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SoundGenerationEndpoint is the ElevenLabs SoundGeneration endpoint https://elevenlabs.io/docs/api-reference/sound-generation
|
||||||
|
// @Summary Generates audio from the input text.
|
||||||
|
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
|
||||||
|
// @Success 200 {string} binary "Response"
|
||||||
|
// @Router /v1/sound-generation [post]
|
||||||
|
func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(schema.ElevenLabsSoundGenerationRequest)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false)
|
||||||
|
if err != nil {
|
||||||
|
modelFile = input.ModelID
|
||||||
|
log.Warn().Str("ModelID", input.ModelID).Msg("Model not found in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||||
|
config.LoadOptionDebug(appConfig.Debug),
|
||||||
|
config.LoadOptionThreads(appConfig.Threads),
|
||||||
|
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||||
|
config.LoadOptionF16(appConfig.F16),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
modelFile = input.ModelID
|
||||||
|
log.Warn().Str("Request ModelID", input.ModelID).Err(err).Msg("error during LoadBackendConfigFileByName, using request ModelID")
|
||||||
|
} else {
|
||||||
|
if input.ModelID != "" {
|
||||||
|
modelFile = input.ModelID
|
||||||
|
} else {
|
||||||
|
modelFile = cfg.Model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
|
||||||
|
|
||||||
|
if input.Duration != nil {
|
||||||
|
log.Debug().Float32("duration", *input.Duration).Msg("duration set")
|
||||||
|
}
|
||||||
|
if input.Temperature != nil {
|
||||||
|
log.Debug().Float32("temperature", *input.Temperature).Msg("temperature set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Support uploading files?
|
||||||
|
filePath, _, err := backend.SoundGeneration(cfg.Backend, modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Download(filePath)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
@ -16,4 +16,6 @@ func RegisterElevenLabsRoutes(app *fiber.App,
|
|||||||
// Elevenlabs
|
// Elevenlabs
|
||||||
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
|
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
app.Post("/v1/sound-generation", auth, elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -4,3 +4,11 @@ type ElevenLabsTTSRequest struct {
|
|||||||
Text string `json:"text" yaml:"text"`
|
Text string `json:"text" yaml:"text"`
|
||||||
ModelID string `json:"model_id" yaml:"model_id"`
|
ModelID string `json:"model_id" yaml:"model_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ElevenLabsSoundGenerationRequest struct {
|
||||||
|
Text string `json:"text" yaml:"text"`
|
||||||
|
ModelID string `json:"model_id" yaml:"model_id"`
|
||||||
|
Duration *float32 `json:"duration_seconds,omitempty" yaml:"duration_seconds,omitempty"`
|
||||||
|
Temperature *float32 `json:"prompt_influence,omitempty" yaml:"prompt_influence,omitempty"`
|
||||||
|
DoSample *bool `json:"do_sample,omitempty" yaml:"do_sample,omitempty"`
|
||||||
|
}
|
||||||
|
@ -0,0 +1,23 @@
|
|||||||
|
meta {
|
||||||
|
name: musicgen
|
||||||
|
type: http
|
||||||
|
seq: 1
|
||||||
|
}
|
||||||
|
|
||||||
|
post {
|
||||||
|
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/v1/sound-generation
|
||||||
|
body: json
|
||||||
|
auth: none
|
||||||
|
}
|
||||||
|
|
||||||
|
headers {
|
||||||
|
Content-Type: application/json
|
||||||
|
}
|
||||||
|
|
||||||
|
body:json {
|
||||||
|
{
|
||||||
|
"model_id": "facebook/musicgen-small",
|
||||||
|
"text": "Exciting 80s Newscast Interstitial",
|
||||||
|
"duration_seconds": 8
|
||||||
|
}
|
||||||
|
}
|
3
go.sum
3
go.sum
@ -509,6 +509,9 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
|||||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||||
github.com/onsi/ginkgo/v2 v2.20.0 h1:PE84V2mHqoT1sglvHc8ZdQtPcwmvvt29WLEEO3xmdZw=
|
github.com/onsi/ginkgo/v2 v2.20.0 h1:PE84V2mHqoT1sglvHc8ZdQtPcwmvvt29WLEEO3xmdZw=
|
||||||
github.com/onsi/ginkgo/v2 v2.20.0/go.mod h1:lG9ey2Z29hR41WMVthyJBGUBcBhGOtoPF2VFMvBXFCI=
|
github.com/onsi/ginkgo/v2 v2.20.0/go.mod h1:lG9ey2Z29hR41WMVthyJBGUBcBhGOtoPF2VFMvBXFCI=
|
||||||
|
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
||||||
|
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
||||||
|
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
|
||||||
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
|
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
|
||||||
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
|
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
|
||||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
|
@ -41,6 +41,7 @@ type Backend interface {
|
|||||||
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
|
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
|
||||||
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
|
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error)
|
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error)
|
||||||
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
|
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
|
||||||
Status(ctx context.Context) (*pb.StatusResponse, error)
|
Status(ctx context.Context) (*pb.StatusResponse, error)
|
||||||
|
@ -61,6 +61,10 @@ func (llm *Base) TTS(*pb.TTSRequest) error {
|
|||||||
return fmt.Errorf("unimplemented")
|
return fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (llm *Base) SoundGeneration(*pb.SoundGenerationRequest) error {
|
||||||
|
return fmt.Errorf("unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (llm *Base) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
|
func (llm *Base) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
|
||||||
return pb.TokenizationResponse{}, fmt.Errorf("unimplemented")
|
return pb.TokenizationResponse{}, fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
@ -210,6 +210,26 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
|
|||||||
return client.TTS(ctx, in, opts...)
|
return client.TTS(ctx, in, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||||
|
if !c.parallel {
|
||||||
|
c.opMutex.Lock()
|
||||||
|
defer c.opMutex.Unlock()
|
||||||
|
}
|
||||||
|
c.setBusy(true)
|
||||||
|
defer c.setBusy(false)
|
||||||
|
if c.wd != nil {
|
||||||
|
c.wd.Mark(c.address)
|
||||||
|
defer c.wd.UnMark(c.address)
|
||||||
|
}
|
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
client := pb.NewBackendClient(conn)
|
||||||
|
return client.SoundGeneration(ctx, in, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) {
|
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) {
|
||||||
if !c.parallel {
|
if !c.parallel {
|
||||||
c.opMutex.Lock()
|
c.opMutex.Lock()
|
||||||
|
@ -53,6 +53,10 @@ func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.
|
|||||||
return e.s.TTS(ctx, in)
|
return e.s.TTS(ctx, in)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *embedBackend) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||||
|
return e.s.SoundGeneration(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) {
|
func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) {
|
||||||
r, err := e.s.AudioTranscription(ctx, in)
|
r, err := e.s.AudioTranscription(ctx, in)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -17,6 +17,7 @@ type LLM interface {
|
|||||||
GenerateImage(*pb.GenerateImageRequest) error
|
GenerateImage(*pb.GenerateImageRequest) error
|
||||||
AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error)
|
AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error)
|
||||||
TTS(*pb.TTSRequest) error
|
TTS(*pb.TTSRequest) error
|
||||||
|
SoundGeneration(*pb.SoundGenerationRequest) error
|
||||||
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
|
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
|
||||||
Status() (pb.StatusResponse, error)
|
Status() (pb.StatusResponse, error)
|
||||||
|
|
||||||
|
@ -84,7 +84,19 @@ func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err
|
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err
|
||||||
}
|
}
|
||||||
return &pb.Result{Message: "Audio generated", Success: true}, nil
|
return &pb.Result{Message: "TTS audio generated", Success: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest) (*pb.Result, error) {
|
||||||
|
if s.llm.Locking() {
|
||||||
|
s.llm.Lock()
|
||||||
|
defer s.llm.Unlock()
|
||||||
|
}
|
||||||
|
err := s.llm.SoundGeneration(in)
|
||||||
|
if err != nil {
|
||||||
|
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err
|
||||||
|
}
|
||||||
|
return &pb.Result{Message: "Sound Generation audio generated", Success: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) {
|
func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) {
|
||||||
|
@ -38,3 +38,19 @@ func SanitizeFileName(fileName string) string {
|
|||||||
safeName := strings.ReplaceAll(baseName, "..", "")
|
safeName := strings.ReplaceAll(baseName, "..", "")
|
||||||
return safeName
|
return safeName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenerateUniqueFileName(dir, baseName, ext string) string {
|
||||||
|
counter := 1
|
||||||
|
fileName := baseName + ext
|
||||||
|
|
||||||
|
for {
|
||||||
|
filePath := filepath.Join(dir, fileName)
|
||||||
|
_, err := os.Stat(filePath)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return fileName
|
||||||
|
}
|
||||||
|
|
||||||
|
counter++
|
||||||
|
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user