diff --git a/api/localai/backend_monitor.go b/api/localai/backend_monitor.go index e841c5f8..3aea8a84 100644 --- a/api/localai/backend_monitor.go +++ b/api/localai/backend_monitor.go @@ -123,13 +123,12 @@ func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error { return err } - client := bm.options.Loader.CheckIsLoaded(backendId) - - if client == "" { + model := bm.options.Loader.CheckIsLoaded(backendId) + if model == "" { return fmt.Errorf("backend %s is not currently loaded", backendId) } - status, rpcErr := client.GRPC().Status(context.TODO()) + status, rpcErr := model.GRPC(false).Status(context.TODO()) if rpcErr != nil { log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error()) val, slbErr := bm.SampleLocalBackendProcess(backendId) diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 0697ac69..673e2a54 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -14,14 +14,17 @@ import ( ) type Client struct { - address string - busy bool + address string + busy bool + parallel bool sync.Mutex + opMutex sync.Mutex } -func NewClient(address string) *Client { +func NewClient(address string, parallel bool) *Client { return &Client{ - address: address, + address: address, + parallel: parallel, } } @@ -38,6 +41,10 @@ func (c *Client) setBusy(v bool) { } func (c *Client) HealthCheck(ctx context.Context) bool { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } c.setBusy(true) defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -66,6 +73,10 @@ func (c *Client) HealthCheck(ctx context.Context) bool { } func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } c.setBusy(true) defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -79,6 +90,10 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ... } func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } c.setBusy(true) defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -92,6 +107,10 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp } func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } c.setBusy(true) defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -104,6 +123,10 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp } func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } c.setBusy(true) defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -135,6 +158,10 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun } func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } c.setBusy(true) defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -147,6 +174,10 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, } func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } c.setBusy(true) defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -159,6 +190,10 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp } func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } c.setBusy(true) defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -191,6 +226,10 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques } func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } c.setBusy(true) defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -209,6 +248,10 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts } func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } c.setBusy(true) defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index c303e64d..22a18ed6 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -121,7 +121,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string // Wait for the service to start up ready := false for i := 0; i < o.grpcAttempts; i++ { - if client.GRPC().HealthCheck(context.Background()) { + if client.GRPC(o.parallelRequests).HealthCheck(context.Background()) { log.Debug().Msgf("GRPC Service Ready") ready = true break @@ -140,7 +140,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string log.Debug().Msgf("GRPC: Loading model with options: %+v", options) - res, err := client.GRPC().LoadModel(o.context, &options) + res, err := client.GRPC(o.parallelRequests).LoadModel(o.context, &options) if err != nil { return "", fmt.Errorf("could not load model: %w", err) } @@ -154,11 +154,11 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (*grpc.Client, error) { if parallel { - return addr.GRPC(), nil + return addr.GRPC(parallel), nil } if _, ok := ml.grpcClients[string(addr)]; !ok { - ml.grpcClients[string(addr)] = addr.GRPC() + ml.grpcClients[string(addr)] = addr.GRPC(parallel) } return ml.grpcClients[string(addr)], nil } diff --git a/pkg/model/loader.go b/pkg/model/loader.go index c9471f1c..60671301 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -67,8 +67,8 @@ type ModelLoader struct { type ModelAddress string -func (m ModelAddress) GRPC() *grpc.Client { - return grpc.NewClient(string(m)) +func (m ModelAddress) GRPC(parallel bool) *grpc.Client { + return grpc.NewClient(string(m), parallel) } func NewModelLoader(modelPath string) *ModelLoader { @@ -147,10 +147,16 @@ func (ml *ModelLoader) ShutdownModel(modelName string) error { } func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress { + var client *grpc.Client if m, ok := ml.models[s]; ok { log.Debug().Msgf("Model already loaded in memory: %s", s) + if c, ok := ml.grpcClients[s]; ok { + client = c + } else { + client = m.GRPC(false) + } - if !m.GRPC().HealthCheck(context.Background()) { + if !client.HealthCheck(context.Background()) { log.Debug().Msgf("GRPC Model not responding: %s", s) if !ml.grpcProcesses[s].IsAlive() { log.Debug().Msgf("GRPC Process is not responding: %s", s) diff --git a/pkg/model/process.go b/pkg/model/process.go index 7048499d..18f44a66 100644 --- a/pkg/model/process.go +++ b/pkg/model/process.go @@ -17,7 +17,7 @@ import ( func (ml *ModelLoader) StopAllExcept(s string) { ml.StopGRPC(func(id string, p *process.Process) bool { if id != s { - for ml.models[id].GRPC().IsBusy() { + for ml.models[id].GRPC(false).IsBusy() { log.Debug().Msgf("%s busy. Waiting.", id) time.Sleep(2 * time.Second) }