feat: track internally started models by ID (#3693)

* chore(refactor): track internally started models by ID

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Just extend options, no need to copy

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Improve debugging for rerankers failures

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Simplify model loading with rerankers

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Be more consistent when generating model options

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Uncommitted code

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Make deleteProcess more idiomatic

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Adapt CLI for sound generation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Fixup threads definition

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Handle corner case where c.Seed is nil

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Consistently use ModelOptions

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Adapt new code to refactoring

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Dave <dave@gray101.com>
This commit is contained in:
Ettore Di Giacinto 2024-10-02 08:55:58 +02:00 committed by GitHub
parent db704199dc
commit 0965c6cd68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 169 additions and 185 deletions

View File

@ -10,20 +10,11 @@ import (
)
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
modelFile := backendConfig.Model
grpcOpts := GRPCModelOpts(backendConfig)
var inferenceModel interface{}
var err error
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(*backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
})
opts := ModelOptions(backendConfig, appConfig, []model.Option{})
if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)

View File

@ -8,19 +8,8 @@ import (
)
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
threads := backendConfig.Threads
if *threads == 0 && appConfig.Threads != 0 {
threads = &appConfig.Threads
}
gRPCOpts := GRPCModelOpts(backendConfig)
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(backendConfig.Backend),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithThreads(uint32(*threads)),
model.WithContext(appConfig.Context),
model.WithModel(backendConfig.Model),
model.WithLoadGRPCLoadModelOpts(gRPCOpts),
})
opts := ModelOptions(backendConfig, appConfig, []model.Option{})
inferenceModel, err := loader.BackendLoader(
opts...,

View File

@ -33,22 +33,11 @@ type TokenUsage struct {
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
threads := c.Threads
if *threads == 0 && o.Threads != 0 {
threads = &o.Threads
}
grpcOpts := GRPCModelOpts(c)
var inferenceModel grpc.Backend
var err error
opts := modelOpts(c, o, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
})
opts := ModelOptions(c, o, []model.Option{})
if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))

View File

@ -11,32 +11,65 @@ import (
"github.com/rs/zerolog/log"
)
func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
name := c.Name
if name == "" {
name = c.Model
}
defOpts := []model.Option{
model.WithBackendString(c.Backend),
model.WithModel(c.Model),
model.WithAssetDir(so.AssetsDestination),
model.WithContext(so.Context),
model.WithModelID(name),
}
threads := 1
if c.Threads != nil {
threads = *c.Threads
}
if so.Threads != 0 {
threads = so.Threads
}
c.Threads = &threads
grpcOpts := grpcModelOpts(c)
defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts))
if so.SingleBackend {
opts = append(opts, model.WithSingleActiveBackend())
defOpts = append(defOpts, model.WithSingleActiveBackend())
}
if so.ParallelBackendRequests {
opts = append(opts, model.EnableParallelRequests)
defOpts = append(defOpts, model.EnableParallelRequests)
}
if c.GRPC.Attempts != 0 {
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
defOpts = append(defOpts, model.WithGRPCAttempts(c.GRPC.Attempts))
}
if c.GRPC.AttemptsSleepTime != 0 {
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
defOpts = append(defOpts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
}
for k, v := range so.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
defOpts = append(defOpts, model.WithExternalBackend(k, v))
}
return opts
return append(defOpts, opts...)
}
func getSeed(c config.BackendConfig) int32 {
seed := int32(*c.Seed)
var seed int32 = config.RAND_SEED
if c.Seed != nil {
seed = int32(*c.Seed)
}
if seed == config.RAND_SEED {
seed = rand.Int31()
}
@ -44,11 +77,47 @@ func getSeed(c config.BackendConfig) int32 {
return seed
}
func GRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
b := 512
if c.Batch != 0 {
b = c.Batch
}
f16 := false
if c.F16 != nil {
f16 = *c.F16
}
embeddings := false
if c.Embeddings != nil {
embeddings = *c.Embeddings
}
lowVRAM := false
if c.LowVRAM != nil {
lowVRAM = *c.LowVRAM
}
mmap := false
if c.MMap != nil {
mmap = *c.MMap
}
ctxSize := 1024
if c.ContextSize != nil {
ctxSize = *c.ContextSize
}
mmlock := false
if c.MMlock != nil {
mmlock = *c.MMlock
}
nGPULayers := 9999999
if c.NGPULayers != nil {
nGPULayers = *c.NGPULayers
}
return &pb.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
@ -56,14 +125,14 @@ func GRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
CFGScale: c.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter,
LoraScale: c.LoraScale,
F16Memory: *c.F16,
F16Memory: f16,
LoraBase: c.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
ContextSize: int32(*c.ContextSize),
ContextSize: int32(ctxSize),
Seed: getSeed(c),
NBatch: int32(b),
NoMulMatQ: c.NoMulMatQ,
@ -85,16 +154,16 @@ func GRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
YarnBetaSlow: c.YarnBetaSlow,
NGQA: c.NGQA,
RMSNormEps: c.RMSNormEps,
MLock: *c.MMlock,
MLock: mmlock,
RopeFreqBase: c.RopeFreqBase,
RopeScaling: c.RopeScaling,
Type: c.ModelType,
RopeFreqScale: c.RopeFreqScale,
NUMA: c.NUMA,
Embeddings: *c.Embeddings,
LowVRAM: *c.LowVRAM,
NGPULayers: int32(*c.NGPULayers),
MMap: *c.MMap,
Embeddings: embeddings,
LowVRAM: lowVRAM,
NGPULayers: int32(nGPULayers),
MMap: mmap,
MainGPU: c.MainGPU,
Threads: int32(*c.Threads),
TensorSplit: c.TensorSplit,

View File

@ -9,21 +9,9 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
func Rerank(backend, modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
bb := backend
if bb == "" {
return nil, fmt.Errorf("backend is required")
}
func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
grpcOpts := GRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
rerankModel, err := loader.BackendLoader(opts...)
if err != nil {
return nil, err

View File

@ -13,7 +13,6 @@ import (
)
func SoundGeneration(
backend string,
modelFile string,
text string,
duration *float32,
@ -25,18 +24,8 @@ func SoundGeneration(
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig,
) (string, *proto.Result, error) {
if backend == "" {
return "", nil, fmt.Errorf("backend is a required parameter")
}
grpcOpts := GRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(backend),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
soundGenModel, err := loader.BackendLoader(opts...)
if err != nil {

View File

@ -10,24 +10,13 @@ import (
)
func TokenMetrics(
backend,
modelFile string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig) (*proto.MetricsResponse, error) {
bb := backend
if bb == "" {
return nil, fmt.Errorf("backend is required")
}
grpcOpts := GRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
opts := ModelOptions(backendConfig, appConfig, []model.Option{
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
model, err := loader.BackendLoader(opts...)
if err != nil {

View File

@ -14,13 +14,11 @@ import (
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(model.WhisperBackend),
model.WithModel(backendConfig.Model),
model.WithContext(appConfig.Context),
model.WithThreads(uint32(*backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
})
if backendConfig.Backend == "" {
backendConfig.Backend = model.WhisperBackend
}
opts := ModelOptions(backendConfig, appConfig, []model.Option{})
transcriptionModel, err := ml.BackendLoader(opts...)
if err != nil {

View File

@ -28,14 +28,9 @@ func ModelTTS(
bb = model.PiperBackend
}
grpcOpts := GRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
opts := ModelOptions(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
ttsModel, err := loader.BackendLoader(opts...)
if err != nil {

View File

@ -85,13 +85,14 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
options := config.BackendConfig{}
options.SetDefaults()
options.Backend = t.Backend
var inputFile *string
if t.InputFile != "" {
inputFile = &t.InputFile
}
filePath, _, err := backend.SoundGeneration(t.Backend, t.Model, text,
filePath, _, err := backend.SoundGeneration(t.Model, text,
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options)

View File

@ -55,7 +55,7 @@ func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
}
// TODO: Support uploading files?
filePath, _, err := backend.SoundGeneration(cfg.Backend, modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
filePath, _, err := backend.SoundGeneration(modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@ -45,13 +45,13 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
@ -64,7 +64,7 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
Documents: req.Documents,
}
results, err := backend.Rerank(cfg.Backend, modelFile, request, ml, appConfig, *cfg)
results, err := backend.Rerank(modelFile, request, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@ -51,7 +51,7 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader,
}
log.Debug().Msgf("Token Metrics for model: %s", modelFile)
response, err := backend.TokenMetrics(cfg.Backend, modelFile, ml, appConfig, *cfg)
response, err := backend.TokenMetrics(modelFile, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@ -160,13 +160,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model)
grpcOpts := backend.GRPCModelOpts(*cfg)
o := []model.Option{
model.WithModel(cfg.Model),
model.WithAssetDir(options.AssetsDestination),
model.WithThreads(uint32(options.Threads)),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
}
o := backend.ModelOptions(*cfg, options, []model.Option{})
var backendErr error
if cfg.Backend != "" {

View File

@ -268,10 +268,10 @@ func selectGRPCProcess(backend, assetDir string, f16 bool) string {
// starts the grpcModelProcess for the backend, and returns a grpc client
// It also loads the model
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (*Model, error) {
return func(modelName, modelFile string) (*Model, error) {
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string, string) (*Model, error) {
return func(modelID, modelName, modelFile string) (*Model, error) {
log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o)
log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelID, modelFile, backend, *o)
var client *Model
@ -304,7 +304,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
}
// Make sure the process is executable
process, err := ml.startProcess(uri, o.model, serverAddress)
process, err := ml.startProcess(uri, modelID, serverAddress)
if err != nil {
log.Error().Err(err).Str("path", uri).Msg("failed to launch ")
return nil, err
@ -312,11 +312,11 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
log.Debug().Msgf("GRPC Service Started")
client = NewModel(modelName, serverAddress, process)
client = NewModel(modelID, serverAddress, process)
} else {
log.Debug().Msg("external backend is uri")
// address
client = NewModel(modelName, uri, nil)
client = NewModel(modelID, uri, nil)
}
} else {
grpcProcess := backendPath(o.assetDir, backend)
@ -347,14 +347,14 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
args, grpcProcess = library.LoadLDSO(o.assetDir, args, grpcProcess)
// Make sure the process is executable in any circumstance
process, err := ml.startProcess(grpcProcess, o.model, serverAddress, args...)
process, err := ml.startProcess(grpcProcess, modelID, serverAddress, args...)
if err != nil {
return nil, err
}
log.Debug().Msgf("GRPC Service Started")
client = NewModel(modelName, serverAddress, process)
client = NewModel(modelID, serverAddress, process)
}
log.Debug().Msgf("Wait for the service to start up")
@ -407,11 +407,7 @@ func (ml *ModelLoader) ListAvailableBackends(assetdir string) ([]string, error)
func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) {
o := NewOptions(opts...)
if o.model != "" {
log.Info().Msgf("Loading model '%s' with backend %s", o.model, o.backendString)
} else {
log.Info().Msgf("Loading model with backend %s", o.backendString)
}
log.Info().Msgf("Loading model '%s' with backend %s", o.modelID, o.backendString)
backend := strings.ToLower(o.backendString)
if realBackend, exists := Aliases[backend]; exists {
@ -420,10 +416,10 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e
}
if o.singleActiveBackend {
log.Debug().Msgf("Stopping all backends except '%s'", o.model)
err := ml.StopGRPC(allExcept(o.model))
log.Debug().Msgf("Stopping all backends except '%s'", o.modelID)
err := ml.StopGRPC(allExcept(o.modelID))
if err != nil {
log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel")
log.Error().Err(err).Str("keptModel", o.modelID).Msg("error while shutting down all backends except for the keptModel")
}
}
@ -437,7 +433,7 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e
backendToConsume = backend
}
model, err := ml.LoadModel(o.model, ml.grpcModel(backendToConsume, o))
model, err := ml.LoadModel(o.modelID, o.model, ml.grpcModel(backendToConsume, o))
if err != nil {
return nil, err
}
@ -450,18 +446,18 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
// Return earlier if we have a model already loaded
// (avoid looping through all the backends)
if m := ml.CheckIsLoaded(o.model); m != nil {
log.Debug().Msgf("Model '%s' already loaded", o.model)
if m := ml.CheckIsLoaded(o.modelID); m != nil {
log.Debug().Msgf("Model '%s' already loaded", o.modelID)
return m.GRPC(o.parallelRequests, ml.wd), nil
}
// If we can have only one backend active, kill all the others (except external backends)
if o.singleActiveBackend {
log.Debug().Msgf("Stopping all backends except '%s'", o.model)
err := ml.StopGRPC(allExcept(o.model))
log.Debug().Msgf("Stopping all backends except '%s'", o.modelID)
err := ml.StopGRPC(allExcept(o.modelID))
if err != nil {
log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel - greedyloader continuing")
log.Error().Err(err).Str("keptModel", o.modelID).Msg("error while shutting down all backends except for the keptModel - greedyloader continuing")
}
}
@ -480,23 +476,13 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
log.Debug().Msgf("Loading from the following backends (in order): %+v", autoLoadBackends)
if o.model != "" {
log.Info().Msgf("Trying to load the model '%s' with the backend '%s'", o.model, autoLoadBackends)
}
log.Info().Msgf("Trying to load the model '%s' with the backend '%s'", o.modelID, autoLoadBackends)
for _, key := range autoLoadBackends {
log.Info().Msgf("[%s] Attempting to load", key)
options := []Option{
options := append(opts, []Option{
WithBackendString(key),
WithModel(o.model),
WithLoadGRPCLoadModelOpts(o.gRPCOptions),
WithThreads(o.threads),
WithAssetDir(o.assetDir),
}
for k, v := range o.externalBackends {
options = append(options, WithExternalBackend(k, v))
}
}...)
model, modelerr := ml.BackendLoader(options...)
if modelerr == nil && model != nil {

View File

@ -114,9 +114,9 @@ func (ml *ModelLoader) ListModels() []Model {
return models
}
func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (*Model, error)) (*Model, error) {
func (ml *ModelLoader) LoadModel(modelID, modelName string, loader func(string, string, string) (*Model, error)) (*Model, error) {
// Check if we already have a loaded model
if model := ml.CheckIsLoaded(modelName); model != nil {
if model := ml.CheckIsLoaded(modelID); model != nil {
return model, nil
}
@ -126,7 +126,7 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
ml.mu.Lock()
defer ml.mu.Unlock()
model, err := loader(modelName, modelFile)
model, err := loader(modelID, modelName, modelFile)
if err != nil {
return nil, err
}
@ -135,7 +135,7 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
return nil, fmt.Errorf("loader didn't return a model")
}
ml.models[modelName] = model
ml.models[modelID] = model
return model, nil
}

View File

@ -65,22 +65,22 @@ var _ = Describe("ModelLoader", func() {
It("should load a model and keep it in memory", func() {
mockModel = model.NewModel("foo", "test.model", nil)
mockLoader := func(modelName, modelFile string) (*model.Model, error) {
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
return mockModel, nil
}
model, err := modelLoader.LoadModel("test.model", mockLoader)
model, err := modelLoader.LoadModel("foo", "test.model", mockLoader)
Expect(err).To(BeNil())
Expect(model).To(Equal(mockModel))
Expect(modelLoader.CheckIsLoaded("test.model")).To(Equal(mockModel))
Expect(modelLoader.CheckIsLoaded("foo")).To(Equal(mockModel))
})
It("should return an error if loading the model fails", func() {
mockLoader := func(modelName, modelFile string) (*model.Model, error) {
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
return nil, errors.New("failed to load model")
}
model, err := modelLoader.LoadModel("test.model", mockLoader)
model, err := modelLoader.LoadModel("foo", "test.model", mockLoader)
Expect(err).To(HaveOccurred())
Expect(model).To(BeNil())
})
@ -88,18 +88,16 @@ var _ = Describe("ModelLoader", func() {
Context("ShutdownModel", func() {
It("should shutdown a loaded model", func() {
mockModel = model.NewModel("foo", "test.model", nil)
mockLoader := func(modelName, modelFile string) (*model.Model, error) {
return mockModel, nil
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
return model.NewModel("foo", "test.model", nil), nil
}
_, err := modelLoader.LoadModel("test.model", mockLoader)
_, err := modelLoader.LoadModel("foo", "test.model", mockLoader)
Expect(err).To(BeNil())
err = modelLoader.ShutdownModel("test.model")
err = modelLoader.ShutdownModel("foo")
Expect(err).To(BeNil())
Expect(modelLoader.CheckIsLoaded("test.model")).To(BeNil())
Expect(modelLoader.CheckIsLoaded("foo")).To(BeNil())
})
})
})

View File

@ -9,7 +9,7 @@ import (
type Options struct {
backendString string
model string
threads uint32
modelID string
assetDir string
context context.Context
@ -68,12 +68,6 @@ func WithLoadGRPCLoadModelOpts(opts *pb.ModelOptions) Option {
}
}
func WithThreads(threads uint32) Option {
return func(o *Options) {
o.threads = threads
}
}
func WithAssetDir(assetDir string) Option {
return func(o *Options) {
o.assetDir = assetDir
@ -92,6 +86,12 @@ func WithSingleActiveBackend() Option {
}
}
func WithModelID(id string) Option {
return func(o *Options) {
o.modelID = id
}
}
func NewOptions(opts ...Option) *Options {
o := &Options{
gRPCOptions: &pb.ModelOptions{},

View File

@ -16,16 +16,26 @@ import (
)
func (ml *ModelLoader) deleteProcess(s string) error {
if m, exists := ml.models[s]; exists {
defer delete(ml.models, s)
m, exists := ml.models[s]
if !exists {
// Nothing to do
return nil
}
process := m.Process()
if process != nil {
if err := process.Stop(); err != nil {
if process == nil {
// Nothing to do as there is no process
return nil
}
err := process.Stop()
if err != nil {
log.Error().Err(err).Msgf("(deleteProcess) error while deleting process %s", s)
}
}
}
delete(ml.models, s)
return nil
return err
}
func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) error {

View File

@ -260,11 +260,9 @@ var _ = Describe("E2E test", func() {
resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized))
Expect(err).To(BeNil())
Expect(resp).ToNot(BeNil())
Expect(resp.StatusCode).To(Equal(200))
body, err := io.ReadAll(resp.Body)
Expect(err).To(BeNil())
Expect(body).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200), fmt.Sprintf("body: %s, response: %+v", body, resp))
deserializedResponse := schema.JINARerankResponse{}
err = json.Unmarshal(body, &deserializedResponse)