diff --git a/core/http/app_test.go b/core/http/app_test.go index 5776b99a..6e9de246 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -73,7 +73,8 @@ func getModelStatus(url string) (response map[string]interface{}) { } func getModels(url string) (response []gallery.GalleryModel) { - downloader.GetURI(url, func(url string, i []byte) error { + // TODO: No tests currently seem to exercise file:// urls. Fix? + downloader.GetURI(url, "", func(url string, i []byte) error { // Unmarshal YAML data into a struct return json.Unmarshal(i, &response) }) diff --git a/core/services/gallery.go b/core/services/gallery.go index ed6f6165..e20e733a 100644 --- a/core/services/gallery.go +++ b/core/services/gallery.go @@ -32,7 +32,7 @@ func NewGalleryService(modelPath string) *GalleryService { func prepareModel(modelPath string, req gallery.GalleryModel, cl *config.BackendConfigLoader, downloadStatus func(string, string, string, float64)) error { - config, err := gallery.GetGalleryConfigFromURL(req.URL) + config, err := gallery.GetGalleryConfigFromURL(req.URL, modelPath) if err != nil { return err } diff --git a/embedded/embedded.go b/embedded/embedded.go index 438a1352..1fc59b4d 100644 --- a/embedded/embedded.go +++ b/embedded/embedded.go @@ -36,10 +36,10 @@ func init() { } } -func GetRemoteLibraryShorteners(url string) (map[string]string, error) { +func GetRemoteLibraryShorteners(url string, basePath string) (map[string]string, error) { remoteLibrary := map[string]string{} - err := downloader.GetURI(url, func(_ string, i []byte) error { + err := downloader.GetURI(url, basePath, func(_ string, i []byte) error { return yaml.Unmarshal(i, &remoteLibrary) }) if err != nil { diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 797a264b..0848a238 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -23,7 +23,7 @@ const ( GithubURI2 = "github://" ) -func GetURI(url string, f func(url string, i []byte) error) error { +func GetURI(url string, basePath string, f func(url string, i []byte) error) error { url = ConvertURL(url) if strings.HasPrefix(url, "file://") { @@ -33,6 +33,11 @@ func GetURI(url string, f func(url string, i []byte) error) error { if err != nil { return err } + // Check if the local file is rooted in basePath + err = utils.VerifyPath(resolvedFile, basePath) + if err != nil { + return err + } // Read the response body body, err := os.ReadFile(resolvedFile) if err != nil { diff --git a/pkg/downloader/uri_test.go b/pkg/downloader/uri_test.go index cd17b7ca..3ab04e56 100644 --- a/pkg/downloader/uri_test.go +++ b/pkg/downloader/uri_test.go @@ -10,7 +10,7 @@ var _ = Describe("Gallery API tests", func() { Context("URI", func() { It("parses github with a branch", func() { Expect( - GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml", func(url string, i []byte) error { + GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml", "", func(url string, i []byte) error { Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) return nil }), @@ -18,7 +18,7 @@ var _ = Describe("Gallery API tests", func() { }) It("parses github without a branch", func() { Expect( - GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml@main", func(url string, i []byte) error { + GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml@main", "", func(url string, i []byte) error { Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) return nil }), @@ -26,7 +26,7 @@ var _ = Describe("Gallery API tests", func() { }) It("parses github with urls", func() { Expect( - GetURI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml", func(url string, i []byte) error { + GetURI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml", "", func(url string, i []byte) error { Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) return nil }), diff --git a/pkg/gallery/gallery.go b/pkg/gallery/gallery.go index 6202529a..0e9daa79 100644 --- a/pkg/gallery/gallery.go +++ b/pkg/gallery/gallery.go @@ -27,7 +27,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string, if len(model.URL) > 0 { var err error - config, err = GetGalleryConfigFromURL(model.URL) + config, err = GetGalleryConfigFromURL(model.URL, basePath) if err != nil { return err } @@ -142,9 +142,9 @@ func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryMod return models, nil } -func findGalleryURLFromReferenceURL(url string) (string, error) { +func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) { var refFile string - err := downloader.GetURI(url, func(url string, d []byte) error { + err := downloader.GetURI(url, basePath, func(url string, d []byte) error { refFile = string(d) if len(refFile) == 0 { return fmt.Errorf("invalid reference file at url %s: %s", url, d) @@ -161,13 +161,13 @@ func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error) if strings.HasSuffix(gallery.URL, ".ref") { var err error - gallery.URL, err = findGalleryURLFromReferenceURL(gallery.URL) + gallery.URL, err = findGalleryURLFromReferenceURL(gallery.URL, basePath) if err != nil { return models, err } } - err := downloader.GetURI(gallery.URL, func(url string, d []byte) error { + err := downloader.GetURI(gallery.URL, basePath, func(url string, d []byte) error { return yaml.Unmarshal(d, &models) }) if err != nil { diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go index e697fcd6..225097c0 100644 --- a/pkg/gallery/models.go +++ b/pkg/gallery/models.go @@ -63,9 +63,9 @@ type PromptTemplate struct { Content string `yaml:"content"` } -func GetGalleryConfigFromURL(url string) (Config, error) { +func GetGalleryConfigFromURL(url string, basePath string) (Config, error) { var config Config - err := downloader.GetURI(url, func(url string, d []byte) error { + err := downloader.GetURI(url, basePath, func(url string, d []byte) error { return yaml.Unmarshal(d, &config) }) if err != nil { diff --git a/pkg/gallery/request_test.go b/pkg/gallery/request_test.go index a9d54e32..af085e81 100644 --- a/pkg/gallery/request_test.go +++ b/pkg/gallery/request_test.go @@ -10,7 +10,7 @@ var _ = Describe("Gallery API tests", func() { Context("requests", func() { It("parses github with a branch", func() { req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"} - e, err := GetGalleryConfigFromURL(req.URL) + e, err := GetGalleryConfigFromURL(req.URL, "") Expect(err).ToNot(HaveOccurred()) Expect(e.Name).To(Equal("gpt4all-j")) }) diff --git a/pkg/startup/model_preload.go b/pkg/startup/model_preload.go index d267d846..240fc6bd 100644 --- a/pkg/startup/model_preload.go +++ b/pkg/startup/model_preload.go @@ -20,7 +20,7 @@ func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, model // As a best effort, try to resolve the model from the remote library // if it's not resolved we try with the other method below if modelLibraryURL != "" { - lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL) + lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL, modelPath) if err == nil { if lib[url] != "" { log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url])