chore(refactor): drop duplicated shutdown logics (#3589)

* chore(refactor): drop duplicated shutdown logics

- Handle locking in Shutdown and CheckModelIsLoaded in a more go-idiomatic way
- Drop duplicated code and re-organize shutdown code

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix: drop leftover

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: improve logging and add missing locks

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-09-17 16:51:40 +02:00 committed by GitHub
parent eee1fb2c75
commit a53392f919
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 32 additions and 38 deletions

View File

@ -69,6 +69,6 @@ func RegisterLocalAIRoutes(app *fiber.App,
}{Version: internal.PrintableVersion()}) }{Version: internal.PrintableVersion()})
}) })
app.Get("/system", auth, localai.SystemInformations(ml, appConfig)) app.Get("/system", localai.SystemInformations(ml, appConfig))
} }

17
pkg/model/filters.go Normal file
View File

@ -0,0 +1,17 @@
package model
import (
process "github.com/mudler/go-processmanager"
)
type GRPCProcessFilter = func(id string, p *process.Process) bool
func all(_ string, _ *process.Process) bool {
return true
}
func allExcept(s string) GRPCProcessFilter {
return func(id string, p *process.Process) bool {
return id != s
}
}

View File

@ -320,7 +320,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
} else { } else {
grpcProcess := backendPath(o.assetDir, backend) grpcProcess := backendPath(o.assetDir, backend)
if err := utils.VerifyPath(grpcProcess, o.assetDir); err != nil { if err := utils.VerifyPath(grpcProcess, o.assetDir); err != nil {
return nil, fmt.Errorf("grpc process not found in assetdir: %s", err.Error()) return nil, fmt.Errorf("refering to a backend not in asset dir: %s", err.Error())
} }
if autoDetect { if autoDetect {
@ -332,7 +332,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
// Check if the file exists // Check if the file exists
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) { if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess) return nil, fmt.Errorf("backend not found: %s", grpcProcess)
} }
serverAddress, err := getFreeAddress() serverAddress, err := getFreeAddress()
@ -355,6 +355,8 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
client = NewModel(serverAddress) client = NewModel(serverAddress)
} }
log.Debug().Msgf("Wait for the service to start up")
// Wait for the service to start up // Wait for the service to start up
ready := false ready := false
for i := 0; i < o.grpcAttempts; i++ { for i := 0; i < o.grpcAttempts; i++ {
@ -413,10 +415,8 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e
} }
if o.singleActiveBackend { if o.singleActiveBackend {
ml.mu.Lock()
log.Debug().Msgf("Stopping all backends except '%s'", o.model) log.Debug().Msgf("Stopping all backends except '%s'", o.model)
err := ml.StopAllExcept(o.model) err := ml.StopGRPC(allExcept(o.model))
ml.mu.Unlock()
if err != nil { if err != nil {
log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel") log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel")
return nil, err return nil, err
@ -444,13 +444,10 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e
func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) { func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
o := NewOptions(opts...) o := NewOptions(opts...)
ml.mu.Lock()
// Return earlier if we have a model already loaded // Return earlier if we have a model already loaded
// (avoid looping through all the backends) // (avoid looping through all the backends)
if m := ml.CheckIsLoaded(o.model); m != nil { if m := ml.CheckIsLoaded(o.model); m != nil {
log.Debug().Msgf("Model '%s' already loaded", o.model) log.Debug().Msgf("Model '%s' already loaded", o.model)
ml.mu.Unlock()
return m.GRPC(o.parallelRequests, ml.wd), nil return m.GRPC(o.parallelRequests, ml.wd), nil
} }
@ -458,12 +455,11 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
// If we can have only one backend active, kill all the others (except external backends) // If we can have only one backend active, kill all the others (except external backends)
if o.singleActiveBackend { if o.singleActiveBackend {
log.Debug().Msgf("Stopping all backends except '%s'", o.model) log.Debug().Msgf("Stopping all backends except '%s'", o.model)
err := ml.StopAllExcept(o.model) err := ml.StopGRPC(allExcept(o.model))
if err != nil { if err != nil {
log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel - greedyloader continuing") log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel - greedyloader continuing")
} }
} }
ml.mu.Unlock()
var err error var err error

View File

@ -118,9 +118,6 @@ func (ml *ModelLoader) ListModels() []*Model {
} }
func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (*Model, error)) (*Model, error) { func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (*Model, error)) (*Model, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
// Check if we already have a loaded model // Check if we already have a loaded model
if model := ml.CheckIsLoaded(modelName); model != nil { if model := ml.CheckIsLoaded(modelName); model != nil {
return model, nil return model, nil
@ -139,6 +136,8 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
return nil, fmt.Errorf("loader didn't return a model") return nil, fmt.Errorf("loader didn't return a model")
} }
ml.mu.Lock()
defer ml.mu.Unlock()
ml.models[modelName] = model ml.models[modelName] = model
return model, nil return model, nil
@ -168,6 +167,8 @@ func (ml *ModelLoader) ShutdownModel(modelName string) error {
} }
func (ml *ModelLoader) CheckIsLoaded(s string) *Model { func (ml *ModelLoader) CheckIsLoaded(s string) *Model {
ml.mu.Lock()
defer ml.mu.Unlock()
m, ok := ml.models[s] m, ok := ml.models[s]
if !ok { if !ok {
return nil return nil

View File

@ -9,28 +9,12 @@ import (
"strconv" "strconv"
"strings" "strings"
"syscall" "syscall"
"time"
"github.com/hpcloud/tail" "github.com/hpcloud/tail"
process "github.com/mudler/go-processmanager" process "github.com/mudler/go-processmanager"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func (ml *ModelLoader) StopAllExcept(s string) error {
return ml.StopGRPC(func(id string, p *process.Process) bool {
if id == s {
return false
}
for ml.models[id].GRPC(false, ml.wd).IsBusy() {
log.Debug().Msgf("%s busy. Waiting.", id)
time.Sleep(2 * time.Second)
}
log.Debug().Msgf("[single-backend] Stopping %s", id)
return true
})
}
func (ml *ModelLoader) deleteProcess(s string) error { func (ml *ModelLoader) deleteProcess(s string) error {
if _, exists := ml.grpcProcesses[s]; exists { if _, exists := ml.grpcProcesses[s]; exists {
if err := ml.grpcProcesses[s].Stop(); err != nil { if err := ml.grpcProcesses[s].Stop(); err != nil {
@ -42,17 +26,11 @@ func (ml *ModelLoader) deleteProcess(s string) error {
return nil return nil
} }
type GRPCProcessFilter = func(id string, p *process.Process) bool
func includeAllProcesses(_ string, _ *process.Process) bool {
return true
}
func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) error { func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) error {
var err error = nil var err error = nil
for k, p := range ml.grpcProcesses { for k, p := range ml.grpcProcesses {
if filter(k, p) { if filter(k, p) {
e := ml.deleteProcess(k) e := ml.ShutdownModel(k)
err = errors.Join(err, e) err = errors.Join(err, e)
} }
} }
@ -60,10 +38,12 @@ func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) error {
} }
func (ml *ModelLoader) StopAllGRPC() error { func (ml *ModelLoader) StopAllGRPC() error {
return ml.StopGRPC(includeAllProcesses) return ml.StopGRPC(all)
} }
func (ml *ModelLoader) GetGRPCPID(id string) (int, error) { func (ml *ModelLoader) GetGRPCPID(id string) (int, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
p, exists := ml.grpcProcesses[id] p, exists := ml.grpcProcesses[id]
if !exists { if !exists {
return -1, fmt.Errorf("no grpc backend found for %s", id) return -1, fmt.Errorf("no grpc backend found for %s", id)