2024-01-05 22:16:33 +00:00
package downloader
2023-06-24 06:18:17 +00:00
import (
2023-12-18 17:58:44 +00:00
"crypto/sha256"
2023-06-26 10:25:38 +00:00
"fmt"
2023-07-30 07:47:22 +00:00
"io"
2023-06-24 06:18:17 +00:00
"net/http"
2024-08-02 18:06:25 +00:00
"net/url"
2023-07-30 07:47:22 +00:00
"os"
2023-09-02 07:00:44 +00:00
"path/filepath"
2023-12-18 17:58:44 +00:00
"strconv"
2023-06-24 06:18:17 +00:00
"strings"
2023-12-18 17:58:44 +00:00
2024-06-22 06:17:41 +00:00
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
2024-06-23 08:24:36 +00:00
"github.com/mudler/LocalAI/pkg/oci"
"github.com/mudler/LocalAI/pkg/utils"
2023-12-18 17:58:44 +00:00
"github.com/rs/zerolog/log"
2023-06-24 06:18:17 +00:00
)
2024-01-05 22:16:33 +00:00
const (
HuggingFacePrefix = "huggingface://"
2024-06-22 06:17:41 +00:00
OCIPrefix = "oci://"
OllamaPrefix = "ollama://"
2024-01-05 22:16:33 +00:00
HTTPPrefix = "http://"
HTTPSPrefix = "https://"
GithubURI = "github:"
GithubURI2 = "github://"
2024-08-02 18:06:25 +00:00
LocalPrefix = "file://"
2024-01-05 22:16:33 +00:00
)
2024-08-02 18:06:25 +00:00
type URI string
2023-06-26 10:25:38 +00:00
2024-09-24 07:32:48 +00:00
func ( uri URI ) DownloadWithCallback ( basePath string , f func ( url string , i [ ] byte ) error ) error {
return uri . DownloadWithAuthorizationAndCallback ( basePath , "" , f )
}
func ( uri URI ) DownloadWithAuthorizationAndCallback ( basePath string , authorization string , f func ( url string , i [ ] byte ) error ) error {
2024-08-02 18:06:25 +00:00
url := uri . ResolveURL ( )
if strings . HasPrefix ( url , LocalPrefix ) {
rawURL := strings . TrimPrefix ( url , LocalPrefix )
2023-09-02 07:00:44 +00:00
// checks if the file is symbolic, and resolve if so - otherwise, this function returns the path unmodified.
resolvedFile , err := filepath . EvalSymlinks ( rawURL )
if err != nil {
return err
}
2024-06-05 06:45:24 +00:00
resolvedBasePath , err := filepath . EvalSymlinks ( basePath )
if err != nil {
return err
}
2024-06-04 14:32:47 +00:00
// Check if the local file is rooted in basePath
2024-06-05 06:45:24 +00:00
err = utils . InTrustedRoot ( resolvedFile , resolvedBasePath )
2024-06-04 14:32:47 +00:00
if err != nil {
2024-06-05 06:45:24 +00:00
log . Debug ( ) . Str ( "resolvedFile" , resolvedFile ) . Str ( "basePath" , basePath ) . Msg ( "downloader.GetURI blocked an attempt to ready a file url outside of basePath" )
2024-06-04 14:32:47 +00:00
return err
}
2023-06-24 06:18:17 +00:00
// Read the response body
2023-09-02 07:00:44 +00:00
body , err := os . ReadFile ( resolvedFile )
2023-06-24 06:18:17 +00:00
if err != nil {
return err
}
// Unmarshal YAML data into a struct
2023-06-26 10:25:38 +00:00
return f ( url , body )
2023-06-24 06:18:17 +00:00
}
// Send a GET request to the URL
2024-09-24 07:32:48 +00:00
req , err := http . NewRequest ( "GET" , url , nil )
if err != nil {
return err
}
if authorization != "" {
req . Header . Add ( "Authorization" , authorization )
}
response , err := http . DefaultClient . Do ( req )
2023-06-24 06:18:17 +00:00
if err != nil {
return err
}
defer response . Body . Close ( )
// Read the response body
2023-07-30 07:47:22 +00:00
body , err := io . ReadAll ( response . Body )
2023-06-24 06:18:17 +00:00
if err != nil {
return err
}
// Unmarshal YAML data into a struct
2023-06-26 10:25:38 +00:00
return f ( url , body )
2023-06-24 06:18:17 +00:00
}
2023-12-18 17:58:44 +00:00
2024-08-02 18:06:25 +00:00
func ( u URI ) FilenameFromUrl ( ) ( string , error ) {
f , err := filenameFromUrl ( string ( u ) )
if err != nil || f == "" {
f = utils . MD5 ( string ( u ) )
if strings . HasSuffix ( string ( u ) , ".yaml" ) || strings . HasSuffix ( string ( u ) , ".yml" ) {
f = f + ".yaml"
}
err = nil
}
return f , err
}
func filenameFromUrl ( urlstr string ) ( string , error ) {
// strip anything after @
if strings . Contains ( urlstr , "@" ) {
urlstr = strings . Split ( urlstr , "@" ) [ 0 ]
}
u , err := url . Parse ( urlstr )
if err != nil {
return "" , fmt . Errorf ( "error due to parsing url: %w" , err )
}
x , err := url . QueryUnescape ( u . EscapedPath ( ) )
if err != nil {
return "" , fmt . Errorf ( "error due to escaping: %w" , err )
}
return filepath . Base ( x ) , nil
}
func ( u URI ) LooksLikeURL ( ) bool {
return strings . HasPrefix ( string ( u ) , HTTPPrefix ) ||
strings . HasPrefix ( string ( u ) , HTTPSPrefix ) ||
strings . HasPrefix ( string ( u ) , HuggingFacePrefix ) ||
strings . HasPrefix ( string ( u ) , GithubURI ) ||
strings . HasPrefix ( string ( u ) , OllamaPrefix ) ||
strings . HasPrefix ( string ( u ) , OCIPrefix ) ||
strings . HasPrefix ( string ( u ) , GithubURI2 )
2023-12-30 14:36:46 +00:00
}
2024-08-02 18:06:25 +00:00
func ( s URI ) LooksLikeOCI ( ) bool {
return strings . HasPrefix ( string ( s ) , OCIPrefix ) || strings . HasPrefix ( string ( s ) , OllamaPrefix )
2024-06-22 06:17:41 +00:00
}
2024-08-02 18:06:25 +00:00
func ( s URI ) ResolveURL ( ) string {
2023-12-18 17:58:44 +00:00
switch {
2024-08-02 18:06:25 +00:00
case strings . HasPrefix ( string ( s ) , GithubURI2 ) :
repository := strings . Replace ( string ( s ) , GithubURI2 , "" , 1 )
2024-01-01 09:31:03 +00:00
repoParts := strings . Split ( repository , "@" )
branch := "main"
if len ( repoParts ) > 1 {
branch = repoParts [ 1 ]
}
repoPath := strings . Split ( repoParts [ 0 ] , "/" )
org := repoPath [ 0 ]
project := repoPath [ 1 ]
projectPath := strings . Join ( repoPath [ 2 : ] , "/" )
return fmt . Sprintf ( "https://raw.githubusercontent.com/%s/%s/%s/%s" , org , project , branch , projectPath )
2024-08-02 18:06:25 +00:00
case strings . HasPrefix ( string ( s ) , GithubURI ) :
parts := strings . Split ( string ( s ) , ":" )
2024-01-01 09:31:03 +00:00
repoParts := strings . Split ( parts [ 1 ] , "@" )
branch := "main"
if len ( repoParts ) > 1 {
branch = repoParts [ 1 ]
}
repoPath := strings . Split ( repoParts [ 0 ] , "/" )
org := repoPath [ 0 ]
project := repoPath [ 1 ]
projectPath := strings . Join ( repoPath [ 2 : ] , "/" )
return fmt . Sprintf ( "https://raw.githubusercontent.com/%s/%s/%s/%s" , org , project , branch , projectPath )
2024-08-02 18:06:25 +00:00
case strings . HasPrefix ( string ( s ) , HuggingFacePrefix ) :
repository := strings . Replace ( string ( s ) , HuggingFacePrefix , "" , 1 )
2023-12-18 17:58:44 +00:00
// convert repository to a full URL.
// e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf
owner := strings . Split ( repository , "/" ) [ 0 ]
repo := strings . Split ( repository , "/" ) [ 1 ]
2024-07-10 11:18:32 +00:00
2023-12-18 17:58:44 +00:00
branch := "main"
if strings . Contains ( repo , "@" ) {
branch = strings . Split ( repository , "@" ) [ 1 ]
}
filepath := strings . Split ( repository , "/" ) [ 2 ]
if strings . Contains ( filepath , "@" ) {
filepath = strings . Split ( filepath , "@" ) [ 0 ]
}
return fmt . Sprintf ( "https://huggingface.co/%s/%s/resolve/%s/%s" , owner , repo , branch , filepath )
}
2024-08-02 18:06:25 +00:00
return string ( s )
2023-12-18 17:58:44 +00:00
}
2023-12-24 18:39:33 +00:00
func removePartialFile ( tmpFilePath string ) error {
_ , err := os . Stat ( tmpFilePath )
if err == nil {
log . Debug ( ) . Msgf ( "Removing temporary file %s" , tmpFilePath )
err = os . Remove ( tmpFilePath )
if err != nil {
err1 := fmt . Errorf ( "failed to remove temporary download file %s: %v" , tmpFilePath , err )
log . Warn ( ) . Msg ( err1 . Error ( ) )
return err1
}
}
return nil
}
2024-08-02 18:06:25 +00:00
func ( uri URI ) DownloadFile ( filePath , sha string , fileN , total int , downloadStatus func ( string , string , string , float64 ) ) error {
url := uri . ResolveURL ( )
if uri . LooksLikeOCI ( ) {
2024-06-22 06:17:41 +00:00
progressStatus := func ( desc ocispec . Descriptor ) io . Writer {
return & progressWriter {
fileName : filePath ,
total : desc . Size ,
hash : sha256 . New ( ) ,
fileNo : fileN ,
totalFiles : total ,
downloadStatus : downloadStatus ,
}
}
if strings . HasPrefix ( url , OllamaPrefix ) {
url = strings . TrimPrefix ( url , OllamaPrefix )
return oci . OllamaFetchModel ( url , filePath , progressStatus )
}
url = strings . TrimPrefix ( url , OCIPrefix )
img , err := oci . GetImage ( url , "" , nil , nil )
if err != nil {
return fmt . Errorf ( "failed to get image %q: %v" , url , err )
}
return oci . ExtractOCIImage ( img , filepath . Dir ( filePath ) )
}
2023-12-18 17:58:44 +00:00
// Check if the file already exists
_ , err := os . Stat ( filePath )
if err == nil {
// File exists, check SHA
if sha != "" {
// Verify SHA
calculatedSHA , err := calculateSHA ( filePath )
if err != nil {
return fmt . Errorf ( "failed to calculate SHA for file %q: %v" , filePath , err )
}
if calculatedSHA == sha {
// SHA matches, skip downloading
log . Debug ( ) . Msgf ( "File %q already exists and matches the SHA. Skipping download" , filePath )
return nil
}
// SHA doesn't match, delete the file and download again
err = os . Remove ( filePath )
if err != nil {
return fmt . Errorf ( "failed to remove existing file %q: %v" , filePath , err )
}
log . Debug ( ) . Msgf ( "Removed %q (SHA doesn't match)" , filePath )
} else {
// SHA is missing, skip downloading
log . Debug ( ) . Msgf ( "File %q already exists. Skipping download" , filePath )
return nil
}
} else if ! os . IsNotExist ( err ) {
// Error occurred while checking file existence
return fmt . Errorf ( "failed to check file %q existence: %v" , filePath , err )
}
log . Info ( ) . Msgf ( "Downloading %q" , url )
// Download file
resp , err := http . Get ( url )
if err != nil {
return fmt . Errorf ( "failed to download file %q: %v" , filePath , err )
}
defer resp . Body . Close ( )
2024-03-01 15:19:53 +00:00
if resp . StatusCode >= 400 {
return fmt . Errorf ( "failed to download url %q, invalid status code %d" , url , resp . StatusCode )
}
2023-12-18 17:58:44 +00:00
// Create parent directory
2024-04-25 22:47:06 +00:00
err = os . MkdirAll ( filepath . Dir ( filePath ) , 0750 )
2023-12-18 17:58:44 +00:00
if err != nil {
return fmt . Errorf ( "failed to create parent directory for file %q: %v" , filePath , err )
}
2023-12-24 18:39:33 +00:00
// save partial download to dedicated file
tmpFilePath := filePath + ".partial"
// remove tmp file
err = removePartialFile ( tmpFilePath )
if err != nil {
return err
}
2023-12-18 17:58:44 +00:00
// Create and write file content
2023-12-24 18:39:33 +00:00
outFile , err := os . Create ( tmpFilePath )
2023-12-18 17:58:44 +00:00
if err != nil {
2023-12-24 18:39:33 +00:00
return fmt . Errorf ( "failed to create file %q: %v" , tmpFilePath , err )
2023-12-18 17:58:44 +00:00
}
defer outFile . Close ( )
progress := & progressWriter {
2023-12-24 18:39:33 +00:00
fileName : tmpFilePath ,
2023-12-18 17:58:44 +00:00
total : resp . ContentLength ,
hash : sha256 . New ( ) ,
2024-04-23 07:22:58 +00:00
fileNo : fileN ,
totalFiles : total ,
2023-12-18 17:58:44 +00:00
downloadStatus : downloadStatus ,
}
_ , err = io . Copy ( io . MultiWriter ( outFile , progress ) , resp . Body )
if err != nil {
return fmt . Errorf ( "failed to write file %q: %v" , filePath , err )
}
2023-12-24 18:39:33 +00:00
err = os . Rename ( tmpFilePath , filePath )
if err != nil {
return fmt . Errorf ( "failed to rename temporary file %s -> %s: %v" , tmpFilePath , filePath , err )
}
2023-12-18 17:58:44 +00:00
if sha != "" {
// Verify SHA
calculatedSHA := fmt . Sprintf ( "%x" , progress . hash . Sum ( nil ) )
if calculatedSHA != sha {
log . Debug ( ) . Msgf ( "SHA mismatch for file %q ( calculated: %s != metadata: %s )" , filePath , calculatedSHA , sha )
return fmt . Errorf ( "SHA mismatch for file %q ( calculated: %s != metadata: %s )" , filePath , calculatedSHA , sha )
}
} else {
log . Debug ( ) . Msgf ( "SHA missing for %q. Skipping validation" , filePath )
}
log . Info ( ) . Msgf ( "File %q downloaded and verified" , filePath )
2024-01-05 22:16:33 +00:00
if utils . IsArchive ( filePath ) {
2023-12-18 17:58:44 +00:00
basePath := filepath . Dir ( filePath )
log . Info ( ) . Msgf ( "File %q is an archive, uncompressing to %s" , filePath , basePath )
2024-01-05 22:16:33 +00:00
if err := utils . ExtractArchive ( filePath , basePath ) ; err != nil {
2023-12-18 17:58:44 +00:00
log . Debug ( ) . Msgf ( "Failed decompressing %q: %s" , filePath , err . Error ( ) )
return err
}
}
return nil
}
func formatBytes ( bytes int64 ) string {
const unit = 1024
if bytes < unit {
return strconv . FormatInt ( bytes , 10 ) + " B"
}
div , exp := int64 ( unit ) , 0
for n := bytes / unit ; n >= unit ; n /= unit {
div *= unit
exp ++
}
return fmt . Sprintf ( "%.1f %ciB" , float64 ( bytes ) / float64 ( div ) , "KMGTPE" [ exp ] )
}
func calculateSHA ( filePath string ) ( string , error ) {
file , err := os . Open ( filePath )
if err != nil {
return "" , err
}
defer file . Close ( )
hash := sha256 . New ( )
if _ , err := io . Copy ( hash , file ) ; err != nil {
return "" , err
}
return fmt . Sprintf ( "%x" , hash . Sum ( nil ) ) , nil
}