chore(refactor): drop unnecessary code in loader (#4096)

* chore: simplify passing options to ModelOptions

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore(refactor): do not expose internal backend Loader

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-11-08 21:54:25 +01:00 committed by GitHub
parent a0cdd19038
commit 6daef00d30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 39 additions and 73 deletions

View File

@ -11,17 +11,9 @@ import (
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
var inferenceModel interface{} opts := ModelOptions(backendConfig, appConfig)
var err error
opts := ModelOptions(backendConfig, appConfig, []model.Option{}) inferenceModel, err := loader.Load(opts...)
if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.BackendLoader(opts...)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -9,9 +9,8 @@ import (
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
opts := ModelOptions(backendConfig, appConfig, []model.Option{}) opts := ModelOptions(backendConfig, appConfig)
inferenceModel, err := loader.Load(
inferenceModel, err := loader.BackendLoader(
opts..., opts...,
) )
if err != nil { if err != nil {

View File

@ -16,7 +16,6 @@ import (
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model" model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/pkg/utils"
@ -35,15 +34,6 @@ type TokenUsage struct {
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model modelFile := c.Model
var inferenceModel grpc.Backend
var err error
opts := ModelOptions(c, o, []model.Option{})
if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))
}
// Check if the modelFile exists, if it doesn't try to load it from the gallery // Check if the modelFile exists, if it doesn't try to load it from the gallery
if o.AutoloadGalleries { // experimental if o.AutoloadGalleries { // experimental
if _, err := os.Stat(modelFile); os.IsNotExist(err) { if _, err := os.Stat(modelFile); os.IsNotExist(err) {
@ -56,12 +46,8 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
} }
} }
if c.Backend == "" { opts := ModelOptions(c, o)
inferenceModel, err = loader.GreedyLoader(opts...) inferenceModel, err := loader.Load(opts...)
} else {
inferenceModel, err = loader.BackendLoader(opts...)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -11,7 +11,7 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option { func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
name := c.Name name := c.Name
if name == "" { if name == "" {
name = c.Model name = c.Model

View File

@ -11,8 +11,8 @@ import (
func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) { func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)}) opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
rerankModel, err := loader.BackendLoader(opts...) rerankModel, err := loader.Load(opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -25,9 +25,8 @@ func SoundGeneration(
backendConfig config.BackendConfig, backendConfig config.BackendConfig,
) (string, *proto.Result, error) { ) (string, *proto.Result, error) {
opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)}) opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
soundGenModel, err := loader.Load(opts...)
soundGenModel, err := loader.BackendLoader(opts...)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }

View File

@ -8,16 +8,15 @@ import (
) )
func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string) (grpc.Backend, error) { func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string) (grpc.Backend, error) {
if storeName == "" { if storeName == "" {
storeName = "default" storeName = "default"
} }
sc := []model.Option{ sc := []model.Option{
model.WithBackendString(model.LocalStoreBackend), model.WithBackendString(model.LocalStoreBackend),
model.WithAssetDir(appConfig.AssetsDestination), model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(storeName), model.WithModel(storeName),
} }
return sl.BackendLoader(sc...) return sl.Load(sc...)
} }

View File

@ -15,10 +15,8 @@ func TokenMetrics(
appConfig *config.ApplicationConfig, appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig) (*proto.MetricsResponse, error) { backendConfig config.BackendConfig) (*proto.MetricsResponse, error) {
opts := ModelOptions(backendConfig, appConfig, []model.Option{ opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
model.WithModel(modelFile), model, err := loader.Load(opts...)
})
model, err := loader.BackendLoader(opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -14,15 +14,13 @@ func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.Bac
var inferenceModel grpc.Backend var inferenceModel grpc.Backend
var err error var err error
opts := ModelOptions(backendConfig, appConfig, []model.Option{ opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
model.WithModel(modelFile),
})
if backendConfig.Backend == "" { if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...) inferenceModel, err = loader.Load(opts...)
} else { } else {
opts = append(opts, model.WithBackendString(backendConfig.Backend)) opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.BackendLoader(opts...) inferenceModel, err = loader.Load(opts...)
} }
if err != nil { if err != nil {
return schema.TokenizeResponse{}, err return schema.TokenizeResponse{}, err

View File

@ -18,9 +18,9 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
backendConfig.Backend = model.WhisperBackend backendConfig.Backend = model.WhisperBackend
} }
opts := ModelOptions(backendConfig, appConfig, []model.Option{}) opts := ModelOptions(backendConfig, appConfig)
transcriptionModel, err := ml.BackendLoader(opts...) transcriptionModel, err := ml.Load(opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -28,12 +28,8 @@ func ModelTTS(
bb = model.PiperBackend bb = model.PiperBackend
} }
opts := ModelOptions(backendConfig, appConfig, []model.Option{ opts := ModelOptions(backendConfig, appConfig, model.WithBackendString(bb), model.WithModel(modelFile))
model.WithBackendString(bb), ttsModel, err := loader.Load(opts...)
model.WithModel(modelFile),
})
ttsModel, err := loader.BackendLoader(opts...)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }

View File

@ -160,15 +160,10 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model) log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model)
o := backend.ModelOptions(*cfg, options, []model.Option{}) o := backend.ModelOptions(*cfg, options)
var backendErr error var backendErr error
if cfg.Backend != "" { _, backendErr = ml.Load(o...)
o = append(o, model.WithBackendString(cfg.Backend))
_, backendErr = ml.BackendLoader(o...)
} else {
_, backendErr = ml.GreedyLoader(o...)
}
if backendErr != nil { if backendErr != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }

View File

@ -455,7 +455,7 @@ func (ml *ModelLoader) ListAvailableBackends(assetdir string) ([]string, error)
return orderBackends(backends) return orderBackends(backends)
} }
func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) { func (ml *ModelLoader) backendLoader(opts ...Option) (client grpc.Backend, err error) {
o := NewOptions(opts...) o := NewOptions(opts...)
log.Info().Msgf("Loading model '%s' with backend %s", o.modelID, o.backendString) log.Info().Msgf("Loading model '%s' with backend %s", o.modelID, o.backendString)
@ -500,7 +500,7 @@ func (ml *ModelLoader) stopActiveBackends(modelID string, singleActiveBackend bo
} }
} }
func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) { func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
o := NewOptions(opts...) o := NewOptions(opts...)
// Return earlier if we have a model already loaded // Return earlier if we have a model already loaded
@ -513,6 +513,10 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
ml.stopActiveBackends(o.modelID, o.singleActiveBackend) ml.stopActiveBackends(o.modelID, o.singleActiveBackend)
if o.backendString != "" {
return ml.backendLoader(opts...)
}
var err error var err error
// get backends embedded in the binary // get backends embedded in the binary
@ -536,7 +540,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
WithBackendString(key), WithBackendString(key),
}...) }...)
model, modelerr := ml.BackendLoader(options...) model, modelerr := ml.backendLoader(options...)
if modelerr == nil && model != nil { if modelerr == nil && model != nil {
log.Info().Msgf("[%s] Loads OK", key) log.Info().Msgf("[%s] Loads OK", key)
return model, nil return model, nil

View File

@ -57,7 +57,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
} }
sl = model.NewModelLoader("") sl = model.NewModelLoader("")
sc, err = sl.BackendLoader(storeOpts...) sc, err = sl.Load(storeOpts...)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(sc).ToNot(BeNil()) Expect(sc).ToNot(BeNil())
}) })