diff --git a/core/backend/llm.go b/core/backend/llm.go index 87bdbe36..a6f7fe56 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -12,7 +12,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/pkg/gallery" + "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" diff --git a/core/cli/models.go b/core/cli/models.go index 1a9ac8a8..d62ad318 100644 --- a/core/cli/models.go +++ b/core/cli/models.go @@ -5,9 +5,10 @@ import ( "fmt" cliContext "github.com/mudler/LocalAI/core/cli/context" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/downloader" - "github.com/mudler/LocalAI/pkg/gallery" "github.com/mudler/LocalAI/pkg/startup" "github.com/rs/zerolog/log" "github.com/schollz/progressbar/v3" @@ -34,7 +35,7 @@ type ModelsCMD struct { } func (ml *ModelsList) Run(ctx *cliContext.Context) error { - var galleries []gallery.Gallery + var galleries []config.Gallery if err := json.Unmarshal([]byte(ml.Galleries), &galleries); err != nil { log.Error().Err(err).Msg("unable to load galleries") } @@ -54,7 +55,7 @@ func (ml *ModelsList) Run(ctx *cliContext.Context) error { } func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { - var galleries []gallery.Gallery + var galleries []config.Gallery if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil { log.Error().Err(err).Msg("unable to load galleries") } diff --git a/core/config/application_config.go b/core/config/application_config.go index 24672e6b..65c716f8 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -6,7 +6,6 @@ import ( "encoding/json" "time" - "github.com/mudler/LocalAI/pkg/gallery" "github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/rs/zerolog/log" ) @@ -36,7 +35,7 @@ type ApplicationConfig struct { ModelLibraryURL string - Galleries []gallery.Gallery + Galleries []Gallery BackendAssets embed.FS AssetsDestination string @@ -180,10 +179,10 @@ func WithBackendAssets(f embed.FS) AppOption { func WithStringGalleries(galls string) AppOption { return func(o *ApplicationConfig) { if galls == "" { - o.Galleries = []gallery.Gallery{} + o.Galleries = []Gallery{} return } - var galleries []gallery.Gallery + var galleries []Gallery if err := json.Unmarshal([]byte(galls), &galleries); err != nil { log.Error().Err(err).Msg("failed loading galleries") } @@ -191,7 +190,7 @@ func WithStringGalleries(galls string) AppOption { } } -func WithGalleries(galleries []gallery.Gallery) AppOption { +func WithGalleries(galleries []Gallery) AppOption { return func(o *ApplicationConfig) { o.Galleries = append(o.Galleries, galleries...) } diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 040b6e78..1e647ceb 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -390,10 +390,6 @@ func (c *BackendConfig) Validate() bool { } } - if c.Name == "" { - return false - } - if c.Backend != "" { // a regex that checks that is a string name with no special characters, except '-' and '_' re := regexp.MustCompile(`^[a-zA-Z0-9-_]+$`) diff --git a/core/config/backend_config_test.go b/core/config/backend_config_test.go index 48bcfa9c..da245933 100644 --- a/core/config/backend_config_test.go +++ b/core/config/backend_config_test.go @@ -16,7 +16,8 @@ var _ = Describe("Test cases for config related functions", func() { Expect(err).To(BeNil()) defer os.Remove(tmp.Name()) _, err = tmp.WriteString( - `backend: "foo-bar" + `backend: "../foo-bar" +name: "foo" parameters: model: "foo-bar"`) Expect(err).ToNot(HaveOccurred()) diff --git a/core/config/gallery.go b/core/config/gallery.go new file mode 100644 index 00000000..002100be --- /dev/null +++ b/core/config/gallery.go @@ -0,0 +1,6 @@ +package config + +type Gallery struct { + URL string `json:"url" yaml:"url"` + Name string `json:"name" yaml:"name"` +} diff --git a/pkg/gallery/gallery.go b/core/gallery/gallery.go similarity index 93% rename from pkg/gallery/gallery.go rename to core/gallery/gallery.go index 49cfd054..be167755 100644 --- a/pkg/gallery/gallery.go +++ b/core/gallery/gallery.go @@ -8,18 +8,14 @@ import ( "strings" "github.com/imdario/mergo" + "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/downloader" "github.com/rs/zerolog/log" "gopkg.in/yaml.v2" ) -type Gallery struct { - URL string `json:"url" yaml:"url"` - Name string `json:"name" yaml:"name"` -} - // Installs a model from the gallery -func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { +func InstallModelFromGallery(galleries []config.Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { applyModel := func(model *GalleryModel) error { name = strings.ReplaceAll(name, string(os.PathSeparator), "__") @@ -117,7 +113,7 @@ func FindModel(models []*GalleryModel, name string, basePath string) *GalleryMod // List available models // Models galleries are a list of yaml files that are hosted on a remote server (for example github). // Each yaml file contains a list of models that can be downloaded and optionally overrides to define a new model setting. -func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryModel, error) { +func AvailableGalleryModels(galleries []config.Gallery, basePath string) ([]*GalleryModel, error) { var models []*GalleryModel // Get models from galleries @@ -146,7 +142,7 @@ func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) return refFile, err } -func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error) { +func getGalleryModels(gallery config.Gallery, basePath string) ([]*GalleryModel, error) { var models []*GalleryModel = []*GalleryModel{} if strings.HasSuffix(gallery.URL, ".ref") { diff --git a/pkg/gallery/gallery_suite_test.go b/core/gallery/gallery_suite_test.go similarity index 100% rename from pkg/gallery/gallery_suite_test.go rename to core/gallery/gallery_suite_test.go diff --git a/pkg/gallery/models.go b/core/gallery/models.go similarity index 93% rename from pkg/gallery/models.go rename to core/gallery/models.go index 5819c617..8d020ff5 100644 --- a/pkg/gallery/models.go +++ b/core/gallery/models.go @@ -6,8 +6,10 @@ import ( "path/filepath" "github.com/imdario/mergo" + lconfig "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" "gopkg.in/yaml.v2" ) @@ -172,6 +174,15 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides return fmt.Errorf("failed to marshal updated config YAML: %v", err) } + backendConfig := lconfig.BackendConfig{} + err = yaml.Unmarshal(updatedConfigYAML, &backendConfig) + if err != nil { + return fmt.Errorf("failed to unmarshal updated config YAML: %v", err) + } + if !backendConfig.Validate() { + return fmt.Errorf("failed to validate updated config YAML") + } + err = os.WriteFile(configFilePath, updatedConfigYAML, 0600) if err != nil { return fmt.Errorf("failed to write updated config file: %v", err) diff --git a/pkg/gallery/models_test.go b/core/gallery/models_test.go similarity index 97% rename from pkg/gallery/models_test.go rename to core/gallery/models_test.go index 3f1a68b1..17a30911 100644 --- a/pkg/gallery/models_test.go +++ b/core/gallery/models_test.go @@ -5,7 +5,8 @@ import ( "os" "path/filepath" - . "github.com/mudler/LocalAI/pkg/gallery" + "github.com/mudler/LocalAI/core/config" + . "github.com/mudler/LocalAI/core/gallery" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gopkg.in/yaml.v3" @@ -54,7 +55,7 @@ var _ = Describe("Model test", func() { err = os.WriteFile(galleryFilePath, out, 0600) Expect(err).ToNot(HaveOccurred()) Expect(filepath.IsAbs(galleryFilePath)).To(BeTrue(), galleryFilePath) - galleries := []Gallery{ + galleries := []config.Gallery{ { Name: "test", URL: "file://" + galleryFilePath, diff --git a/pkg/gallery/op.go b/core/gallery/op.go similarity index 87% rename from pkg/gallery/op.go rename to core/gallery/op.go index ad1bcbaf..d3795a00 100644 --- a/pkg/gallery/op.go +++ b/core/gallery/op.go @@ -1,5 +1,7 @@ package gallery +import "github.com/mudler/LocalAI/core/config" + type GalleryOp struct { Id string GalleryModelName string @@ -7,7 +9,7 @@ type GalleryOp struct { Delete bool Req GalleryModel - Galleries []Gallery + Galleries []config.Gallery } type GalleryOpStatus struct { diff --git a/pkg/gallery/request.go b/core/gallery/request.go similarity index 94% rename from pkg/gallery/request.go rename to core/gallery/request.go index 61a25912..eec764c1 100644 --- a/pkg/gallery/request.go +++ b/core/gallery/request.go @@ -3,6 +3,8 @@ package gallery import ( "fmt" "strings" + + "github.com/mudler/LocalAI/core/config" ) // GalleryModel is the struct used to represent a model in the gallery returned by the endpoint. @@ -23,7 +25,7 @@ type GalleryModel struct { // AdditionalFiles are used to add additional files to the model AdditionalFiles []File `json:"files,omitempty" yaml:"files,omitempty"` // Gallery is a reference to the gallery which contains the model - Gallery Gallery `json:"gallery,omitempty" yaml:"gallery,omitempty"` + Gallery config.Gallery `json:"gallery,omitempty" yaml:"gallery,omitempty"` // Installed is used to indicate if the model is installed or not Installed bool `json:"installed,omitempty" yaml:"installed,omitempty"` } diff --git a/pkg/gallery/request_test.go b/core/gallery/request_test.go similarity index 90% rename from pkg/gallery/request_test.go rename to core/gallery/request_test.go index 6600f494..23281cc6 100644 --- a/pkg/gallery/request_test.go +++ b/core/gallery/request_test.go @@ -1,7 +1,7 @@ package gallery_test import ( - . "github.com/mudler/LocalAI/pkg/gallery" + . "github.com/mudler/LocalAI/core/gallery" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) diff --git a/core/http/app_test.go b/core/http/app_test.go index 6b5e531b..3fb16581 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -19,8 +19,8 @@ import ( "github.com/mudler/LocalAI/core/startup" "github.com/gofiber/fiber/v2" + "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/downloader" - "github.com/mudler/LocalAI/pkg/gallery" "github.com/mudler/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -247,7 +247,7 @@ var _ = Describe("API test", func() { err = os.WriteFile(filepath.Join(modelDir, "gallery_simple.yaml"), out, 0600) Expect(err).ToNot(HaveOccurred()) - galleries := []gallery.Gallery{ + galleries := []config.Gallery{ { Name: "test", URL: "file://" + filepath.Join(modelDir, "gallery_simple.yaml"), @@ -603,7 +603,7 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) - galleries := []gallery.Gallery{ + galleries := []config.Gallery{ { Name: "model-gallery", URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/index.yaml", diff --git a/core/http/elements/gallery.go b/core/http/elements/gallery.go index 1a92ee12..373de038 100644 --- a/core/http/elements/gallery.go +++ b/core/http/elements/gallery.go @@ -6,8 +6,8 @@ import ( "github.com/chasefleming/elem-go" "github.com/chasefleming/elem-go/attrs" + "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/services" - "github.com/mudler/LocalAI/pkg/gallery" "github.com/mudler/LocalAI/pkg/xsync" ) diff --git a/core/http/endpoints/localai/gallery.go b/core/http/endpoints/localai/gallery.go index 3fd03d43..9c49d641 100644 --- a/core/http/endpoints/localai/gallery.go +++ b/core/http/endpoints/localai/gallery.go @@ -7,13 +7,14 @@ import ( "github.com/gofiber/fiber/v2" "github.com/google/uuid" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/services" - "github.com/mudler/LocalAI/pkg/gallery" "github.com/rs/zerolog/log" ) type ModelGalleryEndpointService struct { - galleries []gallery.Gallery + galleries []config.Gallery modelPath string galleryApplier *services.GalleryService } @@ -24,7 +25,7 @@ type GalleryModel struct { gallery.GalleryModel } -func CreateModelGalleryEndpointService(galleries []gallery.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService { +func CreateModelGalleryEndpointService(galleries []config.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService { return ModelGalleryEndpointService{ galleries: galleries, modelPath: modelPath, @@ -129,12 +130,12 @@ func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fib func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - input := new(gallery.Gallery) + input := new(config.Gallery) // Get input data from the request body if err := c.BodyParser(input); err != nil { return err } - if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { + if slices.ContainsFunc(mgs.galleries, func(gallery config.Gallery) bool { return gallery.Name == input.Name }) { return fmt.Errorf("%s already exists", input.Name) @@ -151,17 +152,17 @@ func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber. func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - input := new(gallery.Gallery) + input := new(config.Gallery) // Get input data from the request body if err := c.BodyParser(input); err != nil { return err } - if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { + if !slices.ContainsFunc(mgs.galleries, func(gallery config.Gallery) bool { return gallery.Name == input.Name }) { return fmt.Errorf("%s is not currently registered", input.Name) } - mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool { + mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery config.Gallery) bool { return gallery.Name == input.Name }) return c.Send(nil) diff --git a/core/http/endpoints/localai/welcome.go b/core/http/endpoints/localai/welcome.go index 7e2a7938..fa00e900 100644 --- a/core/http/endpoints/localai/welcome.go +++ b/core/http/endpoints/localai/welcome.go @@ -3,8 +3,8 @@ package localai import ( "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/internal" - "github.com/mudler/LocalAI/pkg/gallery" "github.com/mudler/LocalAI/pkg/model" ) diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go index 7f7cc3a2..009de4a0 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -257,5 +257,9 @@ func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *c // Set the parameters for the language model prediction updateRequestConfig(cfg, input) + if !cfg.Validate() { + return nil, nil, fmt.Errorf("failed to validate config") + } + return cfg, input, err } diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index 1624cad3..eddcc6fc 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -32,7 +32,7 @@ func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + return fmt.Errorf("failed reading parameters from request: %w", err) } // retrieve the file data from the request file, err := c.FormFile("file") diff --git a/core/http/routes/ui.go b/core/http/routes/ui.go index c13ee745..3c805422 100644 --- a/core/http/routes/ui.go +++ b/core/http/routes/ui.go @@ -7,11 +7,11 @@ import ( "strings" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/http/elements" "github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" - "github.com/mudler/LocalAI/pkg/gallery" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/xsync" "github.com/rs/zerolog/log" diff --git a/core/services/gallery.go b/core/services/gallery.go index 6382e595..2c0ed435 100644 --- a/core/services/gallery.go +++ b/core/services/gallery.go @@ -9,7 +9,7 @@ import ( "sync" "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/pkg/gallery" + "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/startup" "github.com/mudler/LocalAI/pkg/utils" "gopkg.in/yaml.v2" @@ -96,6 +96,7 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader // delete a model if op.Delete { modelConfig := &config.BackendConfig{} + // Galleryname is the name of the model in this case dat, err := os.ReadFile(filepath.Join(g.appConfig.ModelPath, op.GalleryModelName+".yaml")) if err != nil { @@ -174,7 +175,7 @@ type galleryModel struct { ID string `json:"id"` } -func processRequests(modelPath string, galleries []gallery.Gallery, requests []galleryModel) error { +func processRequests(modelPath string, galleries []config.Gallery, requests []galleryModel) error { var err error for _, r := range requests { utils.ResetDownloadTimers() @@ -189,7 +190,7 @@ func processRequests(modelPath string, galleries []gallery.Gallery, requests []g return err } -func ApplyGalleryFromFile(modelPath, s string, galleries []gallery.Gallery) error { +func ApplyGalleryFromFile(modelPath, s string, galleries []config.Gallery) error { dat, err := os.ReadFile(s) if err != nil { return err @@ -203,7 +204,7 @@ func ApplyGalleryFromFile(modelPath, s string, galleries []gallery.Gallery) erro return processRequests(modelPath, galleries, requests) } -func ApplyGalleryFromString(modelPath, s string, galleries []gallery.Gallery) error { +func ApplyGalleryFromString(modelPath, s string, galleries []config.Gallery) error { var requests []galleryModel err := json.Unmarshal([]byte(s), &requests) if err != nil { diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index cfc263da..c1676708 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -13,6 +13,7 @@ import ( "github.com/klauspost/cpuid/v2" grpc "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/library" + "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/phayes/freeport" "github.com/rs/zerolog/log" @@ -309,6 +310,9 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string } } else { grpcProcess := backendPath(o.assetDir, backend) + if err := utils.VerifyPath(grpcProcess, o.assetDir); err != nil { + return "", fmt.Errorf("grpc process not found in assetdir: %s", err.Error()) + } if autoDetect { // autoDetect GRPC process to start based on system capabilities diff --git a/pkg/startup/model_preload.go b/pkg/startup/model_preload.go index ba869166..d678f283 100644 --- a/pkg/startup/model_preload.go +++ b/pkg/startup/model_preload.go @@ -7,9 +7,10 @@ import ( "path/filepath" "strings" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/embedded" "github.com/mudler/LocalAI/pkg/downloader" - "github.com/mudler/LocalAI/pkg/gallery" "github.com/mudler/LocalAI/pkg/utils" "github.com/rs/zerolog/log" ) @@ -17,7 +18,7 @@ import ( // InstallModels will preload models from the given list of URLs and galleries // It will download the model if it is not already present in the model path // It will also try to resolve if the model is an embedded model YAML configuration -func InstallModels(galleries []gallery.Gallery, modelLibraryURL string, modelPath string, downloadStatus func(string, string, string, float64), models ...string) error { +func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath string, downloadStatus func(string, string, string, float64), models ...string) error { // create an error that groups all errors var err error @@ -126,7 +127,7 @@ func InstallModels(galleries []gallery.Gallery, modelLibraryURL string, modelPat return err } -func installModel(galleries []gallery.Gallery, modelName, modelPath string, downloadStatus func(string, string, string, float64)) (error, bool) { +func installModel(galleries []config.Gallery, modelName, modelPath string, downloadStatus func(string, string, string, float64)) (error, bool) { models, err := gallery.AvailableGalleryModels(galleries, modelPath) if err != nil { return err, false diff --git a/pkg/startup/model_preload_test.go b/pkg/startup/model_preload_test.go index f99d2a3c..e3d7d979 100644 --- a/pkg/startup/model_preload_test.go +++ b/pkg/startup/model_preload_test.go @@ -5,7 +5,7 @@ import ( "os" "path/filepath" - "github.com/mudler/LocalAI/pkg/gallery" + "github.com/mudler/LocalAI/core/config" . "github.com/mudler/LocalAI/pkg/startup" "github.com/mudler/LocalAI/pkg/utils" @@ -22,7 +22,7 @@ var _ = Describe("Preload test", func() { libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml" fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719") - InstallModels([]gallery.Gallery{}, libraryURL, tmpdir, nil, "phi-2") + InstallModels([]config.Gallery{}, libraryURL, tmpdir, nil, "phi-2") resultFile := filepath.Join(tmpdir, fileName) @@ -38,7 +38,7 @@ var _ = Describe("Preload test", func() { url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml" fileName := fmt.Sprintf("%s.yaml", utils.MD5(url)) - InstallModels([]gallery.Gallery{}, "", tmpdir, nil, url) + InstallModels([]config.Gallery{}, "", tmpdir, nil, url) resultFile := filepath.Join(tmpdir, fileName) @@ -52,7 +52,7 @@ var _ = Describe("Preload test", func() { Expect(err).ToNot(HaveOccurred()) url := "phi-2" - InstallModels([]gallery.Gallery{}, "", tmpdir, nil, url) + InstallModels([]config.Gallery{}, "", tmpdir, nil, url) entry, err := os.ReadDir(tmpdir) Expect(err).ToNot(HaveOccurred()) @@ -70,7 +70,7 @@ var _ = Describe("Preload test", func() { url := "mistral-openorca" fileName := fmt.Sprintf("%s.yaml", utils.MD5(url)) - InstallModels([]gallery.Gallery{}, "", tmpdir, nil, url) + InstallModels([]config.Gallery{}, "", tmpdir, nil, url) resultFile := filepath.Join(tmpdir, fileName) diff --git a/pkg/utils/untar.go b/pkg/utils/untar.go index 782b2d17..ed6c6cb2 100644 --- a/pkg/utils/untar.go +++ b/pkg/utils/untar.go @@ -2,6 +2,7 @@ package utils import ( "fmt" + "os" "github.com/mholt/archiver/v3" ) @@ -52,5 +53,17 @@ func ExtractArchive(archive, dst string) error { case *archiver.TarZstd: v.Tar = mytar } + + err = archiver.Walk(archive, func(f archiver.File) error { + if f.FileInfo.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("archive contains a symlink") + } + return nil + }) + + if err != nil { + return err + } + return un.Unarchive(archive, dst) }