diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index 9f0f8be9..264d947b 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -10,20 +10,11 @@ import ( ) func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { - modelFile := backendConfig.Model - - grpcOpts := GRPCModelOpts(backendConfig) var inferenceModel interface{} var err error - opts := modelOpts(backendConfig, appConfig, []model.Option{ - model.WithLoadGRPCLoadModelOpts(grpcOpts), - model.WithThreads(uint32(*backendConfig.Threads)), - model.WithAssetDir(appConfig.AssetsDestination), - model.WithModel(modelFile), - model.WithContext(appConfig.Context), - }) + opts := ModelOptions(backendConfig, appConfig, []model.Option{}) if backendConfig.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) diff --git a/core/backend/image.go b/core/backend/image.go index 5c2a950c..72c0007c 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -8,19 +8,8 @@ import ( ) func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { - threads := backendConfig.Threads - if *threads == 0 && appConfig.Threads != 0 { - threads = &appConfig.Threads - } - gRPCOpts := GRPCModelOpts(backendConfig) - opts := modelOpts(backendConfig, appConfig, []model.Option{ - model.WithBackendString(backendConfig.Backend), - model.WithAssetDir(appConfig.AssetsDestination), - model.WithThreads(uint32(*threads)), - model.WithContext(appConfig.Context), - model.WithModel(backendConfig.Model), - model.WithLoadGRPCLoadModelOpts(gRPCOpts), - }) + + opts := ModelOptions(backendConfig, appConfig, []model.Option{}) inferenceModel, err := loader.BackendLoader( opts..., diff --git a/core/backend/llm.go b/core/backend/llm.go index cac9beba..d946d3f8 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -33,22 +33,11 @@ type TokenUsage struct { func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { modelFile := c.Model - threads := c.Threads - if *threads == 0 && o.Threads != 0 { - threads = &o.Threads - } - grpcOpts := GRPCModelOpts(c) var inferenceModel grpc.Backend var err error - opts := modelOpts(c, o, []model.Option{ - model.WithLoadGRPCLoadModelOpts(grpcOpts), - model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup - model.WithAssetDir(o.AssetsDestination), - model.WithModel(modelFile), - model.WithContext(o.Context), - }) + opts := ModelOptions(c, o, []model.Option{}) if c.Backend != "" { opts = append(opts, model.WithBackendString(c.Backend)) diff --git a/core/backend/options.go b/core/backend/options.go index d431aab6..90d563e0 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -11,32 +11,65 @@ import ( "github.com/rs/zerolog/log" ) -func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option { +func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option { + name := c.Name + if name == "" { + name = c.Model + } + + defOpts := []model.Option{ + model.WithBackendString(c.Backend), + model.WithModel(c.Model), + model.WithAssetDir(so.AssetsDestination), + model.WithContext(so.Context), + model.WithModelID(name), + } + + threads := 1 + + if c.Threads != nil { + threads = *c.Threads + } + + if so.Threads != 0 { + threads = so.Threads + } + + c.Threads = &threads + + grpcOpts := grpcModelOpts(c) + defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts)) + if so.SingleBackend { - opts = append(opts, model.WithSingleActiveBackend()) + defOpts = append(defOpts, model.WithSingleActiveBackend()) } if so.ParallelBackendRequests { - opts = append(opts, model.EnableParallelRequests) + defOpts = append(defOpts, model.EnableParallelRequests) } if c.GRPC.Attempts != 0 { - opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts)) + defOpts = append(defOpts, model.WithGRPCAttempts(c.GRPC.Attempts)) } if c.GRPC.AttemptsSleepTime != 0 { - opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) + defOpts = append(defOpts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) } for k, v := range so.ExternalGRPCBackends { - opts = append(opts, model.WithExternalBackend(k, v)) + defOpts = append(defOpts, model.WithExternalBackend(k, v)) } - return opts + return append(defOpts, opts...) } func getSeed(c config.BackendConfig) int32 { - seed := int32(*c.Seed) + var seed int32 = config.RAND_SEED + + if c.Seed != nil { + seed = int32(*c.Seed) + } + if seed == config.RAND_SEED { seed = rand.Int31() } @@ -44,11 +77,47 @@ func getSeed(c config.BackendConfig) int32 { return seed } -func GRPCModelOpts(c config.BackendConfig) *pb.ModelOptions { +func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions { b := 512 if c.Batch != 0 { b = c.Batch } + + f16 := false + if c.F16 != nil { + f16 = *c.F16 + } + + embeddings := false + if c.Embeddings != nil { + embeddings = *c.Embeddings + } + + lowVRAM := false + if c.LowVRAM != nil { + lowVRAM = *c.LowVRAM + } + + mmap := false + if c.MMap != nil { + mmap = *c.MMap + } + + ctxSize := 1024 + if c.ContextSize != nil { + ctxSize = *c.ContextSize + } + + mmlock := false + if c.MMlock != nil { + mmlock = *c.MMlock + } + + nGPULayers := 9999999 + if c.NGPULayers != nil { + nGPULayers = *c.NGPULayers + } + return &pb.ModelOptions{ CUDA: c.CUDA || c.Diffusers.CUDA, SchedulerType: c.Diffusers.SchedulerType, @@ -56,14 +125,14 @@ func GRPCModelOpts(c config.BackendConfig) *pb.ModelOptions { CFGScale: c.Diffusers.CFGScale, LoraAdapter: c.LoraAdapter, LoraScale: c.LoraScale, - F16Memory: *c.F16, + F16Memory: f16, LoraBase: c.LoraBase, IMG2IMG: c.Diffusers.IMG2IMG, CLIPModel: c.Diffusers.ClipModel, CLIPSubfolder: c.Diffusers.ClipSubFolder, CLIPSkip: int32(c.Diffusers.ClipSkip), ControlNet: c.Diffusers.ControlNet, - ContextSize: int32(*c.ContextSize), + ContextSize: int32(ctxSize), Seed: getSeed(c), NBatch: int32(b), NoMulMatQ: c.NoMulMatQ, @@ -85,16 +154,16 @@ func GRPCModelOpts(c config.BackendConfig) *pb.ModelOptions { YarnBetaSlow: c.YarnBetaSlow, NGQA: c.NGQA, RMSNormEps: c.RMSNormEps, - MLock: *c.MMlock, + MLock: mmlock, RopeFreqBase: c.RopeFreqBase, RopeScaling: c.RopeScaling, Type: c.ModelType, RopeFreqScale: c.RopeFreqScale, NUMA: c.NUMA, - Embeddings: *c.Embeddings, - LowVRAM: *c.LowVRAM, - NGPULayers: int32(*c.NGPULayers), - MMap: *c.MMap, + Embeddings: embeddings, + LowVRAM: lowVRAM, + NGPULayers: int32(nGPULayers), + MMap: mmap, MainGPU: c.MainGPU, Threads: int32(*c.Threads), TensorSplit: c.TensorSplit, diff --git a/core/backend/rerank.go b/core/backend/rerank.go index a7573ade..f600e2e6 100644 --- a/core/backend/rerank.go +++ b/core/backend/rerank.go @@ -9,21 +9,9 @@ import ( model "github.com/mudler/LocalAI/pkg/model" ) -func Rerank(backend, modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) { - bb := backend - if bb == "" { - return nil, fmt.Errorf("backend is required") - } +func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) { - grpcOpts := GRPCModelOpts(backendConfig) - - opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{ - model.WithBackendString(bb), - model.WithModel(modelFile), - model.WithContext(appConfig.Context), - model.WithAssetDir(appConfig.AssetsDestination), - model.WithLoadGRPCLoadModelOpts(grpcOpts), - }) + opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)}) rerankModel, err := loader.BackendLoader(opts...) if err != nil { return nil, err diff --git a/core/backend/soundgeneration.go b/core/backend/soundgeneration.go index b6a1c827..b1b458b4 100644 --- a/core/backend/soundgeneration.go +++ b/core/backend/soundgeneration.go @@ -13,7 +13,6 @@ import ( ) func SoundGeneration( - backend string, modelFile string, text string, duration *float32, @@ -25,18 +24,8 @@ func SoundGeneration( appConfig *config.ApplicationConfig, backendConfig config.BackendConfig, ) (string, *proto.Result, error) { - if backend == "" { - return "", nil, fmt.Errorf("backend is a required parameter") - } - grpcOpts := GRPCModelOpts(backendConfig) - opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{ - model.WithBackendString(backend), - model.WithModel(modelFile), - model.WithContext(appConfig.Context), - model.WithAssetDir(appConfig.AssetsDestination), - model.WithLoadGRPCLoadModelOpts(grpcOpts), - }) + opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)}) soundGenModel, err := loader.BackendLoader(opts...) if err != nil { diff --git a/core/backend/token_metrics.go b/core/backend/token_metrics.go index cd715108..acd25663 100644 --- a/core/backend/token_metrics.go +++ b/core/backend/token_metrics.go @@ -10,24 +10,13 @@ import ( ) func TokenMetrics( - backend, modelFile string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.MetricsResponse, error) { - bb := backend - if bb == "" { - return nil, fmt.Errorf("backend is required") - } - grpcOpts := GRPCModelOpts(backendConfig) - - opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{ - model.WithBackendString(bb), + opts := ModelOptions(backendConfig, appConfig, []model.Option{ model.WithModel(modelFile), - model.WithContext(appConfig.Context), - model.WithAssetDir(appConfig.AssetsDestination), - model.WithLoadGRPCLoadModelOpts(grpcOpts), }) model, err := loader.BackendLoader(opts...) if err != nil { diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 6ebc7c10..c6ad9b59 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -14,13 +14,11 @@ import ( func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { - opts := modelOpts(backendConfig, appConfig, []model.Option{ - model.WithBackendString(model.WhisperBackend), - model.WithModel(backendConfig.Model), - model.WithContext(appConfig.Context), - model.WithThreads(uint32(*backendConfig.Threads)), - model.WithAssetDir(appConfig.AssetsDestination), - }) + if backendConfig.Backend == "" { + backendConfig.Backend = model.WhisperBackend + } + + opts := ModelOptions(backendConfig, appConfig, []model.Option{}) transcriptionModel, err := ml.BackendLoader(opts...) if err != nil { diff --git a/core/backend/tts.go b/core/backend/tts.go index 2401748c..bac2e900 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -28,14 +28,9 @@ func ModelTTS( bb = model.PiperBackend } - grpcOpts := GRPCModelOpts(backendConfig) - - opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{ + opts := ModelOptions(config.BackendConfig{}, appConfig, []model.Option{ model.WithBackendString(bb), model.WithModel(modelFile), - model.WithContext(appConfig.Context), - model.WithAssetDir(appConfig.AssetsDestination), - model.WithLoadGRPCLoadModelOpts(grpcOpts), }) ttsModel, err := loader.BackendLoader(opts...) if err != nil { diff --git a/core/cli/soundgeneration.go b/core/cli/soundgeneration.go index 5711b199..82bc0346 100644 --- a/core/cli/soundgeneration.go +++ b/core/cli/soundgeneration.go @@ -85,13 +85,14 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error { options := config.BackendConfig{} options.SetDefaults() + options.Backend = t.Backend var inputFile *string if t.InputFile != "" { inputFile = &t.InputFile } - filePath, _, err := backend.SoundGeneration(t.Backend, t.Model, text, + filePath, _, err := backend.SoundGeneration(t.Model, text, parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample, inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options) diff --git a/core/http/endpoints/elevenlabs/soundgeneration.go b/core/http/endpoints/elevenlabs/soundgeneration.go index 619544d8..345df35b 100644 --- a/core/http/endpoints/elevenlabs/soundgeneration.go +++ b/core/http/endpoints/elevenlabs/soundgeneration.go @@ -55,7 +55,7 @@ func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad } // TODO: Support uploading files? - filePath, _, err := backend.SoundGeneration(cfg.Backend, modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg) + filePath, _, err := backend.SoundGeneration(modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg) if err != nil { return err } diff --git a/core/http/endpoints/jina/rerank.go b/core/http/endpoints/jina/rerank.go index 04fdf031..58c3972d 100644 --- a/core/http/endpoints/jina/rerank.go +++ b/core/http/endpoints/jina/rerank.go @@ -45,13 +45,13 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a config.LoadOptionContextSize(appConfig.ContextSize), config.LoadOptionF16(appConfig.F16), ) - if err != nil { modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) } else { modelFile = cfg.Model } + log.Debug().Msgf("Request for model: %s", modelFile) if input.Backend != "" { @@ -64,7 +64,7 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a Documents: req.Documents, } - results, err := backend.Rerank(cfg.Backend, modelFile, request, ml, appConfig, *cfg) + results, err := backend.Rerank(modelFile, request, ml, appConfig, *cfg) if err != nil { return err } diff --git a/core/http/endpoints/localai/get_token_metrics.go b/core/http/endpoints/localai/get_token_metrics.go index 95e79bac..e0e6943f 100644 --- a/core/http/endpoints/localai/get_token_metrics.go +++ b/core/http/endpoints/localai/get_token_metrics.go @@ -51,7 +51,7 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, } log.Debug().Msgf("Token Metrics for model: %s", modelFile) - response, err := backend.TokenMetrics(cfg.Backend, modelFile, ml, appConfig, *cfg) + response, err := backend.TokenMetrics(modelFile, ml, appConfig, *cfg) if err != nil { return err } diff --git a/core/startup/startup.go b/core/startup/startup.go index b7b9ce8f..17e54bc0 100644 --- a/core/startup/startup.go +++ b/core/startup/startup.go @@ -160,13 +160,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model) - grpcOpts := backend.GRPCModelOpts(*cfg) - o := []model.Option{ - model.WithModel(cfg.Model), - model.WithAssetDir(options.AssetsDestination), - model.WithThreads(uint32(options.Threads)), - model.WithLoadGRPCLoadModelOpts(grpcOpts), - } + o := backend.ModelOptions(*cfg, options, []model.Option{}) var backendErr error if cfg.Backend != "" { diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index d0f47373..6f56b453 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -268,10 +268,10 @@ func selectGRPCProcess(backend, assetDir string, f16 bool) string { // starts the grpcModelProcess for the backend, and returns a grpc client // It also loads the model -func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (*Model, error) { - return func(modelName, modelFile string) (*Model, error) { +func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string, string) (*Model, error) { + return func(modelID, modelName, modelFile string) (*Model, error) { - log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o) + log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelID, modelFile, backend, *o) var client *Model @@ -304,7 +304,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string return nil, fmt.Errorf("failed allocating free ports: %s", err.Error()) } // Make sure the process is executable - process, err := ml.startProcess(uri, o.model, serverAddress) + process, err := ml.startProcess(uri, modelID, serverAddress) if err != nil { log.Error().Err(err).Str("path", uri).Msg("failed to launch ") return nil, err @@ -312,11 +312,11 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string log.Debug().Msgf("GRPC Service Started") - client = NewModel(modelName, serverAddress, process) + client = NewModel(modelID, serverAddress, process) } else { log.Debug().Msg("external backend is uri") // address - client = NewModel(modelName, uri, nil) + client = NewModel(modelID, uri, nil) } } else { grpcProcess := backendPath(o.assetDir, backend) @@ -347,14 +347,14 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string args, grpcProcess = library.LoadLDSO(o.assetDir, args, grpcProcess) // Make sure the process is executable in any circumstance - process, err := ml.startProcess(grpcProcess, o.model, serverAddress, args...) + process, err := ml.startProcess(grpcProcess, modelID, serverAddress, args...) if err != nil { return nil, err } log.Debug().Msgf("GRPC Service Started") - client = NewModel(modelName, serverAddress, process) + client = NewModel(modelID, serverAddress, process) } log.Debug().Msgf("Wait for the service to start up") @@ -407,11 +407,7 @@ func (ml *ModelLoader) ListAvailableBackends(assetdir string) ([]string, error) func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) { o := NewOptions(opts...) - if o.model != "" { - log.Info().Msgf("Loading model '%s' with backend %s", o.model, o.backendString) - } else { - log.Info().Msgf("Loading model with backend %s", o.backendString) - } + log.Info().Msgf("Loading model '%s' with backend %s", o.modelID, o.backendString) backend := strings.ToLower(o.backendString) if realBackend, exists := Aliases[backend]; exists { @@ -420,10 +416,10 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e } if o.singleActiveBackend { - log.Debug().Msgf("Stopping all backends except '%s'", o.model) - err := ml.StopGRPC(allExcept(o.model)) + log.Debug().Msgf("Stopping all backends except '%s'", o.modelID) + err := ml.StopGRPC(allExcept(o.modelID)) if err != nil { - log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel") + log.Error().Err(err).Str("keptModel", o.modelID).Msg("error while shutting down all backends except for the keptModel") } } @@ -437,7 +433,7 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e backendToConsume = backend } - model, err := ml.LoadModel(o.model, ml.grpcModel(backendToConsume, o)) + model, err := ml.LoadModel(o.modelID, o.model, ml.grpcModel(backendToConsume, o)) if err != nil { return nil, err } @@ -450,18 +446,18 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) { // Return earlier if we have a model already loaded // (avoid looping through all the backends) - if m := ml.CheckIsLoaded(o.model); m != nil { - log.Debug().Msgf("Model '%s' already loaded", o.model) + if m := ml.CheckIsLoaded(o.modelID); m != nil { + log.Debug().Msgf("Model '%s' already loaded", o.modelID) return m.GRPC(o.parallelRequests, ml.wd), nil } // If we can have only one backend active, kill all the others (except external backends) if o.singleActiveBackend { - log.Debug().Msgf("Stopping all backends except '%s'", o.model) - err := ml.StopGRPC(allExcept(o.model)) + log.Debug().Msgf("Stopping all backends except '%s'", o.modelID) + err := ml.StopGRPC(allExcept(o.modelID)) if err != nil { - log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel - greedyloader continuing") + log.Error().Err(err).Str("keptModel", o.modelID).Msg("error while shutting down all backends except for the keptModel - greedyloader continuing") } } @@ -480,23 +476,13 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) { log.Debug().Msgf("Loading from the following backends (in order): %+v", autoLoadBackends) - if o.model != "" { - log.Info().Msgf("Trying to load the model '%s' with the backend '%s'", o.model, autoLoadBackends) - } + log.Info().Msgf("Trying to load the model '%s' with the backend '%s'", o.modelID, autoLoadBackends) for _, key := range autoLoadBackends { log.Info().Msgf("[%s] Attempting to load", key) - options := []Option{ + options := append(opts, []Option{ WithBackendString(key), - WithModel(o.model), - WithLoadGRPCLoadModelOpts(o.gRPCOptions), - WithThreads(o.threads), - WithAssetDir(o.assetDir), - } - - for k, v := range o.externalBackends { - options = append(options, WithExternalBackend(k, v)) - } + }...) model, modelerr := ml.BackendLoader(options...) if modelerr == nil && model != nil { diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 68ac1a31..97e62fe4 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -114,9 +114,9 @@ func (ml *ModelLoader) ListModels() []Model { return models } -func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (*Model, error)) (*Model, error) { +func (ml *ModelLoader) LoadModel(modelID, modelName string, loader func(string, string, string) (*Model, error)) (*Model, error) { // Check if we already have a loaded model - if model := ml.CheckIsLoaded(modelName); model != nil { + if model := ml.CheckIsLoaded(modelID); model != nil { return model, nil } @@ -126,7 +126,7 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) ( ml.mu.Lock() defer ml.mu.Unlock() - model, err := loader(modelName, modelFile) + model, err := loader(modelID, modelName, modelFile) if err != nil { return nil, err } @@ -135,7 +135,7 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) ( return nil, fmt.Errorf("loader didn't return a model") } - ml.models[modelName] = model + ml.models[modelID] = model return model, nil } diff --git a/pkg/model/loader_test.go b/pkg/model/loader_test.go index d0ad4e0c..83e47ec6 100644 --- a/pkg/model/loader_test.go +++ b/pkg/model/loader_test.go @@ -65,22 +65,22 @@ var _ = Describe("ModelLoader", func() { It("should load a model and keep it in memory", func() { mockModel = model.NewModel("foo", "test.model", nil) - mockLoader := func(modelName, modelFile string) (*model.Model, error) { + mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) { return mockModel, nil } - model, err := modelLoader.LoadModel("test.model", mockLoader) + model, err := modelLoader.LoadModel("foo", "test.model", mockLoader) Expect(err).To(BeNil()) Expect(model).To(Equal(mockModel)) - Expect(modelLoader.CheckIsLoaded("test.model")).To(Equal(mockModel)) + Expect(modelLoader.CheckIsLoaded("foo")).To(Equal(mockModel)) }) It("should return an error if loading the model fails", func() { - mockLoader := func(modelName, modelFile string) (*model.Model, error) { + mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) { return nil, errors.New("failed to load model") } - model, err := modelLoader.LoadModel("test.model", mockLoader) + model, err := modelLoader.LoadModel("foo", "test.model", mockLoader) Expect(err).To(HaveOccurred()) Expect(model).To(BeNil()) }) @@ -88,18 +88,16 @@ var _ = Describe("ModelLoader", func() { Context("ShutdownModel", func() { It("should shutdown a loaded model", func() { - mockModel = model.NewModel("foo", "test.model", nil) - - mockLoader := func(modelName, modelFile string) (*model.Model, error) { - return mockModel, nil + mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) { + return model.NewModel("foo", "test.model", nil), nil } - _, err := modelLoader.LoadModel("test.model", mockLoader) + _, err := modelLoader.LoadModel("foo", "test.model", mockLoader) Expect(err).To(BeNil()) - err = modelLoader.ShutdownModel("test.model") + err = modelLoader.ShutdownModel("foo") Expect(err).To(BeNil()) - Expect(modelLoader.CheckIsLoaded("test.model")).To(BeNil()) + Expect(modelLoader.CheckIsLoaded("foo")).To(BeNil()) }) }) }) diff --git a/pkg/model/options.go b/pkg/model/options.go index a3f4c855..e7fd06de 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -9,7 +9,7 @@ import ( type Options struct { backendString string model string - threads uint32 + modelID string assetDir string context context.Context @@ -68,12 +68,6 @@ func WithLoadGRPCLoadModelOpts(opts *pb.ModelOptions) Option { } } -func WithThreads(threads uint32) Option { - return func(o *Options) { - o.threads = threads - } -} - func WithAssetDir(assetDir string) Option { return func(o *Options) { o.assetDir = assetDir @@ -92,6 +86,12 @@ func WithSingleActiveBackend() Option { } } +func WithModelID(id string) Option { + return func(o *Options) { + o.modelID = id + } +} + func NewOptions(opts ...Option) *Options { o := &Options{ gRPCOptions: &pb.ModelOptions{}, diff --git a/pkg/model/process.go b/pkg/model/process.go index 48631d79..3e16ddaf 100644 --- a/pkg/model/process.go +++ b/pkg/model/process.go @@ -16,16 +16,26 @@ import ( ) func (ml *ModelLoader) deleteProcess(s string) error { - if m, exists := ml.models[s]; exists { - process := m.Process() - if process != nil { - if err := process.Stop(); err != nil { - log.Error().Err(err).Msgf("(deleteProcess) error while deleting process %s", s) - } - } + defer delete(ml.models, s) + + m, exists := ml.models[s] + if !exists { + // Nothing to do + return nil } - delete(ml.models, s) - return nil + + process := m.Process() + if process == nil { + // Nothing to do as there is no process + return nil + } + + err := process.Stop() + if err != nil { + log.Error().Err(err).Msgf("(deleteProcess) error while deleting process %s", s) + } + + return err } func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) error { diff --git a/tests/e2e-aio/e2e_test.go b/tests/e2e-aio/e2e_test.go index 36d127d2..a9c55497 100644 --- a/tests/e2e-aio/e2e_test.go +++ b/tests/e2e-aio/e2e_test.go @@ -260,11 +260,9 @@ var _ = Describe("E2E test", func() { resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized)) Expect(err).To(BeNil()) Expect(resp).ToNot(BeNil()) - Expect(resp.StatusCode).To(Equal(200)) - body, err := io.ReadAll(resp.Body) - Expect(err).To(BeNil()) - Expect(body).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200), fmt.Sprintf("body: %s, response: %+v", body, resp)) deserializedResponse := schema.JINARerankResponse{} err = json.Unmarshal(body, &deserializedResponse)