From a53392f91953bf53c77041a8cd25282cd65eb71a Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Tue, 17 Sep 2024 16:51:40 +0200 Subject: [PATCH] chore(refactor): drop duplicated shutdown logics (#3589) * chore(refactor): drop duplicated shutdown logics - Handle locking in Shutdown and CheckModelIsLoaded in a more go-idiomatic way - Drop duplicated code and re-organize shutdown code Signed-off-by: Ettore Di Giacinto * fix: drop leftover Signed-off-by: Ettore Di Giacinto * chore: improve logging and add missing locks Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- core/http/routes/localai.go | 2 +- pkg/model/filters.go | 17 +++++++++++++++++ pkg/model/initializers.go | 16 ++++++---------- pkg/model/loader.go | 7 ++++--- pkg/model/process.go | 28 ++++------------------------ 5 files changed, 32 insertions(+), 38 deletions(-) create mode 100644 pkg/model/filters.go diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 29fef378..247596c0 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -69,6 +69,6 @@ func RegisterLocalAIRoutes(app *fiber.App, }{Version: internal.PrintableVersion()}) }) - app.Get("/system", auth, localai.SystemInformations(ml, appConfig)) + app.Get("/system", localai.SystemInformations(ml, appConfig)) } diff --git a/pkg/model/filters.go b/pkg/model/filters.go new file mode 100644 index 00000000..79b72d5b --- /dev/null +++ b/pkg/model/filters.go @@ -0,0 +1,17 @@ +package model + +import ( + process "github.com/mudler/go-processmanager" +) + +type GRPCProcessFilter = func(id string, p *process.Process) bool + +func all(_ string, _ *process.Process) bool { + return true +} + +func allExcept(s string) GRPCProcessFilter { + return func(id string, p *process.Process) bool { + return id != s + } +} diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 3d2255cc..7099bf33 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -320,7 +320,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string } else { grpcProcess := backendPath(o.assetDir, backend) if err := utils.VerifyPath(grpcProcess, o.assetDir); err != nil { - return nil, fmt.Errorf("grpc process not found in assetdir: %s", err.Error()) + return nil, fmt.Errorf("refering to a backend not in asset dir: %s", err.Error()) } if autoDetect { @@ -332,7 +332,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string // Check if the file exists if _, err := os.Stat(grpcProcess); os.IsNotExist(err) { - return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess) + return nil, fmt.Errorf("backend not found: %s", grpcProcess) } serverAddress, err := getFreeAddress() @@ -355,6 +355,8 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string client = NewModel(serverAddress) } + log.Debug().Msgf("Wait for the service to start up") + // Wait for the service to start up ready := false for i := 0; i < o.grpcAttempts; i++ { @@ -413,10 +415,8 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e } if o.singleActiveBackend { - ml.mu.Lock() log.Debug().Msgf("Stopping all backends except '%s'", o.model) - err := ml.StopAllExcept(o.model) - ml.mu.Unlock() + err := ml.StopGRPC(allExcept(o.model)) if err != nil { log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel") return nil, err @@ -444,13 +444,10 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) { o := NewOptions(opts...) - ml.mu.Lock() - // Return earlier if we have a model already loaded // (avoid looping through all the backends) if m := ml.CheckIsLoaded(o.model); m != nil { log.Debug().Msgf("Model '%s' already loaded", o.model) - ml.mu.Unlock() return m.GRPC(o.parallelRequests, ml.wd), nil } @@ -458,12 +455,11 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) { // If we can have only one backend active, kill all the others (except external backends) if o.singleActiveBackend { log.Debug().Msgf("Stopping all backends except '%s'", o.model) - err := ml.StopAllExcept(o.model) + err := ml.StopGRPC(allExcept(o.model)) if err != nil { log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel - greedyloader continuing") } } - ml.mu.Unlock() var err error diff --git a/pkg/model/loader.go b/pkg/model/loader.go index b9865f73..f70d2cea 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -118,9 +118,6 @@ func (ml *ModelLoader) ListModels() []*Model { } func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (*Model, error)) (*Model, error) { - ml.mu.Lock() - defer ml.mu.Unlock() - // Check if we already have a loaded model if model := ml.CheckIsLoaded(modelName); model != nil { return model, nil @@ -139,6 +136,8 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) ( return nil, fmt.Errorf("loader didn't return a model") } + ml.mu.Lock() + defer ml.mu.Unlock() ml.models[modelName] = model return model, nil @@ -168,6 +167,8 @@ func (ml *ModelLoader) ShutdownModel(modelName string) error { } func (ml *ModelLoader) CheckIsLoaded(s string) *Model { + ml.mu.Lock() + defer ml.mu.Unlock() m, ok := ml.models[s] if !ok { return nil diff --git a/pkg/model/process.go b/pkg/model/process.go index 50afbb1c..bcd1fccb 100644 --- a/pkg/model/process.go +++ b/pkg/model/process.go @@ -9,28 +9,12 @@ import ( "strconv" "strings" "syscall" - "time" "github.com/hpcloud/tail" process "github.com/mudler/go-processmanager" "github.com/rs/zerolog/log" ) -func (ml *ModelLoader) StopAllExcept(s string) error { - return ml.StopGRPC(func(id string, p *process.Process) bool { - if id == s { - return false - } - - for ml.models[id].GRPC(false, ml.wd).IsBusy() { - log.Debug().Msgf("%s busy. Waiting.", id) - time.Sleep(2 * time.Second) - } - log.Debug().Msgf("[single-backend] Stopping %s", id) - return true - }) -} - func (ml *ModelLoader) deleteProcess(s string) error { if _, exists := ml.grpcProcesses[s]; exists { if err := ml.grpcProcesses[s].Stop(); err != nil { @@ -42,17 +26,11 @@ func (ml *ModelLoader) deleteProcess(s string) error { return nil } -type GRPCProcessFilter = func(id string, p *process.Process) bool - -func includeAllProcesses(_ string, _ *process.Process) bool { - return true -} - func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) error { var err error = nil for k, p := range ml.grpcProcesses { if filter(k, p) { - e := ml.deleteProcess(k) + e := ml.ShutdownModel(k) err = errors.Join(err, e) } } @@ -60,10 +38,12 @@ func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) error { } func (ml *ModelLoader) StopAllGRPC() error { - return ml.StopGRPC(includeAllProcesses) + return ml.StopGRPC(all) } func (ml *ModelLoader) GetGRPCPID(id string) (int, error) { + ml.mu.Lock() + defer ml.mu.Unlock() p, exists := ml.grpcProcesses[id] if !exists { return -1, fmt.Errorf("no grpc backend found for %s", id)