mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
feat(silero): add Silero-vad backend (#4204)
* feat(vad): add silero-vad backend (WIP) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(vad): add API endpoint Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(vad): correctly place the onnxruntime libs Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore(vad): hook silero-vad to binary and container builds Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(gRPC): register VAD Server Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(Makefile): consume ONNX_OS consistently Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(Makefile): handle macOS Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
parent
9892d7d584
commit
b1ea9318e6
39
Makefile
39
Makefile
@ -34,6 +34,10 @@ STABLEDIFFUSION_VERSION?=4a3cd6aeae6f66ee57eae9a0075f8c58c3a6a38f
|
|||||||
TINYDREAM_REPO?=https://github.com/M0Rf30/go-tiny-dream
|
TINYDREAM_REPO?=https://github.com/M0Rf30/go-tiny-dream
|
||||||
TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057
|
TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057
|
||||||
|
|
||||||
|
ONNX_VERSION?=1.20.0
|
||||||
|
ONNX_ARCH?=x64
|
||||||
|
ONNX_OS?=linux
|
||||||
|
|
||||||
export BUILD_TYPE?=
|
export BUILD_TYPE?=
|
||||||
export STABLE_BUILD_TYPE?=$(BUILD_TYPE)
|
export STABLE_BUILD_TYPE?=$(BUILD_TYPE)
|
||||||
export CMAKE_ARGS?=
|
export CMAKE_ARGS?=
|
||||||
@ -89,7 +93,20 @@ ifeq ($(NATIVE),false)
|
|||||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
# Detect if we are running on arm64
|
||||||
|
ifneq (,$(findstring aarch64,$(shell uname -m)))
|
||||||
|
ONNX_ARCH=aarch64
|
||||||
|
endif
|
||||||
|
|
||||||
ifeq ($(OS),Darwin)
|
ifeq ($(OS),Darwin)
|
||||||
|
ONNX_OS=osx
|
||||||
|
ifneq (,$(findstring aarch64,$(shell uname -m)))
|
||||||
|
ONNX_ARCH=arm64
|
||||||
|
else ifneq (,$(findstring arm64,$(shell uname -m)))
|
||||||
|
ONNX_ARCH=arm64
|
||||||
|
else
|
||||||
|
ONNX_ARCH=x86_64
|
||||||
|
endif
|
||||||
|
|
||||||
ifeq ($(OSX_SIGNING_IDENTITY),)
|
ifeq ($(OSX_SIGNING_IDENTITY),)
|
||||||
OSX_SIGNING_IDENTITY := $(shell security find-identity -v -p codesigning | grep '"' | head -n 1 | sed -E 's/.*"(.*)"/\1/')
|
OSX_SIGNING_IDENTITY := $(shell security find-identity -v -p codesigning | grep '"' | head -n 1 | sed -E 's/.*"(.*)"/\1/')
|
||||||
@ -195,6 +212,7 @@ ALL_GRPC_BACKENDS+=backend-assets/util/llama-cpp-rpc-server
|
|||||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/rwkv
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/rwkv
|
||||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
|
||||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
|
||||||
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/silero-vad
|
||||||
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
|
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
|
||||||
# Use filter-out to remove the specified backends
|
# Use filter-out to remove the specified backends
|
||||||
ALL_GRPC_BACKENDS := $(filter-out $(SKIP_GRPC_BACKEND),$(ALL_GRPC_BACKENDS))
|
ALL_GRPC_BACKENDS := $(filter-out $(SKIP_GRPC_BACKEND),$(ALL_GRPC_BACKENDS))
|
||||||
@ -281,6 +299,20 @@ sources/go-stable-diffusion:
|
|||||||
sources/go-stable-diffusion/libstablediffusion.a: sources/go-stable-diffusion
|
sources/go-stable-diffusion/libstablediffusion.a: sources/go-stable-diffusion
|
||||||
CPATH="$(CPATH):/usr/include/opencv4" $(MAKE) -C sources/go-stable-diffusion libstablediffusion.a
|
CPATH="$(CPATH):/usr/include/opencv4" $(MAKE) -C sources/go-stable-diffusion libstablediffusion.a
|
||||||
|
|
||||||
|
sources/onnxruntime:
|
||||||
|
mkdir -p sources/onnxruntime
|
||||||
|
curl -L https://github.com/microsoft/onnxruntime/releases/download/v$(ONNX_VERSION)/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz -o sources/onnxruntime/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz
|
||||||
|
cd sources/onnxruntime && tar -xvf onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz && rm onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz
|
||||||
|
cd sources/onnxruntime && mv onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION)/* ./
|
||||||
|
|
||||||
|
backend-assets/lib/libonnxruntime.so.1: backend-assets/lib sources/onnxruntime
|
||||||
|
cp -rfv sources/onnxruntime/lib/* backend-assets/lib/
|
||||||
|
ifeq ($(OS),Darwin)
|
||||||
|
mv backend-assets/lib/libonnxruntime.$(ONNX_VERSION).dylib backend-assets/lib/libonnxruntime.dylib
|
||||||
|
else
|
||||||
|
mv backend-assets/lib/libonnxruntime.so.$(ONNX_VERSION) backend-assets/lib/libonnxruntime.so.1
|
||||||
|
endif
|
||||||
|
|
||||||
## tiny-dream
|
## tiny-dream
|
||||||
sources/go-tiny-dream:
|
sources/go-tiny-dream:
|
||||||
mkdir -p sources/go-tiny-dream
|
mkdir -p sources/go-tiny-dream
|
||||||
@ -837,6 +869,13 @@ ifneq ($(UPX),)
|
|||||||
$(UPX) backend-assets/grpc/stablediffusion
|
$(UPX) backend-assets/grpc/stablediffusion
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
backend-assets/grpc/silero-vad: backend-assets/grpc backend-assets/lib/libonnxruntime.so.1
|
||||||
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" CPATH="$(CPATH):$(CURDIR)/sources/onnxruntime/include/" LIBRARY_PATH=$(CURDIR)/backend-assets/lib \
|
||||||
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/silero-vad ./backend/go/vad/silero
|
||||||
|
ifneq ($(UPX),)
|
||||||
|
$(UPX) backend-assets/grpc/silero-vad
|
||||||
|
endif
|
||||||
|
|
||||||
backend-assets/grpc/tinydream: sources/go-tiny-dream sources/go-tiny-dream/libtinydream.a backend-assets/grpc
|
backend-assets/grpc/tinydream: sources/go-tiny-dream sources/go-tiny-dream/libtinydream.a backend-assets/grpc
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/go-tiny-dream \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/go-tiny-dream \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/tinydream ./backend/go/image/tinydream
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/tinydream ./backend/go/image/tinydream
|
||||||
|
@ -28,6 +28,8 @@ service Backend {
|
|||||||
rpc Rerank(RerankRequest) returns (RerankResult) {}
|
rpc Rerank(RerankRequest) returns (RerankResult) {}
|
||||||
|
|
||||||
rpc GetMetrics(MetricsRequest) returns (MetricsResponse);
|
rpc GetMetrics(MetricsRequest) returns (MetricsResponse);
|
||||||
|
|
||||||
|
rpc VAD(VADRequest) returns (VADResponse) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define the empty request
|
// Define the empty request
|
||||||
@ -293,6 +295,19 @@ message TTSRequest {
|
|||||||
optional string language = 5;
|
optional string language = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message VADRequest {
|
||||||
|
repeated float audio = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message VADSegment {
|
||||||
|
float start = 1;
|
||||||
|
float end = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message VADResponse {
|
||||||
|
repeated VADSegment segments = 1;
|
||||||
|
}
|
||||||
|
|
||||||
message SoundGenerationRequest {
|
message SoundGenerationRequest {
|
||||||
string text = 1;
|
string text = 1;
|
||||||
string model = 2;
|
string model = 2;
|
||||||
|
21
backend/go/vad/silero/main.go
Normal file
21
backend/go/vad/silero/main.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
|
||||||
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &VAD{}); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
54
backend/go/vad/silero/vad.go
Normal file
54
backend/go/vad/silero/vad.go
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/streamer45/silero-vad-go/speech"
|
||||||
|
)
|
||||||
|
|
||||||
|
type VAD struct {
|
||||||
|
base.SingleThread
|
||||||
|
detector *speech.Detector
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vad *VAD) Load(opts *pb.ModelOptions) error {
|
||||||
|
v, err := speech.NewDetector(speech.DetectorConfig{
|
||||||
|
ModelPath: opts.ModelFile,
|
||||||
|
SampleRate: 16000,
|
||||||
|
//WindowSize: 1024,
|
||||||
|
Threshold: 0.5,
|
||||||
|
MinSilenceDurationMs: 0,
|
||||||
|
SpeechPadMs: 0,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create silero detector: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
vad.detector = v
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vad *VAD) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
|
||||||
|
audio := req.Audio
|
||||||
|
|
||||||
|
segments, err := vad.detector.Detect(audio)
|
||||||
|
if err != nil {
|
||||||
|
return pb.VADResponse{}, fmt.Errorf("detect: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
vadSegments := []*pb.VADSegment{}
|
||||||
|
for _, s := range segments {
|
||||||
|
vadSegments = append(vadSegments, &pb.VADSegment{
|
||||||
|
Start: float32(s.SpeechStartAt),
|
||||||
|
End: float32(s.SpeechEndAt),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return pb.VADResponse{
|
||||||
|
Segments: vadSegments,
|
||||||
|
}, nil
|
||||||
|
}
|
68
core/http/endpoints/localai/vad.go
Normal file
68
core/http/endpoints/localai/vad.go
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VADEndpoint is Voice-Activation-Detection endpoint
|
||||||
|
// @Summary Detect voice fragments in an audio stream
|
||||||
|
// @Accept json
|
||||||
|
// @Param request body schema.VADRequest true "query params"
|
||||||
|
// @Success 200 {object} proto.VADResponse "Response"
|
||||||
|
// @Router /vad [post]
|
||||||
|
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
|
input := new(schema.VADRequest)
|
||||||
|
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
|
||||||
|
if err != nil {
|
||||||
|
modelFile = input.Model
|
||||||
|
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||||
|
config.LoadOptionDebug(appConfig.Debug),
|
||||||
|
config.LoadOptionThreads(appConfig.Threads),
|
||||||
|
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||||
|
config.LoadOptionF16(appConfig.F16),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Err(err)
|
||||||
|
modelFile = input.Model
|
||||||
|
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||||
|
} else {
|
||||||
|
modelFile = cfg.Model
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Request for model: %s", modelFile)
|
||||||
|
|
||||||
|
opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend), model.WithModel(modelFile))
|
||||||
|
|
||||||
|
vadModel, err := ml.Load(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req := proto.VADRequest{
|
||||||
|
Audio: input.Audio,
|
||||||
|
}
|
||||||
|
resp, err := vadModel.VAD(c.Context(), &req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
@ -34,6 +34,7 @@ func RegisterLocalAIRoutes(app *fiber.App,
|
|||||||
}
|
}
|
||||||
|
|
||||||
app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
|
app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// Stores
|
// Stores
|
||||||
sl := model.NewModelLoader("")
|
sl := model.NewModelLoader("")
|
||||||
|
@ -30,10 +30,16 @@ type TTSRequest struct {
|
|||||||
Input string `json:"input" yaml:"input"` // text input
|
Input string `json:"input" yaml:"input"` // text input
|
||||||
Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id
|
Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id
|
||||||
Backend string `json:"backend" yaml:"backend"`
|
Backend string `json:"backend" yaml:"backend"`
|
||||||
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
|
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
|
||||||
Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format
|
Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// @Description VAD request body
|
||||||
|
type VADRequest struct {
|
||||||
|
Model string `json:"model" yaml:"model"` // model name or full path
|
||||||
|
Audio []float32 `json:"audio" yaml:"audio"` // model name or full path
|
||||||
|
}
|
||||||
|
|
||||||
type StoresSet struct {
|
type StoresSet struct {
|
||||||
Store string `json:"store,omitempty" yaml:"store,omitempty"`
|
Store string `json:"store,omitempty" yaml:"store,omitempty"`
|
||||||
|
|
||||||
|
3
go.mod
3
go.mod
@ -86,6 +86,9 @@ require (
|
|||||||
github.com/pion/turn/v2 v2.1.6 // indirect
|
github.com/pion/turn/v2 v2.1.6 // indirect
|
||||||
github.com/pion/webrtc/v3 v3.3.0 // indirect
|
github.com/pion/webrtc/v3 v3.3.0 // indirect
|
||||||
github.com/shirou/gopsutil/v4 v4.24.7 // indirect
|
github.com/shirou/gopsutil/v4 v4.24.7 // indirect
|
||||||
|
github.com/streamer45/silero-vad-go v0.2.1 // indirect
|
||||||
|
github.com/urfave/cli/v2 v2.27.4 // indirect
|
||||||
|
github.com/valyala/fasttemplate v1.2.2 // indirect
|
||||||
github.com/wlynxg/anet v0.0.4 // indirect
|
github.com/wlynxg/anet v0.0.4 // indirect
|
||||||
go.uber.org/mock v0.4.0 // indirect
|
go.uber.org/mock v0.4.0 // indirect
|
||||||
)
|
)
|
||||||
|
5
go.sum
5
go.sum
@ -674,6 +674,11 @@ github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:Udh
|
|||||||
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA=
|
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA=
|
||||||
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
|
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
|
||||||
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||||
|
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
|
||||||
|
github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w=
|
||||||
|
github.com/spf13/cast v1.5.0/go.mod h1:SpXXQ5YoyJw6s3/6cMTQuxvgRl3PCJiyaX9p6b155UU=
|
||||||
|
github.com/streamer45/silero-vad-go v0.2.1 h1:Li1/tTC4H/3cyw6q4weX+U8GWwEL3lTekK/nYa1Cvuk=
|
||||||
|
github.com/streamer45/silero-vad-go v0.2.1/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
|
||||||
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
|
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
|
||||||
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
@ -53,4 +53,6 @@ type Backend interface {
|
|||||||
Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error)
|
Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error)
|
||||||
|
|
||||||
GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error)
|
GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error)
|
||||||
|
|
||||||
|
VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error)
|
||||||
}
|
}
|
||||||
|
@ -92,6 +92,10 @@ func (llm *Base) StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)
|
|||||||
return pb.StoresFindResult{}, fmt.Errorf("unimplemented")
|
return pb.StoresFindResult{}, fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (llm *Base) VAD(*pb.VADRequest) (pb.VADResponse, error) {
|
||||||
|
return pb.VADResponse{}, fmt.Errorf("unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
func memoryUsage() *pb.MemoryUsageData {
|
func memoryUsage() *pb.MemoryUsageData {
|
||||||
mud := pb.MemoryUsageData{
|
mud := pb.MemoryUsageData{
|
||||||
Breakdown: make(map[string]uint64),
|
Breakdown: make(map[string]uint64),
|
||||||
|
@ -392,3 +392,21 @@ func (c *Client) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opt
|
|||||||
client := pb.NewBackendClient(conn)
|
client := pb.NewBackendClient(conn)
|
||||||
return client.GetMetrics(ctx, in, opts...)
|
return client.GetMetrics(ctx, in, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) {
|
||||||
|
if !c.parallel {
|
||||||
|
c.opMutex.Lock()
|
||||||
|
defer c.opMutex.Unlock()
|
||||||
|
}
|
||||||
|
c.setBusy(true)
|
||||||
|
defer c.setBusy(false)
|
||||||
|
c.wdMark()
|
||||||
|
defer c.wdUnMark()
|
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
client := pb.NewBackendClient(conn)
|
||||||
|
return client.VAD(ctx, in, opts...)
|
||||||
|
}
|
||||||
|
@ -87,6 +87,10 @@ func (e *embedBackend) Rerank(ctx context.Context, in *pb.RerankRequest, opts ..
|
|||||||
return e.s.Rerank(ctx, in)
|
return e.s.Rerank(ctx, in)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *embedBackend) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) {
|
||||||
|
return e.s.VAD(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) {
|
func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) {
|
||||||
return e.s.GetMetrics(ctx, in)
|
return e.s.GetMetrics(ctx, in)
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,8 @@ type LLM interface {
|
|||||||
StoresDelete(*pb.StoresDeleteOptions) error
|
StoresDelete(*pb.StoresDeleteOptions) error
|
||||||
StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error)
|
StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error)
|
||||||
StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)
|
StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)
|
||||||
|
|
||||||
|
VAD(*pb.VADRequest) (pb.VADResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newReply(s string) *pb.Reply {
|
func newReply(s string) *pb.Reply {
|
||||||
|
@ -227,6 +227,18 @@ func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.
|
|||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *server) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADResponse, error) {
|
||||||
|
if s.llm.Locking() {
|
||||||
|
s.llm.Lock()
|
||||||
|
defer s.llm.Unlock()
|
||||||
|
}
|
||||||
|
res, err := s.llm.VAD(in)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &res, nil
|
||||||
|
}
|
||||||
|
|
||||||
func StartServer(address string, model LLM) error {
|
func StartServer(address string, model LLM) error {
|
||||||
lis, err := net.Listen("tcp", address)
|
lis, err := net.Listen("tcp", address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user