diff --git a/Makefile b/Makefile index 3384fb5e..0e48a056 100644 --- a/Makefile +++ b/Makefile @@ -34,6 +34,10 @@ STABLEDIFFUSION_VERSION?=4a3cd6aeae6f66ee57eae9a0075f8c58c3a6a38f TINYDREAM_REPO?=https://github.com/M0Rf30/go-tiny-dream TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057 +ONNX_VERSION?=1.20.0 +ONNX_ARCH?=x64 +ONNX_OS?=linux + export BUILD_TYPE?= export STABLE_BUILD_TYPE?=$(BUILD_TYPE) export CMAKE_ARGS?= @@ -89,7 +93,20 @@ ifeq ($(NATIVE),false) CMAKE_ARGS+=-DGGML_NATIVE=OFF endif +# Detect if we are running on arm64 +ifneq (,$(findstring aarch64,$(shell uname -m))) + ONNX_ARCH=aarch64 +endif + ifeq ($(OS),Darwin) + ONNX_OS=osx + ifneq (,$(findstring aarch64,$(shell uname -m))) + ONNX_ARCH=arm64 + else ifneq (,$(findstring arm64,$(shell uname -m))) + ONNX_ARCH=arm64 + else + ONNX_ARCH=x86_64 + endif ifeq ($(OSX_SIGNING_IDENTITY),) OSX_SIGNING_IDENTITY := $(shell security find-identity -v -p codesigning | grep '"' | head -n 1 | sed -E 's/.*"(.*)"/\1/') @@ -195,6 +212,7 @@ ALL_GRPC_BACKENDS+=backend-assets/util/llama-cpp-rpc-server ALL_GRPC_BACKENDS+=backend-assets/grpc/rwkv ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store +ALL_GRPC_BACKENDS+=backend-assets/grpc/silero-vad ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC) # Use filter-out to remove the specified backends ALL_GRPC_BACKENDS := $(filter-out $(SKIP_GRPC_BACKEND),$(ALL_GRPC_BACKENDS)) @@ -281,6 +299,20 @@ sources/go-stable-diffusion: sources/go-stable-diffusion/libstablediffusion.a: sources/go-stable-diffusion CPATH="$(CPATH):/usr/include/opencv4" $(MAKE) -C sources/go-stable-diffusion libstablediffusion.a +sources/onnxruntime: + mkdir -p sources/onnxruntime + curl -L https://github.com/microsoft/onnxruntime/releases/download/v$(ONNX_VERSION)/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz -o sources/onnxruntime/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz + cd sources/onnxruntime && tar -xvf onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz && rm onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz + cd sources/onnxruntime && mv onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION)/* ./ + +backend-assets/lib/libonnxruntime.so.1: backend-assets/lib sources/onnxruntime + cp -rfv sources/onnxruntime/lib/* backend-assets/lib/ +ifeq ($(OS),Darwin) + mv backend-assets/lib/libonnxruntime.$(ONNX_VERSION).dylib backend-assets/lib/libonnxruntime.dylib +else + mv backend-assets/lib/libonnxruntime.so.$(ONNX_VERSION) backend-assets/lib/libonnxruntime.so.1 +endif + ## tiny-dream sources/go-tiny-dream: mkdir -p sources/go-tiny-dream @@ -837,6 +869,13 @@ ifneq ($(UPX),) $(UPX) backend-assets/grpc/stablediffusion endif +backend-assets/grpc/silero-vad: backend-assets/grpc backend-assets/lib/libonnxruntime.so.1 + CGO_LDFLAGS="$(CGO_LDFLAGS)" CPATH="$(CPATH):$(CURDIR)/sources/onnxruntime/include/" LIBRARY_PATH=$(CURDIR)/backend-assets/lib \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/silero-vad ./backend/go/vad/silero +ifneq ($(UPX),) + $(UPX) backend-assets/grpc/silero-vad +endif + backend-assets/grpc/tinydream: sources/go-tiny-dream sources/go-tiny-dream/libtinydream.a backend-assets/grpc CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/go-tiny-dream \ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/tinydream ./backend/go/image/tinydream diff --git a/backend/backend.proto b/backend/backend.proto index 96f7c88f..d6e8f236 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -28,6 +28,8 @@ service Backend { rpc Rerank(RerankRequest) returns (RerankResult) {} rpc GetMetrics(MetricsRequest) returns (MetricsResponse); + + rpc VAD(VADRequest) returns (VADResponse) {} } // Define the empty request @@ -293,6 +295,19 @@ message TTSRequest { optional string language = 5; } +message VADRequest { + repeated float audio = 1; +} + +message VADSegment { + float start = 1; + float end = 2; +} + +message VADResponse { + repeated VADSegment segments = 1; +} + message SoundGenerationRequest { string text = 1; string model = 2; diff --git a/backend/go/vad/silero/main.go b/backend/go/vad/silero/main.go new file mode 100644 index 00000000..28f51e49 --- /dev/null +++ b/backend/go/vad/silero/main.go @@ -0,0 +1,21 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + grpc "github.com/mudler/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &VAD{}); err != nil { + panic(err) + } +} diff --git a/backend/go/vad/silero/vad.go b/backend/go/vad/silero/vad.go new file mode 100644 index 00000000..5a164d2a --- /dev/null +++ b/backend/go/vad/silero/vad.go @@ -0,0 +1,54 @@ +package main + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/mudler/LocalAI/pkg/grpc/base" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/streamer45/silero-vad-go/speech" +) + +type VAD struct { + base.SingleThread + detector *speech.Detector +} + +func (vad *VAD) Load(opts *pb.ModelOptions) error { + v, err := speech.NewDetector(speech.DetectorConfig{ + ModelPath: opts.ModelFile, + SampleRate: 16000, + //WindowSize: 1024, + Threshold: 0.5, + MinSilenceDurationMs: 0, + SpeechPadMs: 0, + }) + if err != nil { + return fmt.Errorf("create silero detector: %w", err) + } + + vad.detector = v + return err +} + +func (vad *VAD) VAD(req *pb.VADRequest) (pb.VADResponse, error) { + audio := req.Audio + + segments, err := vad.detector.Detect(audio) + if err != nil { + return pb.VADResponse{}, fmt.Errorf("detect: %w", err) + } + + vadSegments := []*pb.VADSegment{} + for _, s := range segments { + vadSegments = append(vadSegments, &pb.VADSegment{ + Start: float32(s.SpeechStartAt), + End: float32(s.SpeechEndAt), + }) + } + + return pb.VADResponse{ + Segments: vadSegments, + }, nil +} diff --git a/core/http/endpoints/localai/vad.go b/core/http/endpoints/localai/vad.go new file mode 100644 index 00000000..c5a5d929 --- /dev/null +++ b/core/http/endpoints/localai/vad.go @@ -0,0 +1,68 @@ +package localai + +import ( + "github.com/gofiber/fiber/v2" + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + fiberContext "github.com/mudler/LocalAI/core/http/ctx" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/LocalAI/pkg/model" + "github.com/rs/zerolog/log" +) + +// VADEndpoint is Voice-Activation-Detection endpoint +// @Summary Detect voice fragments in an audio stream +// @Accept json +// @Param request body schema.VADRequest true "query params" +// @Success 200 {object} proto.VADResponse "Response" +// @Router /vad [post] +func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + + input := new(schema.VADRequest) + + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) + if err != nil { + modelFile = input.Model + log.Warn().Msgf("Model not found in context: %s", input.Model) + } + + cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, + config.LoadOptionDebug(appConfig.Debug), + config.LoadOptionThreads(appConfig.Threads), + config.LoadOptionContextSize(appConfig.ContextSize), + config.LoadOptionF16(appConfig.F16), + ) + + if err != nil { + log.Err(err) + modelFile = input.Model + log.Warn().Msgf("Model not found in context: %s", input.Model) + } else { + modelFile = cfg.Model + } + log.Debug().Msgf("Request for model: %s", modelFile) + + opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend), model.WithModel(modelFile)) + + vadModel, err := ml.Load(opts...) + if err != nil { + return err + } + req := proto.VADRequest{ + Audio: input.Audio, + } + resp, err := vadModel.VAD(c.Context(), &req) + if err != nil { + return err + } + + return c.JSON(resp) + } +} diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index a2ef16a5..e7097741 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -34,6 +34,7 @@ func RegisterLocalAIRoutes(app *fiber.App, } app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig)) + app.Post("/vad", localai.VADEndpoint(cl, ml, appConfig)) // Stores sl := model.NewModelLoader("") diff --git a/core/schema/localai.go b/core/schema/localai.go index 861ed577..08afc6df 100644 --- a/core/schema/localai.go +++ b/core/schema/localai.go @@ -30,10 +30,16 @@ type TTSRequest struct { Input string `json:"input" yaml:"input"` // text input Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id Backend string `json:"backend" yaml:"backend"` - Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model + Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format } +// @Description VAD request body +type VADRequest struct { + Model string `json:"model" yaml:"model"` // model name or full path + Audio []float32 `json:"audio" yaml:"audio"` // model name or full path +} + type StoresSet struct { Store string `json:"store,omitempty" yaml:"store,omitempty"` diff --git a/go.mod b/go.mod index 93ef4779..de723044 100644 --- a/go.mod +++ b/go.mod @@ -86,6 +86,9 @@ require ( github.com/pion/turn/v2 v2.1.6 // indirect github.com/pion/webrtc/v3 v3.3.0 // indirect github.com/shirou/gopsutil/v4 v4.24.7 // indirect + github.com/streamer45/silero-vad-go v0.2.1 // indirect + github.com/urfave/cli/v2 v2.27.4 // indirect + github.com/valyala/fasttemplate v1.2.2 // indirect github.com/wlynxg/anet v0.0.4 // indirect go.uber.org/mock v0.4.0 // indirect ) diff --git a/go.sum b/go.sum index 932888ad..b92ae6a1 100644 --- a/go.sum +++ b/go.sum @@ -674,6 +674,11 @@ github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:Udh github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w= +github.com/spf13/cast v1.5.0/go.mod h1:SpXXQ5YoyJw6s3/6cMTQuxvgRl3PCJiyaX9p6b155UU= +github.com/streamer45/silero-vad-go v0.2.1 h1:Li1/tTC4H/3cyw6q4weX+U8GWwEL3lTekK/nYa1Cvuk= +github.com/streamer45/silero-vad-go v0.2.1/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs= github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 637a6db1..21435891 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -53,4 +53,6 @@ type Backend interface { Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) + + VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) } diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 3356f86b..2e1fb209 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -92,6 +92,10 @@ func (llm *Base) StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error) return pb.StoresFindResult{}, fmt.Errorf("unimplemented") } +func (llm *Base) VAD(*pb.VADRequest) (pb.VADResponse, error) { + return pb.VADResponse{}, fmt.Errorf("unimplemented") +} + func memoryUsage() *pb.MemoryUsageData { mud := pb.MemoryUsageData{ Breakdown: make(map[string]uint64), diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 14481620..9c8b302e 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -392,3 +392,21 @@ func (c *Client) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opt client := pb.NewBackendClient(conn) return client.GetMetrics(ctx, in, opts...) } + +func (c *Client) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.VAD(ctx, in, opts...) +} diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index cf624344..a5828a5f 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -87,6 +87,10 @@ func (e *embedBackend) Rerank(ctx context.Context, in *pb.RerankRequest, opts .. return e.s.Rerank(ctx, in) } +func (e *embedBackend) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) { + return e.s.VAD(ctx, in) +} + func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) { return e.s.GetMetrics(ctx, in) } diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 97b958cc..9214e3cf 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -24,6 +24,8 @@ type LLM interface { StoresDelete(*pb.StoresDeleteOptions) error StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error) StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error) + + VAD(*pb.VADRequest) (pb.VADResponse, error) } func newReply(s string) *pb.Reply { diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 0e602a42..0b2a167f 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -227,6 +227,18 @@ func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb. return &res, nil } +func (s *server) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADResponse, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } + res, err := s.llm.VAD(in) + if err != nil { + return nil, err + } + return &res, nil +} + func StartServer(address string, model LLM) error { lis, err := net.Listen("tcp", address) if err != nil {