package downloader import ( "crypto/sha256" "errors" "fmt" "hash" "io" "net/http" "net/url" "os" "path/filepath" "strconv" "strings" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/mudler/LocalAI/pkg/oci" "github.com/mudler/LocalAI/pkg/utils" "github.com/rs/zerolog/log" ) const ( HuggingFacePrefix = "huggingface://" OCIPrefix = "oci://" OllamaPrefix = "ollama://" HTTPPrefix = "http://" HTTPSPrefix = "https://" GithubURI = "github:" GithubURI2 = "github://" LocalPrefix = "file://" ) type URI string func (uri URI) DownloadWithCallback(basePath string, f func(url string, i []byte) error) error { return uri.DownloadWithAuthorizationAndCallback(basePath, "", f) } func (uri URI) DownloadWithAuthorizationAndCallback(basePath string, authorization string, f func(url string, i []byte) error) error { url := uri.ResolveURL() if strings.HasPrefix(url, LocalPrefix) { rawURL := strings.TrimPrefix(url, LocalPrefix) // checks if the file is symbolic, and resolve if so - otherwise, this function returns the path unmodified. resolvedFile, err := filepath.EvalSymlinks(rawURL) if err != nil { return err } resolvedBasePath, err := filepath.EvalSymlinks(basePath) if err != nil { return err } // Check if the local file is rooted in basePath err = utils.InTrustedRoot(resolvedFile, resolvedBasePath) if err != nil { log.Debug().Str("resolvedFile", resolvedFile).Str("basePath", basePath).Msg("downloader.GetURI blocked an attempt to ready a file url outside of basePath") return err } // Read the response body body, err := os.ReadFile(resolvedFile) if err != nil { return err } // Unmarshal YAML data into a struct return f(url, body) } // Send a GET request to the URL req, err := http.NewRequest("GET", url, nil) if err != nil { return err } if authorization != "" { req.Header.Add("Authorization", authorization) } response, err := http.DefaultClient.Do(req) if err != nil { return err } defer response.Body.Close() // Read the response body body, err := io.ReadAll(response.Body) if err != nil { return err } // Unmarshal YAML data into a struct return f(url, body) } func (u URI) FilenameFromUrl() (string, error) { f, err := filenameFromUrl(string(u)) if err != nil || f == "" { f = utils.MD5(string(u)) if strings.HasSuffix(string(u), ".yaml") || strings.HasSuffix(string(u), ".yml") { f = f + ".yaml" } err = nil } return f, err } func filenameFromUrl(urlstr string) (string, error) { // strip anything after @ if strings.Contains(urlstr, "@") { urlstr = strings.Split(urlstr, "@")[0] } u, err := url.Parse(urlstr) if err != nil { return "", fmt.Errorf("error due to parsing url: %w", err) } x, err := url.QueryUnescape(u.EscapedPath()) if err != nil { return "", fmt.Errorf("error due to escaping: %w", err) } return filepath.Base(x), nil } func (u URI) LooksLikeURL() bool { return strings.HasPrefix(string(u), HTTPPrefix) || strings.HasPrefix(string(u), HTTPSPrefix) || strings.HasPrefix(string(u), HuggingFacePrefix) || strings.HasPrefix(string(u), GithubURI) || strings.HasPrefix(string(u), OllamaPrefix) || strings.HasPrefix(string(u), OCIPrefix) || strings.HasPrefix(string(u), GithubURI2) } func (s URI) LooksLikeOCI() bool { return strings.HasPrefix(string(s), OCIPrefix) || strings.HasPrefix(string(s), OllamaPrefix) } func (s URI) ResolveURL() string { switch { case strings.HasPrefix(string(s), GithubURI2): repository := strings.Replace(string(s), GithubURI2, "", 1) repoParts := strings.Split(repository, "@") branch := "main" if len(repoParts) > 1 { branch = repoParts[1] } repoPath := strings.Split(repoParts[0], "/") org := repoPath[0] project := repoPath[1] projectPath := strings.Join(repoPath[2:], "/") return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) case strings.HasPrefix(string(s), GithubURI): parts := strings.Split(string(s), ":") repoParts := strings.Split(parts[1], "@") branch := "main" if len(repoParts) > 1 { branch = repoParts[1] } repoPath := strings.Split(repoParts[0], "/") org := repoPath[0] project := repoPath[1] projectPath := strings.Join(repoPath[2:], "/") return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) case strings.HasPrefix(string(s), HuggingFacePrefix): repository := strings.Replace(string(s), HuggingFacePrefix, "", 1) // convert repository to a full URL. // e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf owner := strings.Split(repository, "/")[0] repo := strings.Split(repository, "/")[1] branch := "main" if strings.Contains(repo, "@") { branch = strings.Split(repository, "@")[1] } filepath := strings.Split(repository, "/")[2] if strings.Contains(filepath, "@") { filepath = strings.Split(filepath, "@")[0] } return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath) } return string(s) } func removePartialFile(tmpFilePath string) error { _, err := os.Stat(tmpFilePath) if err == nil { log.Debug().Msgf("Removing temporary file %s", tmpFilePath) err = os.Remove(tmpFilePath) if err != nil { err1 := fmt.Errorf("failed to remove temporary download file %s: %v", tmpFilePath, err) log.Warn().Msg(err1.Error()) return err1 } } return nil } func calculateHashForPartialFile(file *os.File) (hash.Hash, error) { hash := sha256.New() _, err := io.Copy(hash, file) if err != nil { return nil, err } return hash, nil } func (uri URI) checkSeverSupportsRangeHeader() (bool, error) { url := uri.ResolveURL() resp, err := http.Head(url) if err != nil { return false, err } defer resp.Body.Close() return resp.Header.Get("Accept-Ranges") == "bytes", nil } func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error { url := uri.ResolveURL() if uri.LooksLikeOCI() { progressStatus := func(desc ocispec.Descriptor) io.Writer { return &progressWriter{ fileName: filePath, total: desc.Size, hash: sha256.New(), fileNo: fileN, totalFiles: total, downloadStatus: downloadStatus, } } if strings.HasPrefix(url, OllamaPrefix) { url = strings.TrimPrefix(url, OllamaPrefix) return oci.OllamaFetchModel(url, filePath, progressStatus) } url = strings.TrimPrefix(url, OCIPrefix) img, err := oci.GetImage(url, "", nil, nil) if err != nil { return fmt.Errorf("failed to get image %q: %v", url, err) } return oci.ExtractOCIImage(img, filepath.Dir(filePath)) } // Check if the file already exists _, err := os.Stat(filePath) if err == nil { // File exists, check SHA if sha != "" { // Verify SHA calculatedSHA, err := calculateSHA(filePath) if err != nil { return fmt.Errorf("failed to calculate SHA for file %q: %v", filePath, err) } if calculatedSHA == sha { // SHA matches, skip downloading log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", filePath) return nil } // SHA doesn't match, delete the file and download again err = os.Remove(filePath) if err != nil { return fmt.Errorf("failed to remove existing file %q: %v", filePath, err) } log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath) } else { // SHA is missing, skip downloading log.Debug().Msgf("File %q already exists. Skipping download", filePath) return nil } } else if !os.IsNotExist(err) { // Error occurred while checking file existence return fmt.Errorf("failed to check file %q existence: %v", filePath, err) } log.Info().Msgf("Downloading %q", url) req, err := http.NewRequest("GET", url, nil) if err != nil { return fmt.Errorf("failed to create request for %q: %v", filePath, err) } // save partial download to dedicated file tmpFilePath := filePath + ".partial" tmpFileInfo, err := os.Stat(tmpFilePath) if err == nil { support, err := uri.checkSeverSupportsRangeHeader() if err != nil { return fmt.Errorf("failed to check if uri server supports range header: %v", err) } if support { startPos := tmpFileInfo.Size() req.Header.Set("Range", fmt.Sprintf("bytes=%d-", startPos)) } else { err := removePartialFile(tmpFilePath) if err != nil { return err } } } else if !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("failed to check file %q existence: %v", filePath, err) } // Start the request resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("failed to download file %q: %v", filePath, err) } defer resp.Body.Close() if resp.StatusCode >= 400 { return fmt.Errorf("failed to download url %q, invalid status code %d", url, resp.StatusCode) } // Create parent directory err = os.MkdirAll(filepath.Dir(filePath), 0750) if err != nil { return fmt.Errorf("failed to create parent directory for file %q: %v", filePath, err) } // Create and write file outFile, err := os.OpenFile(tmpFilePath, os.O_APPEND|os.O_RDWR|os.O_CREATE, 0644) if err != nil { return fmt.Errorf("failed to create / open file %q: %v", tmpFilePath, err) } defer outFile.Close() hash, err := calculateHashForPartialFile(outFile) if err != nil { return fmt.Errorf("failed to calculate hash for partial file") } progress := &progressWriter{ fileName: tmpFilePath, total: resp.ContentLength, hash: hash, fileNo: fileN, totalFiles: total, downloadStatus: downloadStatus, } _, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body) if err != nil { return fmt.Errorf("failed to write file %q: %v", filePath, err) } err = os.Rename(tmpFilePath, filePath) if err != nil { return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err) } if sha != "" { // Verify SHA calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil)) if calculatedSHA != sha { log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha) return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha) } } else { log.Debug().Msgf("SHA missing for %q. Skipping validation", filePath) } log.Info().Msgf("File %q downloaded and verified", filePath) if utils.IsArchive(filePath) { basePath := filepath.Dir(filePath) log.Info().Msgf("File %q is an archive, uncompressing to %s", filePath, basePath) if err := utils.ExtractArchive(filePath, basePath); err != nil { log.Debug().Msgf("Failed decompressing %q: %s", filePath, err.Error()) return err } } return nil } func formatBytes(bytes int64) string { const unit = 1024 if bytes < unit { return strconv.FormatInt(bytes, 10) + " B" } div, exp := int64(unit), 0 for n := bytes / unit; n >= unit; n /= unit { div *= unit exp++ } return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) } func calculateSHA(filePath string) (string, error) { file, err := os.Open(filePath) if err != nil { return "", err } defer file.Close() hash := sha256.New() if _, err := io.Copy(hash, file); err != nil { return "", err } return fmt.Sprintf("%x", hash.Sum(nil)), nil }