From 607fd066f0b47cca0d14bd65a64a6385f4f98be3 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 30 Aug 2024 15:20:39 +0200 Subject: [PATCH] chore(model-loader): increase test coverage of model loader (#3433) chore(model-loader): increase coverage of model loader Signed-off-by: Ettore Di Giacinto --- pkg/model/loader.go | 33 +++++++++++- pkg/model/loader_test.go | 105 +++++++++++++++++++++++++++++++++++++++ pkg/model/model.go | 5 +- 3 files changed, 138 insertions(+), 5 deletions(-) create mode 100644 pkg/model/loader_test.go diff --git a/pkg/model/loader.go b/pkg/model/loader.go index c1ed01dc..90fda35f 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -7,6 +7,7 @@ import ( "path/filepath" "strings" "sync" + "time" "github.com/mudler/LocalAI/pkg/templates" @@ -102,6 +103,18 @@ FILE: return models, nil } +func (ml *ModelLoader) ListModels() []*Model { + ml.mu.Lock() + defer ml.mu.Unlock() + + models := []*Model{} + for _, model := range ml.models { + models = append(models, model) + } + + return models +} + func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (*Model, error)) (*Model, error) { ml.mu.Lock() defer ml.mu.Unlock() @@ -120,7 +133,12 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) ( return nil, err } + if model == nil { + return nil, fmt.Errorf("loader didn't return a model") + } + ml.models[modelName] = model + return model, nil } @@ -146,11 +164,22 @@ func (ml *ModelLoader) CheckIsLoaded(s string) *Model { } log.Debug().Msgf("Model already loaded in memory: %s", s) - alive, err := m.GRPC(false, ml.wd).HealthCheck(context.Background()) + client := m.GRPC(false, ml.wd) + + log.Debug().Msgf("Checking model availability (%s)", s) + cTimeout, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + alive, err := client.HealthCheck(cTimeout) if !alive { log.Warn().Msgf("GRPC Model not responding: %s", err.Error()) log.Warn().Msgf("Deleting the process in order to recreate it") - if !ml.grpcProcesses[s].IsAlive() { + process, exists := ml.grpcProcesses[s] + if !exists { + log.Error().Msgf("Process not found for '%s' and the model is not responding anymore !", s) + return m + } + if !process.IsAlive() { log.Debug().Msgf("GRPC Process is not responding: %s", s) // stop and delete the process, this forces to re-load the model and re-create again the service err := ml.deleteProcess(s) diff --git a/pkg/model/loader_test.go b/pkg/model/loader_test.go new file mode 100644 index 00000000..4621844e --- /dev/null +++ b/pkg/model/loader_test.go @@ -0,0 +1,105 @@ +package model_test + +import ( + "errors" + "os" + "path/filepath" + + "github.com/mudler/LocalAI/pkg/model" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("ModelLoader", func() { + var ( + modelLoader *model.ModelLoader + modelPath string + mockModel *model.Model + ) + + BeforeEach(func() { + // Setup the model loader with a test directory + modelPath = "/tmp/test_model_path" + os.Mkdir(modelPath, 0755) + modelLoader = model.NewModelLoader(modelPath) + }) + + AfterEach(func() { + // Cleanup test directory + os.RemoveAll(modelPath) + }) + + Context("NewModelLoader", func() { + It("should create a new ModelLoader with an empty model map", func() { + Expect(modelLoader).ToNot(BeNil()) + Expect(modelLoader.ModelPath).To(Equal(modelPath)) + Expect(modelLoader.ListModels()).To(BeEmpty()) + }) + }) + + Context("ExistsInModelPath", func() { + It("should return true if a file exists in the model path", func() { + testFile := filepath.Join(modelPath, "test.model") + os.Create(testFile) + Expect(modelLoader.ExistsInModelPath("test.model")).To(BeTrue()) + }) + + It("should return false if a file does not exist in the model path", func() { + Expect(modelLoader.ExistsInModelPath("nonexistent.model")).To(BeFalse()) + }) + }) + + Context("ListFilesInModelPath", func() { + It("should list all valid model files in the model path", func() { + os.Create(filepath.Join(modelPath, "test.model")) + os.Create(filepath.Join(modelPath, "README.md")) + + files, err := modelLoader.ListFilesInModelPath() + Expect(err).To(BeNil()) + Expect(files).To(ContainElement("test.model")) + Expect(files).ToNot(ContainElement("README.md")) + }) + }) + + Context("LoadModel", func() { + It("should load a model and keep it in memory", func() { + mockModel = model.NewModel("test.model") + + mockLoader := func(modelName, modelFile string) (*model.Model, error) { + return mockModel, nil + } + + model, err := modelLoader.LoadModel("test.model", mockLoader) + Expect(err).To(BeNil()) + Expect(model).To(Equal(mockModel)) + Expect(modelLoader.CheckIsLoaded("test.model")).To(Equal(mockModel)) + }) + + It("should return an error if loading the model fails", func() { + mockLoader := func(modelName, modelFile string) (*model.Model, error) { + return nil, errors.New("failed to load model") + } + + model, err := modelLoader.LoadModel("test.model", mockLoader) + Expect(err).To(HaveOccurred()) + Expect(model).To(BeNil()) + }) + }) + + Context("ShutdownModel", func() { + It("should shutdown a loaded model", func() { + mockModel = model.NewModel("test.model") + + mockLoader := func(modelName, modelFile string) (*model.Model, error) { + return mockModel, nil + } + + _, err := modelLoader.LoadModel("test.model", mockLoader) + Expect(err).To(BeNil()) + + err = modelLoader.ShutdownModel("test.model") + Expect(err).To(BeNil()) + Expect(modelLoader.CheckIsLoaded("test.model")).To(BeNil()) + }) + }) +}) diff --git a/pkg/model/model.go b/pkg/model/model.go index 26ddb8cc..1927dc0c 100644 --- a/pkg/model/model.go +++ b/pkg/model/model.go @@ -23,7 +23,6 @@ func (m *Model) GRPC(parallel bool, wd *WatchDog) grpc.Backend { enableWD = true } - client := grpc.NewClient(m.address, parallel, wd, enableWD) - m.client = client - return client + m.client = grpc.NewClient(m.address, parallel, wd, enableWD) + return m.client }