feat: track internally started models by ID (#3693)

* chore(refactor): track internally started models by ID

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Just extend options, no need to copy

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Improve debugging for rerankers failures

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Simplify model loading with rerankers

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Be more consistent when generating model options

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Uncommitted code

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Make deleteProcess more idiomatic

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Adapt CLI for sound generation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Fixup threads definition

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Handle corner case where c.Seed is nil

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Consistently use ModelOptions

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Adapt new code to refactoring

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Dave <dave@gray101.com>
This commit is contained in:
Ettore Di Giacinto 2024-10-02 08:55:58 +02:00 committed by GitHub
parent db704199dc
commit 0965c6cd68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 169 additions and 185 deletions

View File

@ -10,20 +10,11 @@ import (
) )
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
modelFile := backendConfig.Model
grpcOpts := GRPCModelOpts(backendConfig)
var inferenceModel interface{} var inferenceModel interface{}
var err error var err error
opts := modelOpts(backendConfig, appConfig, []model.Option{ opts := ModelOptions(backendConfig, appConfig, []model.Option{})
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(*backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
})
if backendConfig.Backend == "" { if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...) inferenceModel, err = loader.GreedyLoader(opts...)

View File

@ -8,19 +8,8 @@ import (
) )
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
threads := backendConfig.Threads
if *threads == 0 && appConfig.Threads != 0 { opts := ModelOptions(backendConfig, appConfig, []model.Option{})
threads = &appConfig.Threads
}
gRPCOpts := GRPCModelOpts(backendConfig)
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(backendConfig.Backend),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithThreads(uint32(*threads)),
model.WithContext(appConfig.Context),
model.WithModel(backendConfig.Model),
model.WithLoadGRPCLoadModelOpts(gRPCOpts),
})
inferenceModel, err := loader.BackendLoader( inferenceModel, err := loader.BackendLoader(
opts..., opts...,

View File

@ -33,22 +33,11 @@ type TokenUsage struct {
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model modelFile := c.Model
threads := c.Threads
if *threads == 0 && o.Threads != 0 {
threads = &o.Threads
}
grpcOpts := GRPCModelOpts(c)
var inferenceModel grpc.Backend var inferenceModel grpc.Backend
var err error var err error
opts := modelOpts(c, o, []model.Option{ opts := ModelOptions(c, o, []model.Option{})
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
})
if c.Backend != "" { if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend)) opts = append(opts, model.WithBackendString(c.Backend))

View File

@ -11,32 +11,65 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option { func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
name := c.Name
if name == "" {
name = c.Model
}
defOpts := []model.Option{
model.WithBackendString(c.Backend),
model.WithModel(c.Model),
model.WithAssetDir(so.AssetsDestination),
model.WithContext(so.Context),
model.WithModelID(name),
}
threads := 1
if c.Threads != nil {
threads = *c.Threads
}
if so.Threads != 0 {
threads = so.Threads
}
c.Threads = &threads
grpcOpts := grpcModelOpts(c)
defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts))
if so.SingleBackend { if so.SingleBackend {
opts = append(opts, model.WithSingleActiveBackend()) defOpts = append(defOpts, model.WithSingleActiveBackend())
} }
if so.ParallelBackendRequests { if so.ParallelBackendRequests {
opts = append(opts, model.EnableParallelRequests) defOpts = append(defOpts, model.EnableParallelRequests)
} }
if c.GRPC.Attempts != 0 { if c.GRPC.Attempts != 0 {
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts)) defOpts = append(defOpts, model.WithGRPCAttempts(c.GRPC.Attempts))
} }
if c.GRPC.AttemptsSleepTime != 0 { if c.GRPC.AttemptsSleepTime != 0 {
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) defOpts = append(defOpts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
} }
for k, v := range so.ExternalGRPCBackends { for k, v := range so.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v)) defOpts = append(defOpts, model.WithExternalBackend(k, v))
} }
return opts return append(defOpts, opts...)
} }
func getSeed(c config.BackendConfig) int32 { func getSeed(c config.BackendConfig) int32 {
seed := int32(*c.Seed) var seed int32 = config.RAND_SEED
if c.Seed != nil {
seed = int32(*c.Seed)
}
if seed == config.RAND_SEED { if seed == config.RAND_SEED {
seed = rand.Int31() seed = rand.Int31()
} }
@ -44,11 +77,47 @@ func getSeed(c config.BackendConfig) int32 {
return seed return seed
} }
func GRPCModelOpts(c config.BackendConfig) *pb.ModelOptions { func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
b := 512 b := 512
if c.Batch != 0 { if c.Batch != 0 {
b = c.Batch b = c.Batch
} }
f16 := false
if c.F16 != nil {
f16 = *c.F16
}
embeddings := false
if c.Embeddings != nil {
embeddings = *c.Embeddings
}
lowVRAM := false
if c.LowVRAM != nil {
lowVRAM = *c.LowVRAM
}
mmap := false
if c.MMap != nil {
mmap = *c.MMap
}
ctxSize := 1024
if c.ContextSize != nil {
ctxSize = *c.ContextSize
}
mmlock := false
if c.MMlock != nil {
mmlock = *c.MMlock
}
nGPULayers := 9999999
if c.NGPULayers != nil {
nGPULayers = *c.NGPULayers
}
return &pb.ModelOptions{ return &pb.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA, CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType, SchedulerType: c.Diffusers.SchedulerType,
@ -56,14 +125,14 @@ func GRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
CFGScale: c.Diffusers.CFGScale, CFGScale: c.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter, LoraAdapter: c.LoraAdapter,
LoraScale: c.LoraScale, LoraScale: c.LoraScale,
F16Memory: *c.F16, F16Memory: f16,
LoraBase: c.LoraBase, LoraBase: c.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG, IMG2IMG: c.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel, CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder, CLIPSubfolder: c.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip), CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet, ControlNet: c.Diffusers.ControlNet,
ContextSize: int32(*c.ContextSize), ContextSize: int32(ctxSize),
Seed: getSeed(c), Seed: getSeed(c),
NBatch: int32(b), NBatch: int32(b),
NoMulMatQ: c.NoMulMatQ, NoMulMatQ: c.NoMulMatQ,
@ -85,16 +154,16 @@ func GRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
YarnBetaSlow: c.YarnBetaSlow, YarnBetaSlow: c.YarnBetaSlow,
NGQA: c.NGQA, NGQA: c.NGQA,
RMSNormEps: c.RMSNormEps, RMSNormEps: c.RMSNormEps,
MLock: *c.MMlock, MLock: mmlock,
RopeFreqBase: c.RopeFreqBase, RopeFreqBase: c.RopeFreqBase,
RopeScaling: c.RopeScaling, RopeScaling: c.RopeScaling,
Type: c.ModelType, Type: c.ModelType,
RopeFreqScale: c.RopeFreqScale, RopeFreqScale: c.RopeFreqScale,
NUMA: c.NUMA, NUMA: c.NUMA,
Embeddings: *c.Embeddings, Embeddings: embeddings,
LowVRAM: *c.LowVRAM, LowVRAM: lowVRAM,
NGPULayers: int32(*c.NGPULayers), NGPULayers: int32(nGPULayers),
MMap: *c.MMap, MMap: mmap,
MainGPU: c.MainGPU, MainGPU: c.MainGPU,
Threads: int32(*c.Threads), Threads: int32(*c.Threads),
TensorSplit: c.TensorSplit, TensorSplit: c.TensorSplit,

View File

@ -9,21 +9,9 @@ import (
model "github.com/mudler/LocalAI/pkg/model" model "github.com/mudler/LocalAI/pkg/model"
) )
func Rerank(backend, modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) { func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
bb := backend
if bb == "" {
return nil, fmt.Errorf("backend is required")
}
grpcOpts := GRPCModelOpts(backendConfig) opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
rerankModel, err := loader.BackendLoader(opts...) rerankModel, err := loader.BackendLoader(opts...)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -13,7 +13,6 @@ import (
) )
func SoundGeneration( func SoundGeneration(
backend string,
modelFile string, modelFile string,
text string, text string,
duration *float32, duration *float32,
@ -25,18 +24,8 @@ func SoundGeneration(
appConfig *config.ApplicationConfig, appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig, backendConfig config.BackendConfig,
) (string, *proto.Result, error) { ) (string, *proto.Result, error) {
if backend == "" {
return "", nil, fmt.Errorf("backend is a required parameter")
}
grpcOpts := GRPCModelOpts(backendConfig) opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(backend),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
soundGenModel, err := loader.BackendLoader(opts...) soundGenModel, err := loader.BackendLoader(opts...)
if err != nil { if err != nil {

View File

@ -10,24 +10,13 @@ import (
) )
func TokenMetrics( func TokenMetrics(
backend,
modelFile string, modelFile string,
loader *model.ModelLoader, loader *model.ModelLoader,
appConfig *config.ApplicationConfig, appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig) (*proto.MetricsResponse, error) { backendConfig config.BackendConfig) (*proto.MetricsResponse, error) {
bb := backend
if bb == "" {
return nil, fmt.Errorf("backend is required")
}
grpcOpts := GRPCModelOpts(backendConfig) opts := ModelOptions(backendConfig, appConfig, []model.Option{
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile), model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
}) })
model, err := loader.BackendLoader(opts...) model, err := loader.BackendLoader(opts...)
if err != nil { if err != nil {

View File

@ -14,13 +14,11 @@ import (
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
opts := modelOpts(backendConfig, appConfig, []model.Option{ if backendConfig.Backend == "" {
model.WithBackendString(model.WhisperBackend), backendConfig.Backend = model.WhisperBackend
model.WithModel(backendConfig.Model), }
model.WithContext(appConfig.Context),
model.WithThreads(uint32(*backendConfig.Threads)), opts := ModelOptions(backendConfig, appConfig, []model.Option{})
model.WithAssetDir(appConfig.AssetsDestination),
})
transcriptionModel, err := ml.BackendLoader(opts...) transcriptionModel, err := ml.BackendLoader(opts...)
if err != nil { if err != nil {

View File

@ -28,14 +28,9 @@ func ModelTTS(
bb = model.PiperBackend bb = model.PiperBackend
} }
grpcOpts := GRPCModelOpts(backendConfig) opts := ModelOptions(config.BackendConfig{}, appConfig, []model.Option{
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb), model.WithBackendString(bb),
model.WithModel(modelFile), model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
}) })
ttsModel, err := loader.BackendLoader(opts...) ttsModel, err := loader.BackendLoader(opts...)
if err != nil { if err != nil {

View File

@ -85,13 +85,14 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
options := config.BackendConfig{} options := config.BackendConfig{}
options.SetDefaults() options.SetDefaults()
options.Backend = t.Backend
var inputFile *string var inputFile *string
if t.InputFile != "" { if t.InputFile != "" {
inputFile = &t.InputFile inputFile = &t.InputFile
} }
filePath, _, err := backend.SoundGeneration(t.Backend, t.Model, text, filePath, _, err := backend.SoundGeneration(t.Model, text,
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample, parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options) inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options)

View File

@ -55,7 +55,7 @@ func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
} }
// TODO: Support uploading files? // TODO: Support uploading files?
filePath, _, err := backend.SoundGeneration(cfg.Backend, modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg) filePath, _, err := backend.SoundGeneration(modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
if err != nil { if err != nil {
return err return err
} }

View File

@ -45,13 +45,13 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
config.LoadOptionContextSize(appConfig.ContextSize), config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16), config.LoadOptionF16(appConfig.F16),
) )
if err != nil { if err != nil {
modelFile = input.Model modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model) log.Warn().Msgf("Model not found in context: %s", input.Model)
} else { } else {
modelFile = cfg.Model modelFile = cfg.Model
} }
log.Debug().Msgf("Request for model: %s", modelFile) log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" { if input.Backend != "" {
@ -64,7 +64,7 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
Documents: req.Documents, Documents: req.Documents,
} }
results, err := backend.Rerank(cfg.Backend, modelFile, request, ml, appConfig, *cfg) results, err := backend.Rerank(modelFile, request, ml, appConfig, *cfg)
if err != nil { if err != nil {
return err return err
} }

View File

@ -51,7 +51,7 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader,
} }
log.Debug().Msgf("Token Metrics for model: %s", modelFile) log.Debug().Msgf("Token Metrics for model: %s", modelFile)
response, err := backend.TokenMetrics(cfg.Backend, modelFile, ml, appConfig, *cfg) response, err := backend.TokenMetrics(modelFile, ml, appConfig, *cfg)
if err != nil { if err != nil {
return err return err
} }

View File

@ -160,13 +160,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model) log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model)
grpcOpts := backend.GRPCModelOpts(*cfg) o := backend.ModelOptions(*cfg, options, []model.Option{})
o := []model.Option{
model.WithModel(cfg.Model),
model.WithAssetDir(options.AssetsDestination),
model.WithThreads(uint32(options.Threads)),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
}
var backendErr error var backendErr error
if cfg.Backend != "" { if cfg.Backend != "" {

View File

@ -268,10 +268,10 @@ func selectGRPCProcess(backend, assetDir string, f16 bool) string {
// starts the grpcModelProcess for the backend, and returns a grpc client // starts the grpcModelProcess for the backend, and returns a grpc client
// It also loads the model // It also loads the model
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (*Model, error) { func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string, string) (*Model, error) {
return func(modelName, modelFile string) (*Model, error) { return func(modelID, modelName, modelFile string) (*Model, error) {
log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o) log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelID, modelFile, backend, *o)
var client *Model var client *Model
@ -304,7 +304,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error()) return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
} }
// Make sure the process is executable // Make sure the process is executable
process, err := ml.startProcess(uri, o.model, serverAddress) process, err := ml.startProcess(uri, modelID, serverAddress)
if err != nil { if err != nil {
log.Error().Err(err).Str("path", uri).Msg("failed to launch ") log.Error().Err(err).Str("path", uri).Msg("failed to launch ")
return nil, err return nil, err
@ -312,11 +312,11 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
log.Debug().Msgf("GRPC Service Started") log.Debug().Msgf("GRPC Service Started")
client = NewModel(modelName, serverAddress, process) client = NewModel(modelID, serverAddress, process)
} else { } else {
log.Debug().Msg("external backend is uri") log.Debug().Msg("external backend is uri")
// address // address
client = NewModel(modelName, uri, nil) client = NewModel(modelID, uri, nil)
} }
} else { } else {
grpcProcess := backendPath(o.assetDir, backend) grpcProcess := backendPath(o.assetDir, backend)
@ -347,14 +347,14 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
args, grpcProcess = library.LoadLDSO(o.assetDir, args, grpcProcess) args, grpcProcess = library.LoadLDSO(o.assetDir, args, grpcProcess)
// Make sure the process is executable in any circumstance // Make sure the process is executable in any circumstance
process, err := ml.startProcess(grpcProcess, o.model, serverAddress, args...) process, err := ml.startProcess(grpcProcess, modelID, serverAddress, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debug().Msgf("GRPC Service Started") log.Debug().Msgf("GRPC Service Started")
client = NewModel(modelName, serverAddress, process) client = NewModel(modelID, serverAddress, process)
} }
log.Debug().Msgf("Wait for the service to start up") log.Debug().Msgf("Wait for the service to start up")
@ -407,11 +407,7 @@ func (ml *ModelLoader) ListAvailableBackends(assetdir string) ([]string, error)
func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) { func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) {
o := NewOptions(opts...) o := NewOptions(opts...)
if o.model != "" { log.Info().Msgf("Loading model '%s' with backend %s", o.modelID, o.backendString)
log.Info().Msgf("Loading model '%s' with backend %s", o.model, o.backendString)
} else {
log.Info().Msgf("Loading model with backend %s", o.backendString)
}
backend := strings.ToLower(o.backendString) backend := strings.ToLower(o.backendString)
if realBackend, exists := Aliases[backend]; exists { if realBackend, exists := Aliases[backend]; exists {
@ -420,10 +416,10 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e
} }
if o.singleActiveBackend { if o.singleActiveBackend {
log.Debug().Msgf("Stopping all backends except '%s'", o.model) log.Debug().Msgf("Stopping all backends except '%s'", o.modelID)
err := ml.StopGRPC(allExcept(o.model)) err := ml.StopGRPC(allExcept(o.modelID))
if err != nil { if err != nil {
log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel") log.Error().Err(err).Str("keptModel", o.modelID).Msg("error while shutting down all backends except for the keptModel")
} }
} }
@ -437,7 +433,7 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e
backendToConsume = backend backendToConsume = backend
} }
model, err := ml.LoadModel(o.model, ml.grpcModel(backendToConsume, o)) model, err := ml.LoadModel(o.modelID, o.model, ml.grpcModel(backendToConsume, o))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -450,18 +446,18 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
// Return earlier if we have a model already loaded // Return earlier if we have a model already loaded
// (avoid looping through all the backends) // (avoid looping through all the backends)
if m := ml.CheckIsLoaded(o.model); m != nil { if m := ml.CheckIsLoaded(o.modelID); m != nil {
log.Debug().Msgf("Model '%s' already loaded", o.model) log.Debug().Msgf("Model '%s' already loaded", o.modelID)
return m.GRPC(o.parallelRequests, ml.wd), nil return m.GRPC(o.parallelRequests, ml.wd), nil
} }
// If we can have only one backend active, kill all the others (except external backends) // If we can have only one backend active, kill all the others (except external backends)
if o.singleActiveBackend { if o.singleActiveBackend {
log.Debug().Msgf("Stopping all backends except '%s'", o.model) log.Debug().Msgf("Stopping all backends except '%s'", o.modelID)
err := ml.StopGRPC(allExcept(o.model)) err := ml.StopGRPC(allExcept(o.modelID))
if err != nil { if err != nil {
log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel - greedyloader continuing") log.Error().Err(err).Str("keptModel", o.modelID).Msg("error while shutting down all backends except for the keptModel - greedyloader continuing")
} }
} }
@ -480,23 +476,13 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
log.Debug().Msgf("Loading from the following backends (in order): %+v", autoLoadBackends) log.Debug().Msgf("Loading from the following backends (in order): %+v", autoLoadBackends)
if o.model != "" { log.Info().Msgf("Trying to load the model '%s' with the backend '%s'", o.modelID, autoLoadBackends)
log.Info().Msgf("Trying to load the model '%s' with the backend '%s'", o.model, autoLoadBackends)
}
for _, key := range autoLoadBackends { for _, key := range autoLoadBackends {
log.Info().Msgf("[%s] Attempting to load", key) log.Info().Msgf("[%s] Attempting to load", key)
options := []Option{ options := append(opts, []Option{
WithBackendString(key), WithBackendString(key),
WithModel(o.model), }...)
WithLoadGRPCLoadModelOpts(o.gRPCOptions),
WithThreads(o.threads),
WithAssetDir(o.assetDir),
}
for k, v := range o.externalBackends {
options = append(options, WithExternalBackend(k, v))
}
model, modelerr := ml.BackendLoader(options...) model, modelerr := ml.BackendLoader(options...)
if modelerr == nil && model != nil { if modelerr == nil && model != nil {

View File

@ -114,9 +114,9 @@ func (ml *ModelLoader) ListModels() []Model {
return models return models
} }
func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (*Model, error)) (*Model, error) { func (ml *ModelLoader) LoadModel(modelID, modelName string, loader func(string, string, string) (*Model, error)) (*Model, error) {
// Check if we already have a loaded model // Check if we already have a loaded model
if model := ml.CheckIsLoaded(modelName); model != nil { if model := ml.CheckIsLoaded(modelID); model != nil {
return model, nil return model, nil
} }
@ -126,7 +126,7 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
ml.mu.Lock() ml.mu.Lock()
defer ml.mu.Unlock() defer ml.mu.Unlock()
model, err := loader(modelName, modelFile) model, err := loader(modelID, modelName, modelFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -135,7 +135,7 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
return nil, fmt.Errorf("loader didn't return a model") return nil, fmt.Errorf("loader didn't return a model")
} }
ml.models[modelName] = model ml.models[modelID] = model
return model, nil return model, nil
} }

View File

@ -65,22 +65,22 @@ var _ = Describe("ModelLoader", func() {
It("should load a model and keep it in memory", func() { It("should load a model and keep it in memory", func() {
mockModel = model.NewModel("foo", "test.model", nil) mockModel = model.NewModel("foo", "test.model", nil)
mockLoader := func(modelName, modelFile string) (*model.Model, error) { mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
return mockModel, nil return mockModel, nil
} }
model, err := modelLoader.LoadModel("test.model", mockLoader) model, err := modelLoader.LoadModel("foo", "test.model", mockLoader)
Expect(err).To(BeNil()) Expect(err).To(BeNil())
Expect(model).To(Equal(mockModel)) Expect(model).To(Equal(mockModel))
Expect(modelLoader.CheckIsLoaded("test.model")).To(Equal(mockModel)) Expect(modelLoader.CheckIsLoaded("foo")).To(Equal(mockModel))
}) })
It("should return an error if loading the model fails", func() { It("should return an error if loading the model fails", func() {
mockLoader := func(modelName, modelFile string) (*model.Model, error) { mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
return nil, errors.New("failed to load model") return nil, errors.New("failed to load model")
} }
model, err := modelLoader.LoadModel("test.model", mockLoader) model, err := modelLoader.LoadModel("foo", "test.model", mockLoader)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(model).To(BeNil()) Expect(model).To(BeNil())
}) })
@ -88,18 +88,16 @@ var _ = Describe("ModelLoader", func() {
Context("ShutdownModel", func() { Context("ShutdownModel", func() {
It("should shutdown a loaded model", func() { It("should shutdown a loaded model", func() {
mockModel = model.NewModel("foo", "test.model", nil) mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
return model.NewModel("foo", "test.model", nil), nil
mockLoader := func(modelName, modelFile string) (*model.Model, error) {
return mockModel, nil
} }
_, err := modelLoader.LoadModel("test.model", mockLoader) _, err := modelLoader.LoadModel("foo", "test.model", mockLoader)
Expect(err).To(BeNil()) Expect(err).To(BeNil())
err = modelLoader.ShutdownModel("test.model") err = modelLoader.ShutdownModel("foo")
Expect(err).To(BeNil()) Expect(err).To(BeNil())
Expect(modelLoader.CheckIsLoaded("test.model")).To(BeNil()) Expect(modelLoader.CheckIsLoaded("foo")).To(BeNil())
}) })
}) })
}) })

View File

@ -9,7 +9,7 @@ import (
type Options struct { type Options struct {
backendString string backendString string
model string model string
threads uint32 modelID string
assetDir string assetDir string
context context.Context context context.Context
@ -68,12 +68,6 @@ func WithLoadGRPCLoadModelOpts(opts *pb.ModelOptions) Option {
} }
} }
func WithThreads(threads uint32) Option {
return func(o *Options) {
o.threads = threads
}
}
func WithAssetDir(assetDir string) Option { func WithAssetDir(assetDir string) Option {
return func(o *Options) { return func(o *Options) {
o.assetDir = assetDir o.assetDir = assetDir
@ -92,6 +86,12 @@ func WithSingleActiveBackend() Option {
} }
} }
func WithModelID(id string) Option {
return func(o *Options) {
o.modelID = id
}
}
func NewOptions(opts ...Option) *Options { func NewOptions(opts ...Option) *Options {
o := &Options{ o := &Options{
gRPCOptions: &pb.ModelOptions{}, gRPCOptions: &pb.ModelOptions{},

View File

@ -16,16 +16,26 @@ import (
) )
func (ml *ModelLoader) deleteProcess(s string) error { func (ml *ModelLoader) deleteProcess(s string) error {
if m, exists := ml.models[s]; exists { defer delete(ml.models, s)
process := m.Process()
if process != nil { m, exists := ml.models[s]
if err := process.Stop(); err != nil { if !exists {
log.Error().Err(err).Msgf("(deleteProcess) error while deleting process %s", s) // Nothing to do
} return nil
}
} }
delete(ml.models, s)
return nil process := m.Process()
if process == nil {
// Nothing to do as there is no process
return nil
}
err := process.Stop()
if err != nil {
log.Error().Err(err).Msgf("(deleteProcess) error while deleting process %s", s)
}
return err
} }
func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) error { func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) error {

View File

@ -260,11 +260,9 @@ var _ = Describe("E2E test", func() {
resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized)) resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized))
Expect(err).To(BeNil()) Expect(err).To(BeNil())
Expect(resp).ToNot(BeNil()) Expect(resp).ToNot(BeNil())
Expect(resp.StatusCode).To(Equal(200))
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
Expect(err).To(BeNil()) Expect(err).ToNot(HaveOccurred())
Expect(body).ToNot(BeNil()) Expect(resp.StatusCode).To(Equal(200), fmt.Sprintf("body: %s, response: %+v", body, resp))
deserializedResponse := schema.JINARerankResponse{} deserializedResponse := schema.JINARerankResponse{}
err = json.Unmarshal(body, &deserializedResponse) err = json.Unmarshal(body, &deserializedResponse)