From 522659eb59bd9d1855653ed1d68eed551832cd67 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Mon, 1 Jan 2024 08:39:13 -0500 Subject: [PATCH] feat(prepare): allow to specify additional files to download (#1526) --- api/config/config.go | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/api/config/config.go b/api/config/config.go index bfcc7a6b..4e2f4b98 100644 --- a/api/config/config.go +++ b/api/config/config.go @@ -52,6 +52,14 @@ type Config struct { // CUDA // Explicitly enable CUDA or not (some backends might need it) CUDA bool `yaml:"cuda"` + + DownloadFiles []File `yaml:"download_files"` +} + +type File struct { + Filename string `yaml:"filename" json:"filename"` + SHA256 string `yaml:"sha256" json:"sha256"` + URI string `yaml:"uri" json:"uri"` } type VallE struct { @@ -272,10 +280,29 @@ func (cm *ConfigLoader) Preload(modelPath string) error { cm.Lock() defer cm.Unlock() + status := func(fileName, current, total string, percent float64) { + utils.DisplayDownloadFunction(fileName, current, total, percent) + } + log.Info().Msgf("Preloading models from %s", modelPath) for i, config := range cm.configs { + // Download files and verify their SHA + for _, file := range config.DownloadFiles { + log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) + + if err := utils.VerifyPath(file.Filename, modelPath); err != nil { + return err + } + // Create file path + filePath := filepath.Join(modelPath, file.Filename) + + if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil { + return err + } + } + modelURL := config.PredictionOptions.Model modelURL = utils.ConvertURL(modelURL) @@ -285,9 +312,7 @@ func (cm *ConfigLoader) Preload(modelPath string) error { // check if file exists if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { - err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", func(fileName, current, total string, percent float64) { - utils.DisplayDownloadFunction(fileName, current, total, percent) - }) + err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status) if err != nil { return err }