diff --git a/pkg/startup/model_preload.go b/pkg/startup/model_preload.go index 74a10e9e..9fa890b0 100644 --- a/pkg/startup/model_preload.go +++ b/pkg/startup/model_preload.go @@ -3,6 +3,7 @@ package startup import ( "errors" "fmt" + "net/url" "os" "path/filepath" "strings" @@ -77,19 +78,35 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath log.Info().Msgf("[startup] installed model from OCI repository: %s", ociName) case downloader.LooksLikeURL(url): - log.Debug().Msgf("[startup] resolved model to download: %s", url) + log.Debug().Msgf("[startup] downloading %s", url) - // md5 of model name - md5Name := utils.MD5(url) + // Extract filename from URL + fileName, e := filenameFromUrl(url) + if e != nil || fileName == "" { + fileName = utils.MD5(url) + if strings.HasSuffix(url, ".yaml") || strings.HasSuffix(url, ".yml") { + fileName = fileName + ".yaml" + } + log.Warn().Err(e).Str("url", url).Msg("error extracting filename from URL") + //err = errors.Join(err, e) + //continue + } + + modelPath := filepath.Join(modelPath, fileName) + + if e := utils.VerifyPath(fileName, modelPath); e != nil { + log.Error().Err(e).Str("filepath", modelPath).Msg("error verifying path") + err = errors.Join(err, e) + continue + } // check if file exists - if _, e := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(e, os.ErrNotExist) { - modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - e := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) { + if _, e := os.Stat(modelPath); errors.Is(e, os.ErrNotExist) { + e := downloader.DownloadFile(url, modelPath, "", 0, 0, func(fileName, current, total string, percent float64) { utils.DisplayDownloadFunction(fileName, current, total, percent) }) if e != nil { - log.Error().Err(e).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model") + log.Error().Err(e).Str("url", url).Str("filepath", modelPath).Msg("error downloading model") err = errors.Join(err, e) } } @@ -150,3 +167,20 @@ func installModel(galleries []config.Gallery, modelName, modelPath string, downl return nil, true } + +func filenameFromUrl(urlstr string) (string, error) { + // strip anything after @ + if strings.Contains(urlstr, "@") { + urlstr = strings.Split(urlstr, "@")[0] + } + + u, err := url.Parse(urlstr) + if err != nil { + return "", fmt.Errorf("error due to parsing url: %w", err) + } + x, err := url.QueryUnescape(u.EscapedPath()) + if err != nil { + return "", fmt.Errorf("error due to escaping: %w", err) + } + return filepath.Base(x), nil +} diff --git a/pkg/startup/model_preload_test.go b/pkg/startup/model_preload_test.go index 939ad1a2..869fcd3e 100644 --- a/pkg/startup/model_preload_test.go +++ b/pkg/startup/model_preload_test.go @@ -20,7 +20,7 @@ var _ = Describe("Preload test", func() { tmpdir, err := os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml" - fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719") + fileName := fmt.Sprintf("%s.yaml", "phi-2") InstallModels([]config.Gallery{}, libraryURL, tmpdir, true, nil, "phi-2") @@ -36,7 +36,7 @@ var _ = Describe("Preload test", func() { tmpdir, err := os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml" - fileName := fmt.Sprintf("%s.yaml", utils.MD5(url)) + fileName := fmt.Sprintf("%s.yaml", "phi-2") InstallModels([]config.Gallery{}, "", tmpdir, true, nil, url) @@ -79,5 +79,19 @@ var _ = Describe("Preload test", func() { Expect(string(content)).To(ContainSubstring("name: mistral-openorca")) }) + It("downloads from urls", func() { + tmpdir, err := os.MkdirTemp("", "") + Expect(err).ToNot(HaveOccurred()) + url := "huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf" + fileName := fmt.Sprintf("%s.gguf", "tinyllama-1.1b-chat-v0.3.Q2_K") + + err = InstallModels([]config.Gallery{}, "", tmpdir, false, nil, url) + Expect(err).ToNot(HaveOccurred()) + + resultFile := filepath.Join(tmpdir, fileName) + + _, err = os.Stat(resultFile) + Expect(err).ToNot(HaveOccurred()) + }) }) })