mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-19 20:57:54 +00:00
feat: add --single-active-backend to allow only one backend active at the time (#925)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
1079b18ff7
commit
afdc0ebfd7
@ -21,25 +21,13 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
||||
var inferenceModel interface{}
|
||||
var err error
|
||||
|
||||
opts := []model.Option{
|
||||
opts := modelOpts(c, o, []model.Option{
|
||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(o.Context),
|
||||
}
|
||||
|
||||
if c.GRPC.Attempts != 0 {
|
||||
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
|
||||
}
|
||||
|
||||
if c.GRPC.AttemptsSleepTime != 0 {
|
||||
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
})
|
||||
|
||||
if c.Backend == "" {
|
||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) {
|
||||
|
||||
opts := []model.Option{
|
||||
opts := modelOpts(c, o, []model.Option{
|
||||
model.WithBackendString(c.Backend),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
@ -25,19 +25,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
||||
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
||||
}),
|
||||
}
|
||||
|
||||
if c.GRPC.Attempts != 0 {
|
||||
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
|
||||
}
|
||||
|
||||
if c.GRPC.AttemptsSleepTime != 0 {
|
||||
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
})
|
||||
|
||||
inferenceModel, err := loader.BackendLoader(
|
||||
opts...,
|
||||
|
@ -33,25 +33,13 @@ func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c
|
||||
var inferenceModel *grpc.Client
|
||||
var err error
|
||||
|
||||
opts := []model.Option{
|
||||
opts := modelOpts(c, o, []model.Option{
|
||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(o.Context),
|
||||
}
|
||||
|
||||
if c.GRPC.Attempts != 0 {
|
||||
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
|
||||
}
|
||||
|
||||
if c.GRPC.AttemptsSleepTime != 0 {
|
||||
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
})
|
||||
|
||||
if c.Backend != "" {
|
||||
opts = append(opts, model.WithBackendString(c.Backend))
|
||||
|
@ -5,10 +5,32 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
)
|
||||
|
||||
func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option {
|
||||
if o.SingleBackend {
|
||||
opts = append(opts, model.WithSingleActiveBackend())
|
||||
}
|
||||
|
||||
if c.GRPC.Attempts != 0 {
|
||||
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
|
||||
}
|
||||
|
||||
if c.GRPC.AttemptsSleepTime != 0 {
|
||||
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
func gRPCModelOpts(c config.Config) *pb.ModelOptions {
|
||||
b := 512
|
||||
if c.Batch != 0 {
|
||||
|
@ -13,24 +13,14 @@ import (
|
||||
)
|
||||
|
||||
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) {
|
||||
opts := []model.Option{
|
||||
|
||||
opts := modelOpts(c, o, []model.Option{
|
||||
model.WithBackendString(model.WhisperBackend),
|
||||
model.WithModel(c.Model),
|
||||
model.WithContext(o.Context),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
}
|
||||
|
||||
if c.GRPC.Attempts != 0 {
|
||||
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
|
||||
}
|
||||
|
||||
if c.GRPC.AttemptsSleepTime != 0 {
|
||||
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
|
||||
}
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
})
|
||||
|
||||
whisperModel, err := o.Loader.BackendLoader(opts...)
|
||||
if err != nil {
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
api_config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
@ -33,17 +34,12 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *opt
|
||||
if bb == "" {
|
||||
bb = model.PiperBackend
|
||||
}
|
||||
opts := []model.Option{
|
||||
opts := modelOpts(api_config.Config{}, o, []model.Option{
|
||||
model.WithBackendString(bb),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(o.Context),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
})
|
||||
piperModel, err := o.Loader.BackendLoader(opts...)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
|
@ -33,6 +33,8 @@ type Option struct {
|
||||
ExternalGRPCBackends map[string]string
|
||||
|
||||
AutoloadGalleries bool
|
||||
|
||||
SingleBackend bool
|
||||
}
|
||||
|
||||
type AppOption func(*Option)
|
||||
@ -58,6 +60,10 @@ func WithCors(b bool) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
var EnableSingleBackend = func(o *Option) {
|
||||
o.SingleBackend = true
|
||||
}
|
||||
|
||||
var EnableGalleriesAutoload = func(o *Option) {
|
||||
o.AutoloadGalleries = true
|
||||
}
|
||||
|
@ -77,7 +77,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
@ -51,7 +51,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
@ -267,7 +267,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
@ -110,7 +110,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
@ -34,7 +34,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
4
go.mod
4
go.mod
@ -18,7 +18,7 @@ require (
|
||||
github.com/json-iterator/go v1.1.12
|
||||
github.com/mholt/archiver/v3 v3.5.1
|
||||
github.com/mudler/go-ggllm.cpp v0.0.0-20230709223052-862477d16eef
|
||||
github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d
|
||||
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c
|
||||
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230815171941-a63093554fb5
|
||||
github.com/onsi/ginkgo/v2 v2.11.0
|
||||
@ -40,7 +40,7 @@ require (
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/shirou/gopsutil/v3 v3.23.6
|
||||
github.com/shirou/gopsutil/v3 v3.23.7
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.11 // indirect
|
||||
github.com/tklauser/numcpus v0.6.0 // indirect
|
||||
|
4
go.sum
4
go.sum
@ -119,6 +119,8 @@ github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm
|
||||
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
|
||||
github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d h1:/lAg9vPAAU+s35cDMCx1IyeMn+4OYfCBPqi08Q8vXDg=
|
||||
github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d/go.mod h1:HGGAOJhipApckwNV8ZTliRJqxctUv3xRY+zbQEwuytc=
|
||||
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c h1:CI5uGwqBpN8N7BrSKC+nmdfw+9nPQIDyjHHlaIiitZI=
|
||||
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c/go.mod h1:gY3wyrhkRySJtmtI/JPt4a2mKv48h/M9pEZIW+SjeC0=
|
||||
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af h1:XFq6OUqsWQam0OrEr05okXsJK/TQur3zoZTHbiZD3Ks=
|
||||
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230815171941-a63093554fb5 h1:b4EeYDaGxOLNlNm5LOVEmrUhaw1v6xq/V79ZwWVlY6I=
|
||||
@ -164,6 +166,8 @@ github.com/sashabaranov/go-openai v1.14.2 h1:5DPTtR9JBjKPJS008/A409I5ntFhUPPGCma
|
||||
github.com/sashabaranov/go-openai v1.14.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/shirou/gopsutil/v3 v3.23.6 h1:5y46WPI9QBKBbK7EEccUPNXpJpNrvPuTD0O2zHEHT08=
|
||||
github.com/shirou/gopsutil/v3 v3.23.6/go.mod h1:j7QX50DrXYggrpN30W0Mo+I4/8U2UUIQrnrhqUeWrAU=
|
||||
github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4=
|
||||
github.com/shirou/gopsutil/v3 v3.23.7/go.mod h1:c4gnmoRC0hQuaLqvxnx1//VXQ0Ms/X9UnJF8pddY5z4=
|
||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
|
||||
|
9
main.go
9
main.go
@ -49,6 +49,11 @@ func main() {
|
||||
Name: "debug",
|
||||
EnvVars: []string{"DEBUG"},
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "single-active-backend",
|
||||
EnvVars: []string{"SINGLE_ACTIVE_BACKEND"},
|
||||
Usage: "Allow only one backend to be running.",
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "cors",
|
||||
EnvVars: []string{"CORS"},
|
||||
@ -181,6 +186,10 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
||||
options.WithApiKeys(ctx.StringSlice("api-keys")),
|
||||
}
|
||||
|
||||
if ctx.Bool("single-active-backend") {
|
||||
opts = append(opts, options.EnableSingleBackend)
|
||||
}
|
||||
|
||||
externalgRPC := ctx.StringSlice("external-grpc-backends")
|
||||
// split ":" to get backend name and the uri
|
||||
for _, v := range externalgRPC {
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
@ -14,6 +15,8 @@ import (
|
||||
|
||||
type Client struct {
|
||||
address string
|
||||
busy bool
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func NewClient(address string) *Client {
|
||||
@ -22,7 +25,21 @@ func NewClient(address string) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) IsBusy() bool {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
return c.busy
|
||||
}
|
||||
|
||||
func (c *Client) setBusy(v bool) {
|
||||
c.Lock()
|
||||
c.busy = v
|
||||
c.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) HealthCheck(ctx context.Context) bool {
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
@ -49,6 +66,8 @@ func (c *Client) HealthCheck(ctx context.Context) bool {
|
||||
}
|
||||
|
||||
func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) {
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -60,6 +79,8 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...
|
||||
}
|
||||
|
||||
func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) {
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -71,6 +92,8 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp
|
||||
}
|
||||
|
||||
func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -81,6 +104,8 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
|
||||
}
|
||||
|
||||
func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return err
|
||||
@ -110,6 +135,8 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
|
||||
}
|
||||
|
||||
func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -120,6 +147,8 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest,
|
||||
}
|
||||
|
||||
func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -130,6 +159,8 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
|
||||
}
|
||||
|
||||
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*api.Result, error) {
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -160,6 +191,8 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
|
||||
}
|
||||
|
||||
func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -176,6 +209,8 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts
|
||||
}
|
||||
|
||||
func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) {
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -4,20 +4,14 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/hpcloud/tail"
|
||||
"github.com/phayes/freeport"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
process "github.com/mudler/go-processmanager"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -65,89 +59,6 @@ var AutoLoadBackends []string = []string{
|
||||
PiperBackend,
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) GetGRPCPID(id string) (int, error) {
|
||||
p, exists := ml.grpcProcesses[id]
|
||||
if !exists {
|
||||
return -1, fmt.Errorf("no grpc backend found for %s", id)
|
||||
}
|
||||
return strconv.Atoi(p.PID)
|
||||
}
|
||||
|
||||
type GRPCProcessFilter = func(p *process.Process) bool
|
||||
|
||||
func includeAllProcesses(_ *process.Process) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) {
|
||||
for _, p := range ml.grpcProcesses {
|
||||
if filter(p) {
|
||||
p.Stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) StopAllGRPC() {
|
||||
ml.StopGRPC(includeAllProcesses)
|
||||
// for _, p := range ml.grpcProcesses {
|
||||
// p.Stop()
|
||||
// }
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error {
|
||||
// Make sure the process is executable
|
||||
if err := os.Chmod(grpcProcess, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Loading GRPC Process: %s", grpcProcess)
|
||||
|
||||
log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress)
|
||||
|
||||
grpcControlProcess := process.New(
|
||||
process.WithTemporaryStateDir(),
|
||||
process.WithName(grpcProcess),
|
||||
process.WithArgs("--addr", serverAddress),
|
||||
process.WithEnvironment(os.Environ()...),
|
||||
)
|
||||
|
||||
ml.grpcProcesses[id] = grpcControlProcess
|
||||
|
||||
if err := grpcControlProcess.Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir())
|
||||
// clean up process
|
||||
go func() {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
<-c
|
||||
grpcControlProcess.Stop()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true})
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Could not tail stderr")
|
||||
}
|
||||
for line := range t.Lines {
|
||||
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true})
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Could not tail stdout")
|
||||
}
|
||||
for line := range t.Lines {
|
||||
log.Debug().Msgf("GRPC(%s): stdout %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// starts the grpcModelProcess for the backend, and returns a grpc client
|
||||
// It also loads the model
|
||||
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (*grpc.Client, error) {
|
||||
@ -248,6 +159,13 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
|
||||
|
||||
backend := strings.ToLower(o.backendString)
|
||||
|
||||
if o.singleActiveBackend {
|
||||
ml.mu.Lock()
|
||||
log.Debug().Msgf("Stopping all backends except '%s'", o.model)
|
||||
ml.StopAllExcept(o.model)
|
||||
ml.mu.Unlock()
|
||||
}
|
||||
|
||||
// if an external backend is provided, use it
|
||||
_, externalBackendExists := o.externalBackends[backend]
|
||||
if externalBackendExists {
|
||||
@ -274,14 +192,21 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
|
||||
func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
|
||||
o := NewOptions(opts...)
|
||||
|
||||
// Is this really needed? BackendLoader already does this
|
||||
ml.mu.Lock()
|
||||
// Return earlier if we have a model already loaded
|
||||
// (avoid looping through all the backends)
|
||||
if m := ml.CheckIsLoaded(o.model); m != nil {
|
||||
log.Debug().Msgf("Model '%s' already loaded", o.model)
|
||||
ml.mu.Unlock()
|
||||
return m, nil
|
||||
}
|
||||
// If we can have only one backend active, kill all the others (except external backends)
|
||||
if o.singleActiveBackend {
|
||||
log.Debug().Msgf("Stopping all backends except '%s'", o.model)
|
||||
ml.StopAllExcept(o.model)
|
||||
}
|
||||
ml.mu.Unlock()
|
||||
|
||||
var err error
|
||||
|
||||
// autoload also external backends
|
||||
|
@ -137,9 +137,7 @@ func (ml *ModelLoader) CheckIsLoaded(s string) *grpc.Client {
|
||||
if !ml.grpcProcesses[s].IsAlive() {
|
||||
log.Debug().Msgf("GRPC Process is not responding: %s", s)
|
||||
// stop and delete the process, this forces to re-load the model and re-create again the service
|
||||
ml.grpcProcesses[s].Stop()
|
||||
delete(ml.grpcProcesses, s)
|
||||
delete(ml.models, s)
|
||||
ml.deleteProcess(s)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ type Options struct {
|
||||
|
||||
grpcAttempts int
|
||||
grpcAttemptsDelay int
|
||||
singleActiveBackend bool
|
||||
}
|
||||
|
||||
type Option func(*Options)
|
||||
@ -80,6 +81,12 @@ func WithContext(ctx context.Context) Option {
|
||||
}
|
||||
}
|
||||
|
||||
func WithSingleActiveBackend() Option {
|
||||
return func(o *Options) {
|
||||
o.singleActiveBackend = true
|
||||
}
|
||||
}
|
||||
|
||||
func NewOptions(opts ...Option) *Options {
|
||||
o := &Options{
|
||||
gRPCOptions: &pb.ModelOptions{},
|
||||
|
118
pkg/model/process.go
Normal file
118
pkg/model/process.go
Normal file
@ -0,0 +1,118 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/hpcloud/tail"
|
||||
process "github.com/mudler/go-processmanager"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func (ml *ModelLoader) StopAllExcept(s string) {
|
||||
ml.StopGRPC(func(id string, p *process.Process) bool {
|
||||
if id != s {
|
||||
for ml.models[id].IsBusy() {
|
||||
log.Debug().Msgf("%s busy. Waiting.", id)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
log.Debug().Msgf("[single-backend] Stopping %s", id)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) deleteProcess(s string) error {
|
||||
if err := ml.grpcProcesses[s].Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(ml.grpcProcesses, s)
|
||||
delete(ml.models, s)
|
||||
return nil
|
||||
}
|
||||
|
||||
type GRPCProcessFilter = func(id string, p *process.Process) bool
|
||||
|
||||
func includeAllProcesses(_ string, _ *process.Process) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) {
|
||||
for k, p := range ml.grpcProcesses {
|
||||
if filter(k, p) {
|
||||
ml.deleteProcess(k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) StopAllGRPC() {
|
||||
ml.StopGRPC(includeAllProcesses)
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) GetGRPCPID(id string) (int, error) {
|
||||
p, exists := ml.grpcProcesses[id]
|
||||
if !exists {
|
||||
return -1, fmt.Errorf("no grpc backend found for %s", id)
|
||||
}
|
||||
return strconv.Atoi(p.PID)
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error {
|
||||
// Make sure the process is executable
|
||||
if err := os.Chmod(grpcProcess, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Loading GRPC Process: %s", grpcProcess)
|
||||
|
||||
log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress)
|
||||
|
||||
grpcControlProcess := process.New(
|
||||
process.WithTemporaryStateDir(),
|
||||
process.WithName(grpcProcess),
|
||||
process.WithArgs("--addr", serverAddress),
|
||||
process.WithEnvironment(os.Environ()...),
|
||||
)
|
||||
|
||||
ml.grpcProcesses[id] = grpcControlProcess
|
||||
|
||||
if err := grpcControlProcess.Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir())
|
||||
// clean up process
|
||||
go func() {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
<-c
|
||||
grpcControlProcess.Stop()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true})
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Could not tail stderr")
|
||||
}
|
||||
for line := range t.Lines {
|
||||
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true})
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Could not tail stdout")
|
||||
}
|
||||
for line := range t.Lines {
|
||||
log.Debug().Msgf("GRPC(%s): stdout %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user