diff --git a/backend/backend.proto b/backend/backend.proto index 0d3d5f7f..4a8f31a9 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -16,6 +16,7 @@ service Backend { rpc GenerateImage(GenerateImageRequest) returns (Result) {} rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {} rpc TTS(TTSRequest) returns (Result) {} + rpc SoundGeneration(SoundGenerationRequest) returns (Result) {} rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {} rpc Status(HealthMessage) returns (StatusResponse) {} @@ -270,6 +271,17 @@ message TTSRequest { optional string language = 5; } +message SoundGenerationRequest { + string text = 1; + string model = 2; + string dst = 3; + optional float duration = 4; + optional float temperature = 5; + optional bool sample = 6; + optional string src = 7; + optional int32 src_divisor = 8; +} + message TokenizationResponse { int32 length = 1; repeated int32 tokens = 2; diff --git a/backend/python/transformers-musicgen/backend.py b/backend/python/transformers-musicgen/backend.py index d41d9a5c..b9f1facf 100644 --- a/backend/python/transformers-musicgen/backend.py +++ b/backend/python/transformers-musicgen/backend.py @@ -15,7 +15,7 @@ import backend_pb2_grpc import grpc -from scipy.io.wavfile import write as write_wav +from scipy.io import wavfile from transformers import AutoProcessor, MusicgenForConditionalGeneration _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -63,6 +63,61 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): return backend_pb2.Result(message="Model loaded successfully", success=True) + def SoundGeneration(self, request, context): + model_name = request.model + if model_name == "": + return backend_pb2.Result(success=False, message="request.model is required") + try: + self.processor = AutoProcessor.from_pretrained(model_name) + self.model = MusicgenForConditionalGeneration.from_pretrained(model_name) + inputs = None + if request.text == "": + inputs = self.model.get_unconditional_inputs(num_samples=1) + elif request.HasField('src'): + # TODO SECURITY CODE GOES HERE LOL + # WHO KNOWS IF THIS WORKS??? + sample_rate, wsamples = wavfile.read('path_to_your_file.wav') + + if request.HasField('src_divisor'): + wsamples = wsamples[: len(wsamples) // request.src_divisor] + + inputs = self.processor( + audio=wsamples, + sampling_rate=sample_rate, + text=[request.text], + padding=True, + return_tensors="pt", + ) + else: + inputs = self.processor( + text=[request.text], + padding=True, + return_tensors="pt", + ) + + tokens = 256 + if request.HasField('duration'): + tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second + guidance = 3.0 + if request.HasField('temperature'): + guidance = request.temperature + dosample = True + if request.HasField('sample'): + dosample = request.sample + audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=tokens) + print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr) + sampling_rate = self.model.config.audio_encoder.sampling_rate + wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy()) + print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr) + print("[transformers-musicgen] SoundGeneration for", file=sys.stderr) + print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr) + print(request, file=sys.stderr) + except Exception as err: + return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + return backend_pb2.Result(success=True) + + +# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons def TTS(self, request, context): model_name = request.model if model_name == "": @@ -75,8 +130,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): padding=True, return_tensors="pt", ) - tokens = 256 - # TODO get tokens from request? + tokens = 512 # No good place to set the "length" in TTS, so use 10s as a sane default audio_values = self.model.generate(**inputs, max_new_tokens=tokens) print("[transformers-musicgen] TTS generated!", file=sys.stderr) sampling_rate = self.model.config.audio_encoder.sampling_rate diff --git a/backend/python/transformers-musicgen/test.py b/backend/python/transformers-musicgen/test.py index 777b399a..295de65e 100644 --- a/backend/python/transformers-musicgen/test.py +++ b/backend/python/transformers-musicgen/test.py @@ -63,7 +63,7 @@ class TestBackendServicer(unittest.TestCase): def test_tts(self): """ - This method tests if the embeddings are generated successfully + This method tests if TTS is generated successfully """ try: self.setUp() @@ -77,5 +77,24 @@ class TestBackendServicer(unittest.TestCase): except Exception as err: print(err) self.fail("TTS service failed") + finally: + self.tearDown() + + def test_sound_generation(self): + """ + This method tests if SoundGeneration is generated successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/musicgen-small")) + self.assertTrue(response.success) + sg_request = backend_pb2.SoundGenerationRequest(text="80s TV news production music hit for tonight's biggest story") + sg_response = stub.SoundGeneration(sg_request) + self.assertIsNotNone(sg_response) + except Exception as err: + print(err) + self.fail("SoundGeneration service failed") finally: self.tearDown() \ No newline at end of file diff --git a/core/backend/llm.go b/core/backend/llm.go index 9268fbbc..72c4ad9f 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -87,7 +87,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im case string: protoMessages[i].Content = ct default: - return nil, fmt.Errorf("Unsupported type for schema.Message.Content for inference: %T", ct) + return nil, fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct) } } } diff --git a/core/backend/soundgeneration.go b/core/backend/soundgeneration.go new file mode 100644 index 00000000..abd5221b --- /dev/null +++ b/core/backend/soundgeneration.go @@ -0,0 +1,74 @@ +package backend + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/utils" +) + +func SoundGeneration( + backend string, + modelFile string, + text string, + duration *float32, + temperature *float32, + doSample *bool, + sourceFile *string, + sourceDivisor *int32, + loader *model.ModelLoader, + appConfig *config.ApplicationConfig, + backendConfig config.BackendConfig, +) (string, *proto.Result, error) { + if backend == "" { + return "", nil, fmt.Errorf("backend is a required parameter") + } + + grpcOpts := gRPCModelOpts(backendConfig) + opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{ + model.WithBackendString(backend), + model.WithModel(modelFile), + model.WithContext(appConfig.Context), + model.WithAssetDir(appConfig.AssetsDestination), + model.WithLoadGRPCLoadModelOpts(grpcOpts), + }) + + soundGenModel, err := loader.BackendLoader(opts...) + if err != nil { + return "", nil, err + } + + if soundGenModel == nil { + return "", nil, fmt.Errorf("could not load sound generation model") + } + + if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil { + return "", nil, fmt.Errorf("failed creating audio directory: %s", err) + } + + fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "sound_generation", ".wav") + filePath := filepath.Join(appConfig.AudioDir, fileName) + + res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{ + Text: text, + Model: modelFile, + Dst: filePath, + Sample: doSample, + Duration: duration, + Temperature: temperature, + Src: sourceFile, + SrcDivisor: sourceDivisor, + }) + + // return RPC error if any + if !res.Success { + return "", nil, fmt.Errorf(res.Message) + } + + return filePath, res, err +} diff --git a/core/backend/tts.go b/core/backend/tts.go index ced73e13..13a851ba 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -9,31 +9,15 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/grpc/proto" - model "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/utils" ) -func generateUniqueFileName(dir, baseName, ext string) string { - counter := 1 - fileName := baseName + ext - - for { - filePath := filepath.Join(dir, fileName) - _, err := os.Stat(filePath) - if os.IsNotExist(err) { - return fileName - } - - counter++ - fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) - } -} - func ModelTTS( backend, text, modelFile, - voice , + voice, language string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, @@ -66,7 +50,7 @@ func ModelTTS( return "", nil, fmt.Errorf("failed creating audio directory: %s", err) } - fileName := generateUniqueFileName(appConfig.AudioDir, "tts", ".wav") + fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav") filePath := filepath.Join(appConfig.AudioDir, fileName) // If the model file is not empty, we pass it joined with the model path @@ -88,10 +72,10 @@ func ModelTTS( } res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{ - Text: text, - Model: modelPath, - Voice: voice, - Dst: filePath, + Text: text, + Model: modelPath, + Voice: voice, + Dst: filePath, Language: &language, }) diff --git a/core/cli/cli.go b/core/cli/cli.go index 2073778d..aed75d8a 100644 --- a/core/cli/cli.go +++ b/core/cli/cli.go @@ -8,12 +8,13 @@ import ( var CLI struct { cliContext.Context `embed:""` - Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"` - Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"` - Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"` - TTS TTSCMD `cmd:"" help:"Convert text to speech"` - Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"` - Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"` - Util UtilCMD `cmd:"" help:"Utility commands"` - Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"` + Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"` + Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"` + Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"` + TTS TTSCMD `cmd:"" help:"Convert text to speech"` + SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"` + Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"` + Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"` + Util UtilCMD `cmd:"" help:"Utility commands"` + Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"` } diff --git a/core/cli/soundgeneration.go b/core/cli/soundgeneration.go new file mode 100644 index 00000000..5711b199 --- /dev/null +++ b/core/cli/soundgeneration.go @@ -0,0 +1,110 @@ +package cli + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/mudler/LocalAI/core/backend" + cliContext "github.com/mudler/LocalAI/core/cli/context" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/pkg/model" + "github.com/rs/zerolog/log" +) + +type SoundGenerationCMD struct { + Text []string `arg:""` + + Backend string `short:"b" required:"" help:"Backend to run the SoundGeneration model"` + Model string `short:"m" required:"" help:"Model name to run the SoundGeneration"` + Duration string `short:"d" help:"If specified, the length of audio to generate in seconds"` + Temperature string `short:"t" help:"If specified, the temperature of the generation"` + InputFile string `short:"i" help:"If specified, the input file to condition generation upon"` + InputFileSampleDivisor string `short:"f" help:"If InputFile and this divisor is specified, the first portion of the sample file will be used"` + DoSample bool `short:"s" default:"true" help:"Enables sampling from the model. Better quality at the cost of speed. Defaults to enabled."` + OutputFile string `short:"o" type:"path" help:"The path to write the output wav file"` + ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` + BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"` + ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"` +} + +func parseToFloat32Ptr(input string) *float32 { + f, err := strconv.ParseFloat(input, 32) + if err != nil { + return nil + } + f2 := float32(f) + return &f2 +} + +func parseToInt32Ptr(input string) *int32 { + i, err := strconv.ParseInt(input, 10, 32) + if err != nil { + return nil + } + i2 := int32(i) + return &i2 +} + +func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error { + outputFile := t.OutputFile + outputDir := t.BackendAssetsPath + if outputFile != "" { + outputDir = filepath.Dir(outputFile) + } + + text := strings.Join(t.Text, " ") + + externalBackends := make(map[string]string) + // split ":" to get backend name and the uri + for _, v := range t.ExternalGRPCBackends { + backend := v[:strings.IndexByte(v, ':')] + uri := v[strings.IndexByte(v, ':')+1:] + externalBackends[backend] = uri + fmt.Printf("TMP externalBackends[%q]=%q\n\n", backend, uri) + } + + opts := &config.ApplicationConfig{ + ModelPath: t.ModelsPath, + Context: context.Background(), + AudioDir: outputDir, + AssetsDestination: t.BackendAssetsPath, + ExternalGRPCBackends: externalBackends, + } + ml := model.NewModelLoader(opts.ModelPath) + + defer func() { + err := ml.StopAllGRPC() + if err != nil { + log.Error().Err(err).Msg("unable to stop all grpc processes") + } + }() + + options := config.BackendConfig{} + options.SetDefaults() + + var inputFile *string + if t.InputFile != "" { + inputFile = &t.InputFile + } + + filePath, _, err := backend.SoundGeneration(t.Backend, t.Model, text, + parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample, + inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options) + + if err != nil { + return err + } + if outputFile != "" { + if err := os.Rename(filePath, outputFile); err != nil { + return err + } + fmt.Printf("Generate file %s\n", outputFile) + } else { + fmt.Printf("Generate file %s\n", filePath) + } + return nil +} diff --git a/core/http/endpoints/elevenlabs/soundgeneration.go b/core/http/endpoints/elevenlabs/soundgeneration.go new file mode 100644 index 00000000..619544d8 --- /dev/null +++ b/core/http/endpoints/elevenlabs/soundgeneration.go @@ -0,0 +1,65 @@ +package elevenlabs + +import ( + "github.com/gofiber/fiber/v2" + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + fiberContext "github.com/mudler/LocalAI/core/http/ctx" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/model" + "github.com/rs/zerolog/log" +) + +// SoundGenerationEndpoint is the ElevenLabs SoundGeneration endpoint https://elevenlabs.io/docs/api-reference/sound-generation +// @Summary Generates audio from the input text. +// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params" +// @Success 200 {string} binary "Response" +// @Router /v1/sound-generation [post] +func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(schema.ElevenLabsSoundGenerationRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false) + if err != nil { + modelFile = input.ModelID + log.Warn().Str("ModelID", input.ModelID).Msg("Model not found in context") + } + + cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, + config.LoadOptionDebug(appConfig.Debug), + config.LoadOptionThreads(appConfig.Threads), + config.LoadOptionContextSize(appConfig.ContextSize), + config.LoadOptionF16(appConfig.F16), + ) + if err != nil { + modelFile = input.ModelID + log.Warn().Str("Request ModelID", input.ModelID).Err(err).Msg("error during LoadBackendConfigFileByName, using request ModelID") + } else { + if input.ModelID != "" { + modelFile = input.ModelID + } else { + modelFile = cfg.Model + } + } + log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend") + + if input.Duration != nil { + log.Debug().Float32("duration", *input.Duration).Msg("duration set") + } + if input.Temperature != nil { + log.Debug().Float32("temperature", *input.Temperature).Msg("temperature set") + } + + // TODO: Support uploading files? + filePath, _, err := backend.SoundGeneration(cfg.Backend, modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg) + if err != nil { + return err + } + return c.Download(filePath) + + } +} diff --git a/core/http/routes/elevenlabs.go b/core/http/routes/elevenlabs.go index 4f9e666f..b20dec75 100644 --- a/core/http/routes/elevenlabs.go +++ b/core/http/routes/elevenlabs.go @@ -16,4 +16,6 @@ func RegisterElevenLabsRoutes(app *fiber.App, // Elevenlabs app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig)) + app.Post("/v1/sound-generation", auth, elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig)) + } diff --git a/core/schema/elevenlabs.go b/core/schema/elevenlabs.go index 8bd6be3b..119e0a58 100644 --- a/core/schema/elevenlabs.go +++ b/core/schema/elevenlabs.go @@ -4,3 +4,11 @@ type ElevenLabsTTSRequest struct { Text string `json:"text" yaml:"text"` ModelID string `json:"model_id" yaml:"model_id"` } + +type ElevenLabsSoundGenerationRequest struct { + Text string `json:"text" yaml:"text"` + ModelID string `json:"model_id" yaml:"model_id"` + Duration *float32 `json:"duration_seconds,omitempty" yaml:"duration_seconds,omitempty"` + Temperature *float32 `json:"prompt_influence,omitempty" yaml:"prompt_influence,omitempty"` + DoSample *bool `json:"do_sample,omitempty" yaml:"do_sample,omitempty"` +} diff --git a/examples/bruno/LocalAI Test Requests/Sound Generation/musicgen.bru b/examples/bruno/LocalAI Test Requests/Sound Generation/musicgen.bru new file mode 100644 index 00000000..471756f5 --- /dev/null +++ b/examples/bruno/LocalAI Test Requests/Sound Generation/musicgen.bru @@ -0,0 +1,23 @@ +meta { + name: musicgen + type: http + seq: 1 +} + +post { + url: {{PROTOCOL}}{{HOST}}:{{PORT}}/v1/sound-generation + body: json + auth: none +} + +headers { + Content-Type: application/json +} + +body:json { + { + "model_id": "facebook/musicgen-small", + "text": "Exciting 80s Newscast Interstitial", + "duration_seconds": 8 + } +} diff --git a/go.sum b/go.sum index e09af5ce..85800fd6 100644 --- a/go.sum +++ b/go.sum @@ -509,6 +509,9 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/ginkgo/v2 v2.20.0 h1:PE84V2mHqoT1sglvHc8ZdQtPcwmvvt29WLEEO3xmdZw= github.com/onsi/ginkgo/v2 v2.20.0/go.mod h1:lG9ey2Z29hR41WMVthyJBGUBcBhGOtoPF2VFMvBXFCI= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 596a7589..5abc34ab 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -41,6 +41,7 @@ type Backend interface { PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) + SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) Status(ctx context.Context) (*pb.StatusResponse, error) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 515022ec..21dd1578 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -61,6 +61,10 @@ func (llm *Base) TTS(*pb.TTSRequest) error { return fmt.Errorf("unimplemented") } +func (llm *Base) SoundGeneration(*pb.SoundGenerationRequest) error { + return fmt.Errorf("unimplemented") +} + func (llm *Base) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) { return pb.TokenizationResponse{}, fmt.Errorf("unimplemented") } diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index cfae5875..827275cf 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -210,6 +210,26 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp return client.TTS(ctx, in, opts...) } +func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + if c.wd != nil { + c.wd.Mark(c.address) + defer c.wd.UnMark(c.address) + } + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.SoundGeneration(ctx, in, opts...) +} + func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) { if !c.parallel { c.opMutex.Lock() diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index 2b776b39..67d83e27 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -53,6 +53,10 @@ func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc. return e.s.TTS(ctx, in) } +func (e *embedBackend) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) { + return e.s.SoundGeneration(ctx, in) +} + func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) { r, err := e.s.AudioTranscription(ctx, in) if err != nil { diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 313c8ff5..731dcd5b 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -17,6 +17,7 @@ type LLM interface { GenerateImage(*pb.GenerateImageRequest) error AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) TTS(*pb.TTSRequest) error + SoundGeneration(*pb.SoundGenerationRequest) error TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) Status() (pb.StatusResponse, error) diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 784aac7f..0e602a42 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -84,7 +84,19 @@ func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) if err != nil { return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err } - return &pb.Result{Message: "Audio generated", Success: true}, nil + return &pb.Result{Message: "TTS audio generated", Success: true}, nil +} + +func (s *server) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest) (*pb.Result, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } + err := s.llm.SoundGeneration(in) + if err != nil { + return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err + } + return &pb.Result{Message: "Sound Generation audio generated", Success: true}, nil } func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) { diff --git a/pkg/utils/path.go b/pkg/utils/path.go index c1d3e86d..1ae11d12 100644 --- a/pkg/utils/path.go +++ b/pkg/utils/path.go @@ -38,3 +38,19 @@ func SanitizeFileName(fileName string) string { safeName := strings.ReplaceAll(baseName, "..", "") return safeName } + +func GenerateUniqueFileName(dir, baseName, ext string) string { + counter := 1 + fileName := baseName + ext + + for { + filePath := filepath.Join(dir, fileName) + _, err := os.Stat(filePath) + if os.IsNotExist(err) { + return fileName + } + + counter++ + fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) + } +}