From 10b0e13882892db6df6273f9a555a9c148f6a59f Mon Sep 17 00:00:00 2001 From: Dave Date: Wed, 23 Aug 2023 12:38:37 -0400 Subject: [PATCH] feat: backend monitor shutdown endpoint, process based (#938) This PR adds a new endpoint to the backend monitor section `/backend/shutdown` which terminates the grpc process for the related model. --- api/api.go | 1 + api/localai/backend_monitor.go | 59 +++++++++++++++++++++++----------- pkg/model/loader.go | 10 ++++++ 3 files changed, 51 insertions(+), 19 deletions(-) diff --git a/api/api.go b/api/api.go index 57cf968f..c07077d9 100644 --- a/api/api.go +++ b/api/api.go @@ -218,6 +218,7 @@ func App(opts ...options.AppOption) (*fiber.App, error) { // Experimental Backend Statistics Module backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor)) + app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor)) // models app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) diff --git a/api/localai/backend_monitor.go b/api/localai/backend_monitor.go index f723cddf..e8b53556 100644 --- a/api/localai/backend_monitor.go +++ b/api/localai/backend_monitor.go @@ -92,39 +92,49 @@ func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonit }, nil } +func (bm BackendMonitor) getModelLoaderIDFromCtx(c *fiber.Ctx) (string, error) { + input := new(BackendMonitorRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return "", err + } + + config, exists := bm.configLoader.GetConfig(input.Model) + var backendId string + if exists { + backendId = config.Model + } else { + // Last ditch effort: use it raw, see if a backend happens to match. + backendId = input.Model + } + + if !strings.HasSuffix(backendId, ".bin") { + backendId = fmt.Sprintf("%s.bin", backendId) + } + + return backendId, nil +} + func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - input := new(BackendMonitorRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { + + backendId, err := bm.getModelLoaderIDFromCtx(c) + if err != nil { return err } - config, exists := bm.configLoader.GetConfig(input.Model) - var backendId string - if exists { - backendId = config.Model - } else { - // Last ditch effort: use it raw, see if a backend happens to match. - backendId = input.Model - } - - if !strings.HasSuffix(backendId, ".bin") { - backendId = fmt.Sprintf("%s.bin", backendId) - } - client := bm.options.Loader.CheckIsLoaded(backendId) if client == nil { - return fmt.Errorf("backend %s is not currently loaded", input.Model) + return fmt.Errorf("backend %s is not currently loaded", backendId) } status, rpcErr := client.Status(context.TODO()) if rpcErr != nil { - log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", input.Model, rpcErr.Error()) + log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error()) val, slbErr := bm.SampleLocalBackendProcess(backendId) if slbErr != nil { - return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", input.Model, rpcErr.Error(), slbErr.Error()) + return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error()) } return c.JSON(proto.StatusResponse{ State: proto.StatusResponse_ERROR, @@ -140,3 +150,14 @@ func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error { return c.JSON(status) } } + +func BackendShutdownEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + backendId, err := bm.getModelLoaderIDFromCtx(c) + if err != nil { + return err + } + + return bm.options.Loader.ShutdownModel(backendId) + } +} diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 8d129e46..e4a4437c 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -128,6 +128,16 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) ( return model, nil } +func (ml *ModelLoader) ShutdownModel(modelName string) error { + ml.mu.Lock() + defer ml.mu.Unlock() + if _, ok := ml.models[modelName]; !ok { + return fmt.Errorf("model %s not found", modelName) + } + + return ml.deleteProcess(modelName) +} + func (ml *ModelLoader) CheckIsLoaded(s string) *grpc.Client { if m, ok := ml.models[s]; ok { log.Debug().Msgf("Model already loaded in memory: %s", s)