Saarthak Verma 6765b17acd
feat(dowloader): resume partial downloads (#4537)
* feat(resume downloads): add basic tests

Signed-off-by: Saarthak Verma <saarthakverma739@gmail.com>

* test(resume downloads): implement file download tc

Signed-off-by: Saarthak Verma <saarthakverma739@gmail.com>

* test(resume downloads): add resume partial download test

Signed-off-by: Saarthak Verma <saarthakverma739@gmail.com>

* feat(resume downloads): implement resumable downloads for interrupted transfers

- Adds support for resuming partially downloaded files
- Uses HTTP Range header to continue from last byte position
- Maintains download progress across interruptions
- Preserves partial downloads with .partial extension
- Validates SHA256 checksum after completion

Signed-off-by: Saarthak Verma <saarthakverma739@gmail.com>

* fix(resume downloads): incorrect download percent on front end

Signed-off-by: Saarthak Verma <saarthakverma739@gmail.com>

* feat(resume download): add range header check tc

Signed-off-by: Saarthak Verma <saarthakverma739@gmail.com>

* feat(resume download): implement range header check

Signed-off-by: Saarthak Verma <saarthakverma739@gmail.com>

---------

Signed-off-by: Saarthak Verma <saarthakverma739@gmail.com>
2025-01-09 09:22:52 +01:00

412 lines
11 KiB
Go

package downloader
import (
"crypto/sha256"
"errors"
"fmt"
"hash"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/mudler/LocalAI/pkg/oci"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
const (
HuggingFacePrefix = "huggingface://"
OCIPrefix = "oci://"
OllamaPrefix = "ollama://"
HTTPPrefix = "http://"
HTTPSPrefix = "https://"
GithubURI = "github:"
GithubURI2 = "github://"
LocalPrefix = "file://"
)
type URI string
func (uri URI) DownloadWithCallback(basePath string, f func(url string, i []byte) error) error {
return uri.DownloadWithAuthorizationAndCallback(basePath, "", f)
}
func (uri URI) DownloadWithAuthorizationAndCallback(basePath string, authorization string, f func(url string, i []byte) error) error {
url := uri.ResolveURL()
if strings.HasPrefix(url, LocalPrefix) {
rawURL := strings.TrimPrefix(url, LocalPrefix)
// checks if the file is symbolic, and resolve if so - otherwise, this function returns the path unmodified.
resolvedFile, err := filepath.EvalSymlinks(rawURL)
if err != nil {
return err
}
resolvedBasePath, err := filepath.EvalSymlinks(basePath)
if err != nil {
return err
}
// Check if the local file is rooted in basePath
err = utils.InTrustedRoot(resolvedFile, resolvedBasePath)
if err != nil {
log.Debug().Str("resolvedFile", resolvedFile).Str("basePath", basePath).Msg("downloader.GetURI blocked an attempt to ready a file url outside of basePath")
return err
}
// Read the response body
body, err := os.ReadFile(resolvedFile)
if err != nil {
return err
}
// Unmarshal YAML data into a struct
return f(url, body)
}
// Send a GET request to the URL
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err
}
if authorization != "" {
req.Header.Add("Authorization", authorization)
}
response, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer response.Body.Close()
// Read the response body
body, err := io.ReadAll(response.Body)
if err != nil {
return err
}
// Unmarshal YAML data into a struct
return f(url, body)
}
func (u URI) FilenameFromUrl() (string, error) {
f, err := filenameFromUrl(string(u))
if err != nil || f == "" {
f = utils.MD5(string(u))
if strings.HasSuffix(string(u), ".yaml") || strings.HasSuffix(string(u), ".yml") {
f = f + ".yaml"
}
err = nil
}
return f, err
}
func filenameFromUrl(urlstr string) (string, error) {
// strip anything after @
if strings.Contains(urlstr, "@") {
urlstr = strings.Split(urlstr, "@")[0]
}
u, err := url.Parse(urlstr)
if err != nil {
return "", fmt.Errorf("error due to parsing url: %w", err)
}
x, err := url.QueryUnescape(u.EscapedPath())
if err != nil {
return "", fmt.Errorf("error due to escaping: %w", err)
}
return filepath.Base(x), nil
}
func (u URI) LooksLikeURL() bool {
return strings.HasPrefix(string(u), HTTPPrefix) ||
strings.HasPrefix(string(u), HTTPSPrefix) ||
strings.HasPrefix(string(u), HuggingFacePrefix) ||
strings.HasPrefix(string(u), GithubURI) ||
strings.HasPrefix(string(u), OllamaPrefix) ||
strings.HasPrefix(string(u), OCIPrefix) ||
strings.HasPrefix(string(u), GithubURI2)
}
func (s URI) LooksLikeOCI() bool {
return strings.HasPrefix(string(s), OCIPrefix) || strings.HasPrefix(string(s), OllamaPrefix)
}
func (s URI) ResolveURL() string {
switch {
case strings.HasPrefix(string(s), GithubURI2):
repository := strings.Replace(string(s), GithubURI2, "", 1)
repoParts := strings.Split(repository, "@")
branch := "main"
if len(repoParts) > 1 {
branch = repoParts[1]
}
repoPath := strings.Split(repoParts[0], "/")
org := repoPath[0]
project := repoPath[1]
projectPath := strings.Join(repoPath[2:], "/")
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
case strings.HasPrefix(string(s), GithubURI):
parts := strings.Split(string(s), ":")
repoParts := strings.Split(parts[1], "@")
branch := "main"
if len(repoParts) > 1 {
branch = repoParts[1]
}
repoPath := strings.Split(repoParts[0], "/")
org := repoPath[0]
project := repoPath[1]
projectPath := strings.Join(repoPath[2:], "/")
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
case strings.HasPrefix(string(s), HuggingFacePrefix):
repository := strings.Replace(string(s), HuggingFacePrefix, "", 1)
// convert repository to a full URL.
// e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf
owner := strings.Split(repository, "/")[0]
repo := strings.Split(repository, "/")[1]
branch := "main"
if strings.Contains(repo, "@") {
branch = strings.Split(repository, "@")[1]
}
filepath := strings.Split(repository, "/")[2]
if strings.Contains(filepath, "@") {
filepath = strings.Split(filepath, "@")[0]
}
return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath)
}
return string(s)
}
func removePartialFile(tmpFilePath string) error {
_, err := os.Stat(tmpFilePath)
if err == nil {
log.Debug().Msgf("Removing temporary file %s", tmpFilePath)
err = os.Remove(tmpFilePath)
if err != nil {
err1 := fmt.Errorf("failed to remove temporary download file %s: %v", tmpFilePath, err)
log.Warn().Msg(err1.Error())
return err1
}
}
return nil
}
func calculateHashForPartialFile(file *os.File) (hash.Hash, error) {
hash := sha256.New()
_, err := io.Copy(hash, file)
if err != nil {
return nil, err
}
return hash, nil
}
func (uri URI) checkSeverSupportsRangeHeader() (bool, error) {
url := uri.ResolveURL()
resp, err := http.Head(url)
if err != nil {
return false, err
}
defer resp.Body.Close()
return resp.Header.Get("Accept-Ranges") == "bytes", nil
}
func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
url := uri.ResolveURL()
if uri.LooksLikeOCI() {
progressStatus := func(desc ocispec.Descriptor) io.Writer {
return &progressWriter{
fileName: filePath,
total: desc.Size,
hash: sha256.New(),
fileNo: fileN,
totalFiles: total,
downloadStatus: downloadStatus,
}
}
if strings.HasPrefix(url, OllamaPrefix) {
url = strings.TrimPrefix(url, OllamaPrefix)
return oci.OllamaFetchModel(url, filePath, progressStatus)
}
url = strings.TrimPrefix(url, OCIPrefix)
img, err := oci.GetImage(url, "", nil, nil)
if err != nil {
return fmt.Errorf("failed to get image %q: %v", url, err)
}
return oci.ExtractOCIImage(img, filepath.Dir(filePath))
}
// Check if the file already exists
_, err := os.Stat(filePath)
if err == nil {
// File exists, check SHA
if sha != "" {
// Verify SHA
calculatedSHA, err := calculateSHA(filePath)
if err != nil {
return fmt.Errorf("failed to calculate SHA for file %q: %v", filePath, err)
}
if calculatedSHA == sha {
// SHA matches, skip downloading
log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", filePath)
return nil
}
// SHA doesn't match, delete the file and download again
err = os.Remove(filePath)
if err != nil {
return fmt.Errorf("failed to remove existing file %q: %v", filePath, err)
}
log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath)
} else {
// SHA is missing, skip downloading
log.Debug().Msgf("File %q already exists. Skipping download", filePath)
return nil
}
} else if !os.IsNotExist(err) {
// Error occurred while checking file existence
return fmt.Errorf("failed to check file %q existence: %v", filePath, err)
}
log.Info().Msgf("Downloading %q", url)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return fmt.Errorf("failed to create request for %q: %v", filePath, err)
}
// save partial download to dedicated file
tmpFilePath := filePath + ".partial"
tmpFileInfo, err := os.Stat(tmpFilePath)
if err == nil {
support, err := uri.checkSeverSupportsRangeHeader()
if err != nil {
return fmt.Errorf("failed to check if uri server supports range header: %v", err)
}
if support {
startPos := tmpFileInfo.Size()
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", startPos))
} else {
err := removePartialFile(tmpFilePath)
if err != nil {
return err
}
}
} else if !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("failed to check file %q existence: %v", filePath, err)
}
// Start the request
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download file %q: %v", filePath, err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return fmt.Errorf("failed to download url %q, invalid status code %d", url, resp.StatusCode)
}
// Create parent directory
err = os.MkdirAll(filepath.Dir(filePath), 0750)
if err != nil {
return fmt.Errorf("failed to create parent directory for file %q: %v", filePath, err)
}
// Create and write file
outFile, err := os.OpenFile(tmpFilePath, os.O_APPEND|os.O_RDWR|os.O_CREATE, 0644)
if err != nil {
return fmt.Errorf("failed to create / open file %q: %v", tmpFilePath, err)
}
defer outFile.Close()
hash, err := calculateHashForPartialFile(outFile)
if err != nil {
return fmt.Errorf("failed to calculate hash for partial file")
}
progress := &progressWriter{
fileName: tmpFilePath,
total: resp.ContentLength,
hash: hash,
fileNo: fileN,
totalFiles: total,
downloadStatus: downloadStatus,
}
_, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body)
if err != nil {
return fmt.Errorf("failed to write file %q: %v", filePath, err)
}
err = os.Rename(tmpFilePath, filePath)
if err != nil {
return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err)
}
if sha != "" {
// Verify SHA
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
if calculatedSHA != sha {
log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha)
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha)
}
} else {
log.Debug().Msgf("SHA missing for %q. Skipping validation", filePath)
}
log.Info().Msgf("File %q downloaded and verified", filePath)
if utils.IsArchive(filePath) {
basePath := filepath.Dir(filePath)
log.Info().Msgf("File %q is an archive, uncompressing to %s", filePath, basePath)
if err := utils.ExtractArchive(filePath, basePath); err != nil {
log.Debug().Msgf("Failed decompressing %q: %s", filePath, err.Error())
return err
}
}
return nil
}
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return strconv.FormatInt(bytes, 10) + " B"
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
func calculateSHA(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}