mirror of
https://github.com/mudler/LocalAI.git
synced 2025-01-22 12:28:14 +00:00
106 lines
2.9 KiB
Go
106 lines
2.9 KiB
Go
|
package model_test
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
|
||
|
"github.com/mudler/LocalAI/pkg/model"
|
||
|
. "github.com/onsi/ginkgo/v2"
|
||
|
. "github.com/onsi/gomega"
|
||
|
)
|
||
|
|
||
|
var _ = Describe("ModelLoader", func() {
|
||
|
var (
|
||
|
modelLoader *model.ModelLoader
|
||
|
modelPath string
|
||
|
mockModel *model.Model
|
||
|
)
|
||
|
|
||
|
BeforeEach(func() {
|
||
|
// Setup the model loader with a test directory
|
||
|
modelPath = "/tmp/test_model_path"
|
||
|
os.Mkdir(modelPath, 0755)
|
||
|
modelLoader = model.NewModelLoader(modelPath)
|
||
|
})
|
||
|
|
||
|
AfterEach(func() {
|
||
|
// Cleanup test directory
|
||
|
os.RemoveAll(modelPath)
|
||
|
})
|
||
|
|
||
|
Context("NewModelLoader", func() {
|
||
|
It("should create a new ModelLoader with an empty model map", func() {
|
||
|
Expect(modelLoader).ToNot(BeNil())
|
||
|
Expect(modelLoader.ModelPath).To(Equal(modelPath))
|
||
|
Expect(modelLoader.ListModels()).To(BeEmpty())
|
||
|
})
|
||
|
})
|
||
|
|
||
|
Context("ExistsInModelPath", func() {
|
||
|
It("should return true if a file exists in the model path", func() {
|
||
|
testFile := filepath.Join(modelPath, "test.model")
|
||
|
os.Create(testFile)
|
||
|
Expect(modelLoader.ExistsInModelPath("test.model")).To(BeTrue())
|
||
|
})
|
||
|
|
||
|
It("should return false if a file does not exist in the model path", func() {
|
||
|
Expect(modelLoader.ExistsInModelPath("nonexistent.model")).To(BeFalse())
|
||
|
})
|
||
|
})
|
||
|
|
||
|
Context("ListFilesInModelPath", func() {
|
||
|
It("should list all valid model files in the model path", func() {
|
||
|
os.Create(filepath.Join(modelPath, "test.model"))
|
||
|
os.Create(filepath.Join(modelPath, "README.md"))
|
||
|
|
||
|
files, err := modelLoader.ListFilesInModelPath()
|
||
|
Expect(err).To(BeNil())
|
||
|
Expect(files).To(ContainElement("test.model"))
|
||
|
Expect(files).ToNot(ContainElement("README.md"))
|
||
|
})
|
||
|
})
|
||
|
|
||
|
Context("LoadModel", func() {
|
||
|
It("should load a model and keep it in memory", func() {
|
||
|
mockModel = model.NewModel("test.model")
|
||
|
|
||
|
mockLoader := func(modelName, modelFile string) (*model.Model, error) {
|
||
|
return mockModel, nil
|
||
|
}
|
||
|
|
||
|
model, err := modelLoader.LoadModel("test.model", mockLoader)
|
||
|
Expect(err).To(BeNil())
|
||
|
Expect(model).To(Equal(mockModel))
|
||
|
Expect(modelLoader.CheckIsLoaded("test.model")).To(Equal(mockModel))
|
||
|
})
|
||
|
|
||
|
It("should return an error if loading the model fails", func() {
|
||
|
mockLoader := func(modelName, modelFile string) (*model.Model, error) {
|
||
|
return nil, errors.New("failed to load model")
|
||
|
}
|
||
|
|
||
|
model, err := modelLoader.LoadModel("test.model", mockLoader)
|
||
|
Expect(err).To(HaveOccurred())
|
||
|
Expect(model).To(BeNil())
|
||
|
})
|
||
|
})
|
||
|
|
||
|
Context("ShutdownModel", func() {
|
||
|
It("should shutdown a loaded model", func() {
|
||
|
mockModel = model.NewModel("test.model")
|
||
|
|
||
|
mockLoader := func(modelName, modelFile string) (*model.Model, error) {
|
||
|
return mockModel, nil
|
||
|
}
|
||
|
|
||
|
_, err := modelLoader.LoadModel("test.model", mockLoader)
|
||
|
Expect(err).To(BeNil())
|
||
|
|
||
|
err = modelLoader.ShutdownModel("test.model")
|
||
|
Expect(err).To(BeNil())
|
||
|
Expect(modelLoader.CheckIsLoaded("test.model")).To(BeNil())
|
||
|
})
|
||
|
})
|
||
|
})
|