From a36b721ca63436d72d18db7c39df47b506fcaba5 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 2 Aug 2024 20:06:25 +0200 Subject: [PATCH] fix: be consistent in downloading files, check for scanner errors (#3108) * fix(downloader): be consistent in downloading files This PR puts some order in the downloader such as functions are re-used across several places. This fixes an issue with having uri's inside the model YAML file, it would resolve to MD5 rather then using the filename Signed-off-by: Ettore Di Giacinto * fix(scanner): do raise error only if unsafeFiles are found Fixes: https://github.com/mudler/LocalAI/issues/3114 Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- core/cli/models.go | 4 +- core/cli/util.go | 4 +- core/config/backend_config.go | 27 +++-- core/config/backend_config_loader.go | 10 +- core/dependencies_manager/manager.go | 3 +- core/gallery/gallery.go | 10 +- core/gallery/models.go | 11 +- core/http/app_test.go | 3 +- embedded/embedded.go | 4 +- pkg/downloader/huggingface.go | 49 +++++++++ pkg/downloader/uri.go | 157 ++++++++++----------------- pkg/downloader/uri_test.go | 10 +- pkg/startup/model_preload.go | 52 +++------ 13 files changed, 173 insertions(+), 171 deletions(-) create mode 100644 pkg/downloader/huggingface.go diff --git a/core/cli/models.go b/core/cli/models.go index 03047018..56d13fc7 100644 --- a/core/cli/models.go +++ b/core/cli/models.go @@ -83,7 +83,9 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { return err } - if !downloader.LooksLikeOCI(modelName) { + modelURI := downloader.URI(modelName) + + if !modelURI.LooksLikeOCI() { model := gallery.FindModel(models, modelName, mi.ModelsPath) if model == nil { log.Error().Str("model", modelName).Msg("model not found") diff --git a/core/cli/util.go b/core/cli/util.go index a7204092..b3e545d8 100644 --- a/core/cli/util.go +++ b/core/cli/util.go @@ -86,8 +86,8 @@ func (hfscmd *HFScanCMD) Run(ctx *cliContext.Context) error { var errs error = nil for _, uri := range hfscmd.ToScan { log.Info().Str("uri", uri).Msg("scanning specific uri") - scanResults, err := downloader.HuggingFaceScan(uri) - if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) { + scanResults, err := downloader.HuggingFaceScan(downloader.URI(uri)) + if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) { log.Error().Err(err).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("! WARNING ! A known-vulnerable model is included in this repo!") errs = errors.Join(errs, err) } diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 383686cd..b83e1a98 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -8,7 +8,6 @@ import ( "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/functions" - "github.com/mudler/LocalAI/pkg/utils" ) const ( @@ -72,9 +71,9 @@ type BackendConfig struct { } type File struct { - Filename string `yaml:"filename" json:"filename"` - SHA256 string `yaml:"sha256" json:"sha256"` - URI string `yaml:"uri" json:"uri"` + Filename string `yaml:"filename" json:"filename"` + SHA256 string `yaml:"sha256" json:"sha256"` + URI downloader.URI `yaml:"uri" json:"uri"` } type VallE struct { @@ -213,28 +212,32 @@ func (c *BackendConfig) ShouldCallSpecificFunction() bool { // MMProjFileName returns the filename of the MMProj file // If the MMProj is a URL, it will return the MD5 of the URL which is the filename func (c *BackendConfig) MMProjFileName() string { - modelURL := downloader.ConvertURL(c.MMProj) - if downloader.LooksLikeURL(modelURL) { - return utils.MD5(modelURL) + uri := downloader.URI(c.MMProj) + if uri.LooksLikeURL() { + f, _ := uri.FilenameFromUrl() + return f } return c.MMProj } func (c *BackendConfig) IsMMProjURL() bool { - return downloader.LooksLikeURL(downloader.ConvertURL(c.MMProj)) + uri := downloader.URI(c.MMProj) + return uri.LooksLikeURL() } func (c *BackendConfig) IsModelURL() bool { - return downloader.LooksLikeURL(downloader.ConvertURL(c.Model)) + uri := downloader.URI(c.Model) + return uri.LooksLikeURL() } // ModelFileName returns the filename of the model // If the model is a URL, it will return the MD5 of the URL which is the filename func (c *BackendConfig) ModelFileName() string { - modelURL := downloader.ConvertURL(c.Model) - if downloader.LooksLikeURL(modelURL) { - return utils.MD5(modelURL) + uri := downloader.URI(c.Model) + if uri.LooksLikeURL() { + f, _ := uri.FilenameFromUrl() + return f } return c.Model diff --git a/core/config/backend_config_loader.go b/core/config/backend_config_loader.go index 283dac52..45fe259e 100644 --- a/core/config/backend_config_loader.go +++ b/core/config/backend_config_loader.go @@ -244,7 +244,7 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error { // Create file path filePath := filepath.Join(modelPath, file.Filename) - if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, i, len(config.DownloadFiles), status); err != nil { + if err := file.URI.DownloadFile(filePath, file.SHA256, i, len(config.DownloadFiles), status); err != nil { return err } } @@ -252,10 +252,10 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error { // If the model is an URL, expand it, and download the file if config.IsModelURL() { modelFileName := config.ModelFileName() - modelURL := downloader.ConvertURL(config.Model) + uri := downloader.URI(config.Model) // check if file exists if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) { - err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status) + err := uri.DownloadFile(filepath.Join(modelPath, modelFileName), "", 0, 0, status) if err != nil { return err } @@ -269,10 +269,10 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error { if config.IsMMProjURL() { modelFileName := config.MMProjFileName() - modelURL := downloader.ConvertURL(config.MMProj) + uri := downloader.URI(config.MMProj) // check if file exists if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) { - err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status) + err := uri.DownloadFile(filepath.Join(modelPath, modelFileName), "", 0, 0, status) if err != nil { return err } diff --git a/core/dependencies_manager/manager.go b/core/dependencies_manager/manager.go index b86139e0..8434f721 100644 --- a/core/dependencies_manager/manager.go +++ b/core/dependencies_manager/manager.go @@ -37,7 +37,8 @@ func main() { // download the assets for _, asset := range assets { - if err := downloader.DownloadFile(asset.URL, filepath.Join(destPath, asset.FileName), asset.SHA, 1, 1, utils.DisplayDownloadFunction); err != nil { + uri := downloader.URI(asset.URL) + if err := uri.DownloadFile(filepath.Join(destPath, asset.FileName), asset.SHA, 1, 1, utils.DisplayDownloadFunction); err != nil { panic(err) } } diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go index 9288c44f..6ced6244 100644 --- a/core/gallery/gallery.go +++ b/core/gallery/gallery.go @@ -131,7 +131,8 @@ func AvailableGalleryModels(galleries []config.Gallery, basePath string) ([]*Gal func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) { var refFile string - err := downloader.DownloadAndUnmarshal(url, basePath, func(url string, d []byte) error { + uri := downloader.URI(url) + err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error { refFile = string(d) if len(refFile) == 0 { return fmt.Errorf("invalid reference file at url %s: %s", url, d) @@ -153,8 +154,9 @@ func getGalleryModels(gallery config.Gallery, basePath string) ([]*GalleryModel, return models, err } } + uri := downloader.URI(gallery.URL) - err := downloader.DownloadAndUnmarshal(gallery.URL, basePath, func(url string, d []byte) error { + err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error { return yaml.Unmarshal(d, &models) }) if err != nil { @@ -252,8 +254,8 @@ func SafetyScanGalleryModels(galleries []config.Gallery, basePath string) error func SafetyScanGalleryModel(galleryModel *GalleryModel) error { for _, file := range galleryModel.AdditionalFiles { - scanResults, err := downloader.HuggingFaceScan(file.URI) - if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) { + scanResults, err := downloader.HuggingFaceScan(downloader.URI(file.URI)) + if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) { log.Error().Str("model", galleryModel.Name).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("Contains unsafe file(s)!") return err } diff --git a/core/gallery/models.go b/core/gallery/models.go index 32460a9c..dec6312e 100644 --- a/core/gallery/models.go +++ b/core/gallery/models.go @@ -68,7 +68,8 @@ type PromptTemplate struct { func GetGalleryConfigFromURL(url string, basePath string) (Config, error) { var config Config - err := downloader.DownloadAndUnmarshal(url, basePath, func(url string, d []byte) error { + uri := downloader.URI(url) + err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error { return yaml.Unmarshal(d, &config) }) if err != nil { @@ -118,14 +119,14 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides filePath := filepath.Join(basePath, file.Filename) if enforceScan { - scanResults, err := downloader.HuggingFaceScan(file.URI) - if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) { + scanResults, err := downloader.HuggingFaceScan(downloader.URI(file.URI)) + if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) { log.Error().Str("model", config.Name).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("Contains unsafe file(s)!") return err } } - - if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil { + uri := downloader.URI(file.URI) + if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil { return err } } diff --git a/core/http/app_test.go b/core/http/app_test.go index 3fb16581..b21ad25a 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -73,8 +73,9 @@ func getModelStatus(url string) (response map[string]interface{}) { } func getModels(url string) (response []gallery.GalleryModel) { + uri := downloader.URI(url) // TODO: No tests currently seem to exercise file:// urls. Fix? - downloader.DownloadAndUnmarshal(url, "", func(url string, i []byte) error { + uri.DownloadAndUnmarshal("", func(url string, i []byte) error { // Unmarshal YAML data into a struct return json.Unmarshal(i, &response) }) diff --git a/embedded/embedded.go b/embedded/embedded.go index d5fd72df..672c32ed 100644 --- a/embedded/embedded.go +++ b/embedded/embedded.go @@ -38,8 +38,8 @@ func init() { func GetRemoteLibraryShorteners(url string, basePath string) (map[string]string, error) { remoteLibrary := map[string]string{} - - err := downloader.DownloadAndUnmarshal(url, basePath, func(_ string, i []byte) error { + uri := downloader.URI(url) + err := uri.DownloadAndUnmarshal(basePath, func(_ string, i []byte) error { return yaml.Unmarshal(i, &remoteLibrary) }) if err != nil { diff --git a/pkg/downloader/huggingface.go b/pkg/downloader/huggingface.go new file mode 100644 index 00000000..34ba9bd9 --- /dev/null +++ b/pkg/downloader/huggingface.go @@ -0,0 +1,49 @@ +package downloader + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" +) + +type HuggingFaceScanResult struct { + RepositoryId string `json:"repositoryId"` + Revision string `json:"revision"` + HasUnsafeFiles bool `json:"hasUnsafeFile"` + ClamAVInfectedFiles []string `json:"clamAVInfectedFiles"` + DangerousPickles []string `json:"dangerousPickles"` + ScansDone bool `json:"scansDone"` +} + +var ErrNonHuggingFaceFile = errors.New("not a huggingface repo") +var ErrUnsafeFilesFound = errors.New("unsafe files found") + +func HuggingFaceScan(uri URI) (*HuggingFaceScanResult, error) { + cleanParts := strings.Split(uri.ResolveURL(), "/") + if len(cleanParts) <= 4 || cleanParts[2] != "huggingface.co" { + return nil, ErrNonHuggingFaceFile + } + results, err := http.Get(fmt.Sprintf("https://huggingface.co/api/models/%s/%s/scan", cleanParts[3], cleanParts[4])) + if err != nil { + return nil, err + } + if results.StatusCode != 200 { + return nil, fmt.Errorf("unexpected status code during HuggingFaceScan: %d", results.StatusCode) + } + scanResult := &HuggingFaceScanResult{} + bodyBytes, err := io.ReadAll(results.Body) + if err != nil { + return nil, err + } + err = json.Unmarshal(bodyBytes, scanResult) + if err != nil { + return nil, err + } + if scanResult.HasUnsafeFiles { + return scanResult, ErrUnsafeFilesFound + } + return scanResult, nil +} diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 1f88bbb1..7fedd646 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -2,12 +2,10 @@ package downloader import ( "crypto/sha256" - "encoding/base64" - "encoding/json" - "errors" "fmt" "io" "net/http" + "net/url" "os" "path/filepath" "strconv" @@ -28,13 +26,16 @@ const ( HTTPSPrefix = "https://" GithubURI = "github:" GithubURI2 = "github://" + LocalPrefix = "file://" ) -func DownloadAndUnmarshal(url string, basePath string, f func(url string, i []byte) error) error { - url = ConvertURL(url) +type URI string - if strings.HasPrefix(url, "file://") { - rawURL := strings.TrimPrefix(url, "file://") +func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte) error) error { + url := uri.ResolveURL() + + if strings.HasPrefix(url, LocalPrefix) { + rawURL := strings.TrimPrefix(url, LocalPrefix) // checks if the file is symbolic, and resolve if so - otherwise, this function returns the path unmodified. resolvedFile, err := filepath.EvalSymlinks(rawURL) if err != nil { @@ -78,24 +79,54 @@ func DownloadAndUnmarshal(url string, basePath string, f func(url string, i []by return f(url, body) } -func LooksLikeURL(s string) bool { - return strings.HasPrefix(s, HTTPPrefix) || - strings.HasPrefix(s, HTTPSPrefix) || - strings.HasPrefix(s, HuggingFacePrefix) || - strings.HasPrefix(s, GithubURI) || - strings.HasPrefix(s, OllamaPrefix) || - strings.HasPrefix(s, OCIPrefix) || - strings.HasPrefix(s, GithubURI2) +func (u URI) FilenameFromUrl() (string, error) { + f, err := filenameFromUrl(string(u)) + if err != nil || f == "" { + f = utils.MD5(string(u)) + if strings.HasSuffix(string(u), ".yaml") || strings.HasSuffix(string(u), ".yml") { + f = f + ".yaml" + } + err = nil + } + + return f, err } -func LooksLikeOCI(s string) bool { - return strings.HasPrefix(s, OCIPrefix) || strings.HasPrefix(s, OllamaPrefix) +func filenameFromUrl(urlstr string) (string, error) { + // strip anything after @ + if strings.Contains(urlstr, "@") { + urlstr = strings.Split(urlstr, "@")[0] + } + + u, err := url.Parse(urlstr) + if err != nil { + return "", fmt.Errorf("error due to parsing url: %w", err) + } + x, err := url.QueryUnescape(u.EscapedPath()) + if err != nil { + return "", fmt.Errorf("error due to escaping: %w", err) + } + return filepath.Base(x), nil } -func ConvertURL(s string) string { +func (u URI) LooksLikeURL() bool { + return strings.HasPrefix(string(u), HTTPPrefix) || + strings.HasPrefix(string(u), HTTPSPrefix) || + strings.HasPrefix(string(u), HuggingFacePrefix) || + strings.HasPrefix(string(u), GithubURI) || + strings.HasPrefix(string(u), OllamaPrefix) || + strings.HasPrefix(string(u), OCIPrefix) || + strings.HasPrefix(string(u), GithubURI2) +} + +func (s URI) LooksLikeOCI() bool { + return strings.HasPrefix(string(s), OCIPrefix) || strings.HasPrefix(string(s), OllamaPrefix) +} + +func (s URI) ResolveURL() string { switch { - case strings.HasPrefix(s, GithubURI2): - repository := strings.Replace(s, GithubURI2, "", 1) + case strings.HasPrefix(string(s), GithubURI2): + repository := strings.Replace(string(s), GithubURI2, "", 1) repoParts := strings.Split(repository, "@") branch := "main" @@ -110,8 +141,8 @@ func ConvertURL(s string) string { projectPath := strings.Join(repoPath[2:], "/") return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) - case strings.HasPrefix(s, GithubURI): - parts := strings.Split(s, ":") + case strings.HasPrefix(string(s), GithubURI): + parts := strings.Split(string(s), ":") repoParts := strings.Split(parts[1], "@") branch := "main" @@ -125,8 +156,8 @@ func ConvertURL(s string) string { projectPath := strings.Join(repoPath[2:], "/") return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) - case strings.HasPrefix(s, HuggingFacePrefix): - repository := strings.Replace(s, HuggingFacePrefix, "", 1) + case strings.HasPrefix(string(s), HuggingFacePrefix): + repository := strings.Replace(string(s), HuggingFacePrefix, "", 1) // convert repository to a full URL. // e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf owner := strings.Split(repository, "/")[0] @@ -144,7 +175,7 @@ func ConvertURL(s string) string { return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath) } - return s + return string(s) } func removePartialFile(tmpFilePath string) error { @@ -161,9 +192,9 @@ func removePartialFile(tmpFilePath string) error { return nil } -func DownloadFile(url string, filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error { - url = ConvertURL(url) - if LooksLikeOCI(url) { +func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error { + url := uri.ResolveURL() + if uri.LooksLikeOCI() { progressStatus := func(desc ocispec.Descriptor) io.Writer { return &progressWriter{ fileName: filePath, @@ -298,37 +329,6 @@ func DownloadFile(url string, filePath, sha string, fileN, total int, downloadSt return nil } -// this function check if the string is an URL, if it's an URL downloads the image in memory -// encodes it in base64 and returns the base64 string -func GetBase64Image(s string) (string, error) { - if strings.HasPrefix(s, "http") { - // download the image - resp, err := http.Get(s) - if err != nil { - return "", err - } - defer resp.Body.Close() - - // read the image data into memory - data, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - - // encode the image data in base64 - encoded := base64.StdEncoding.EncodeToString(data) - - // return the base64 string - return encoded, nil - } - - // if the string instead is prefixed with "data:image/jpeg;base64,", drop it - if strings.HasPrefix(s, "data:image/jpeg;base64,") { - return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil - } - return "", fmt.Errorf("not valid string") -} - func formatBytes(bytes int64) string { const unit = 1024 if bytes < unit { @@ -356,42 +356,3 @@ func calculateSHA(filePath string) (string, error) { return fmt.Sprintf("%x", hash.Sum(nil)), nil } - -type HuggingFaceScanResult struct { - RepositoryId string `json:"repositoryId"` - Revision string `json:"revision"` - HasUnsafeFiles bool `json:"hasUnsafeFile"` - ClamAVInfectedFiles []string `json:"clamAVInfectedFiles"` - DangerousPickles []string `json:"dangerousPickles"` - ScansDone bool `json:"scansDone"` -} - -var ErrNonHuggingFaceFile = errors.New("not a huggingface repo") -var ErrUnsafeFilesFound = errors.New("unsafe files found") - -func HuggingFaceScan(uri string) (*HuggingFaceScanResult, error) { - cleanParts := strings.Split(ConvertURL(uri), "/") - if len(cleanParts) <= 4 || cleanParts[2] != "huggingface.co" { - return nil, ErrNonHuggingFaceFile - } - results, err := http.Get(fmt.Sprintf("https://huggingface.co/api/models/%s/%s/scan", cleanParts[3], cleanParts[4])) - if err != nil { - return nil, err - } - if results.StatusCode != 200 { - return nil, fmt.Errorf("unexpected status code during HuggingFaceScan: %d", results.StatusCode) - } - scanResult := &HuggingFaceScanResult{} - bodyBytes, err := io.ReadAll(results.Body) - if err != nil { - return nil, err - } - err = json.Unmarshal(bodyBytes, scanResult) - if err != nil { - return nil, err - } - if scanResult.HasUnsafeFiles { - return scanResult, ErrUnsafeFilesFound - } - return scanResult, nil -} diff --git a/pkg/downloader/uri_test.go b/pkg/downloader/uri_test.go index 66a4cb4e..21a093a9 100644 --- a/pkg/downloader/uri_test.go +++ b/pkg/downloader/uri_test.go @@ -9,24 +9,28 @@ import ( var _ = Describe("Gallery API tests", func() { Context("URI", func() { It("parses github with a branch", func() { + uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml") Expect( - DownloadAndUnmarshal("github:go-skynet/model-gallery/gpt4all-j.yaml", "", func(url string, i []byte) error { + uri.DownloadAndUnmarshal("", func(url string, i []byte) error { Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) return nil }), ).ToNot(HaveOccurred()) }) It("parses github without a branch", func() { + uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml@main") + Expect( - DownloadAndUnmarshal("github:go-skynet/model-gallery/gpt4all-j.yaml@main", "", func(url string, i []byte) error { + uri.DownloadAndUnmarshal("", func(url string, i []byte) error { Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) return nil }), ).ToNot(HaveOccurred()) }) It("parses github with urls", func() { + uri := URI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml") Expect( - DownloadAndUnmarshal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml", "", func(url string, i []byte) error { + uri.DownloadAndUnmarshal("", func(url string, i []byte) error { Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) return nil }), diff --git a/pkg/startup/model_preload.go b/pkg/startup/model_preload.go index 9fa890b0..a445b10e 100644 --- a/pkg/startup/model_preload.go +++ b/pkg/startup/model_preload.go @@ -3,7 +3,6 @@ package startup import ( "errors" "fmt" - "net/url" "os" "path/filepath" "strings" @@ -23,21 +22,21 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath // create an error that groups all errors var err error - for _, url := range models { + lib, _ := embedded.GetRemoteLibraryShorteners(modelLibraryURL, modelPath) + for _, url := range models { // As a best effort, try to resolve the model from the remote library // if it's not resolved we try with the other method below if modelLibraryURL != "" { - lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL, modelPath) - if err == nil { - if lib[url] != "" { - log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url]) - url = lib[url] - } + if lib[url] != "" { + log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url]) + url = lib[url] } } url = embedded.ModelShortURL(url) + uri := downloader.URI(url) + switch { case embedded.ExistsInModelsLibrary(url): modelYAML, e := embedded.ResolveContent(url) @@ -55,7 +54,7 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath log.Error().Err(e).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition") err = errors.Join(err, e) } - case downloader.LooksLikeOCI(url): + case uri.LooksLikeOCI(): log.Debug().Msgf("[startup] resolved OCI model to download: %s", url) // convert OCI image name to a file name. @@ -67,7 +66,7 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath // check if file exists if _, e := os.Stat(filepath.Join(modelPath, ociName)); errors.Is(e, os.ErrNotExist) { modelDefinitionFilePath := filepath.Join(modelPath, ociName) - e := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) { + e := uri.DownloadFile(modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) { utils.DisplayDownloadFunction(fileName, current, total, percent) }) if e != nil { @@ -77,19 +76,15 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath } log.Info().Msgf("[startup] installed model from OCI repository: %s", ociName) - case downloader.LooksLikeURL(url): + case uri.LooksLikeURL(): log.Debug().Msgf("[startup] downloading %s", url) // Extract filename from URL - fileName, e := filenameFromUrl(url) - if e != nil || fileName == "" { - fileName = utils.MD5(url) - if strings.HasSuffix(url, ".yaml") || strings.HasSuffix(url, ".yml") { - fileName = fileName + ".yaml" - } + fileName, e := uri.FilenameFromUrl() + if e != nil { log.Warn().Err(e).Str("url", url).Msg("error extracting filename from URL") - //err = errors.Join(err, e) - //continue + err = errors.Join(err, e) + continue } modelPath := filepath.Join(modelPath, fileName) @@ -102,7 +97,7 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath // check if file exists if _, e := os.Stat(modelPath); errors.Is(e, os.ErrNotExist) { - e := downloader.DownloadFile(url, modelPath, "", 0, 0, func(fileName, current, total string, percent float64) { + e := uri.DownloadFile(modelPath, "", 0, 0, func(fileName, current, total string, percent float64) { utils.DisplayDownloadFunction(fileName, current, total, percent) }) if e != nil { @@ -167,20 +162,3 @@ func installModel(galleries []config.Gallery, modelName, modelPath string, downl return nil, true } - -func filenameFromUrl(urlstr string) (string, error) { - // strip anything after @ - if strings.Contains(urlstr, "@") { - urlstr = strings.Split(urlstr, "@")[0] - } - - u, err := url.Parse(urlstr) - if err != nil { - return "", fmt.Errorf("error due to parsing url: %w", err) - } - x, err := url.QueryUnescape(u.EscapedPath()) - if err != nil { - return "", fmt.Errorf("error due to escaping: %w", err) - } - return filepath.Base(x), nil -}