2024-01-05 22:16:33 +00:00
package downloader
2023-06-24 06:18:17 +00:00
import (
2023-12-18 17:58:44 +00:00
"crypto/sha256"
2024-01-05 22:16:33 +00:00
"encoding/base64"
2024-07-10 11:18:32 +00:00
"encoding/json"
"errors"
2023-06-26 10:25:38 +00:00
"fmt"
2023-07-30 07:47:22 +00:00
"io"
2023-06-24 06:18:17 +00:00
"net/http"
2023-07-30 07:47:22 +00:00
"os"
2023-09-02 07:00:44 +00:00
"path/filepath"
2023-12-18 17:58:44 +00:00
"strconv"
2023-06-24 06:18:17 +00:00
"strings"
2023-12-18 17:58:44 +00:00
2024-06-22 06:17:41 +00:00
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
2024-06-23 08:24:36 +00:00
"github.com/mudler/LocalAI/pkg/oci"
"github.com/mudler/LocalAI/pkg/utils"
2023-12-18 17:58:44 +00:00
"github.com/rs/zerolog/log"
2023-06-24 06:18:17 +00:00
)
2024-01-05 22:16:33 +00:00
const (
HuggingFacePrefix = "huggingface://"
2024-06-22 06:17:41 +00:00
OCIPrefix = "oci://"
OllamaPrefix = "ollama://"
2024-01-05 22:16:33 +00:00
HTTPPrefix = "http://"
HTTPSPrefix = "https://"
GithubURI = "github:"
GithubURI2 = "github://"
)
2024-06-22 06:17:41 +00:00
func DownloadAndUnmarshal ( url string , basePath string , f func ( url string , i [ ] byte ) error ) error {
2024-01-01 09:31:03 +00:00
url = ConvertURL ( url )
2023-06-26 10:25:38 +00:00
2023-06-24 06:18:17 +00:00
if strings . HasPrefix ( url , "file://" ) {
rawURL := strings . TrimPrefix ( url , "file://" )
2023-09-02 07:00:44 +00:00
// checks if the file is symbolic, and resolve if so - otherwise, this function returns the path unmodified.
resolvedFile , err := filepath . EvalSymlinks ( rawURL )
if err != nil {
return err
}
2024-06-05 06:45:24 +00:00
// ???
resolvedBasePath , err := filepath . EvalSymlinks ( basePath )
if err != nil {
return err
}
2024-06-04 14:32:47 +00:00
// Check if the local file is rooted in basePath
2024-06-05 06:45:24 +00:00
err = utils . InTrustedRoot ( resolvedFile , resolvedBasePath )
2024-06-04 14:32:47 +00:00
if err != nil {
2024-06-05 06:45:24 +00:00
log . Debug ( ) . Str ( "resolvedFile" , resolvedFile ) . Str ( "basePath" , basePath ) . Msg ( "downloader.GetURI blocked an attempt to ready a file url outside of basePath" )
2024-06-04 14:32:47 +00:00
return err
}
2023-06-24 06:18:17 +00:00
// Read the response body
2023-09-02 07:00:44 +00:00
body , err := os . ReadFile ( resolvedFile )
2023-06-24 06:18:17 +00:00
if err != nil {
return err
}
// Unmarshal YAML data into a struct
2023-06-26 10:25:38 +00:00
return f ( url , body )
2023-06-24 06:18:17 +00:00
}
// Send a GET request to the URL
response , err := http . Get ( url )
if err != nil {
return err
}
defer response . Body . Close ( )
// Read the response body
2023-07-30 07:47:22 +00:00
body , err := io . ReadAll ( response . Body )
2023-06-24 06:18:17 +00:00
if err != nil {
return err
}
// Unmarshal YAML data into a struct
2023-06-26 10:25:38 +00:00
return f ( url , body )
2023-06-24 06:18:17 +00:00
}
2023-12-18 17:58:44 +00:00
2023-12-30 14:36:46 +00:00
func LooksLikeURL ( s string ) bool {
2024-01-05 17:04:46 +00:00
return strings . HasPrefix ( s , HTTPPrefix ) ||
strings . HasPrefix ( s , HTTPSPrefix ) ||
strings . HasPrefix ( s , HuggingFacePrefix ) ||
strings . HasPrefix ( s , GithubURI ) ||
2024-06-22 06:17:41 +00:00
strings . HasPrefix ( s , OllamaPrefix ) ||
strings . HasPrefix ( s , OCIPrefix ) ||
2024-01-05 17:04:46 +00:00
strings . HasPrefix ( s , GithubURI2 )
2023-12-30 14:36:46 +00:00
}
2024-06-22 06:17:41 +00:00
func LooksLikeOCI ( s string ) bool {
return strings . HasPrefix ( s , OCIPrefix ) || strings . HasPrefix ( s , OllamaPrefix )
}
2023-12-18 17:58:44 +00:00
func ConvertURL ( s string ) string {
switch {
2024-01-01 09:31:03 +00:00
case strings . HasPrefix ( s , GithubURI2 ) :
repository := strings . Replace ( s , GithubURI2 , "" , 1 )
repoParts := strings . Split ( repository , "@" )
branch := "main"
if len ( repoParts ) > 1 {
branch = repoParts [ 1 ]
}
repoPath := strings . Split ( repoParts [ 0 ] , "/" )
org := repoPath [ 0 ]
project := repoPath [ 1 ]
projectPath := strings . Join ( repoPath [ 2 : ] , "/" )
return fmt . Sprintf ( "https://raw.githubusercontent.com/%s/%s/%s/%s" , org , project , branch , projectPath )
case strings . HasPrefix ( s , GithubURI ) :
parts := strings . Split ( s , ":" )
repoParts := strings . Split ( parts [ 1 ] , "@" )
branch := "main"
if len ( repoParts ) > 1 {
branch = repoParts [ 1 ]
}
repoPath := strings . Split ( repoParts [ 0 ] , "/" )
org := repoPath [ 0 ]
project := repoPath [ 1 ]
projectPath := strings . Join ( repoPath [ 2 : ] , "/" )
return fmt . Sprintf ( "https://raw.githubusercontent.com/%s/%s/%s/%s" , org , project , branch , projectPath )
2023-12-30 14:36:46 +00:00
case strings . HasPrefix ( s , HuggingFacePrefix ) :
repository := strings . Replace ( s , HuggingFacePrefix , "" , 1 )
2023-12-18 17:58:44 +00:00
// convert repository to a full URL.
// e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf
owner := strings . Split ( repository , "/" ) [ 0 ]
repo := strings . Split ( repository , "/" ) [ 1 ]
2024-07-10 11:18:32 +00:00
2023-12-18 17:58:44 +00:00
branch := "main"
if strings . Contains ( repo , "@" ) {
branch = strings . Split ( repository , "@" ) [ 1 ]
}
filepath := strings . Split ( repository , "/" ) [ 2 ]
if strings . Contains ( filepath , "@" ) {
filepath = strings . Split ( filepath , "@" ) [ 0 ]
}
return fmt . Sprintf ( "https://huggingface.co/%s/%s/resolve/%s/%s" , owner , repo , branch , filepath )
}
return s
}
2023-12-24 18:39:33 +00:00
func removePartialFile ( tmpFilePath string ) error {
_ , err := os . Stat ( tmpFilePath )
if err == nil {
log . Debug ( ) . Msgf ( "Removing temporary file %s" , tmpFilePath )
err = os . Remove ( tmpFilePath )
if err != nil {
err1 := fmt . Errorf ( "failed to remove temporary download file %s: %v" , tmpFilePath , err )
log . Warn ( ) . Msg ( err1 . Error ( ) )
return err1
}
}
return nil
}
2024-04-23 07:22:58 +00:00
func DownloadFile ( url string , filePath , sha string , fileN , total int , downloadStatus func ( string , string , string , float64 ) ) error {
2023-12-18 17:58:44 +00:00
url = ConvertURL ( url )
2024-06-22 06:17:41 +00:00
if LooksLikeOCI ( url ) {
progressStatus := func ( desc ocispec . Descriptor ) io . Writer {
return & progressWriter {
fileName : filePath ,
total : desc . Size ,
hash : sha256 . New ( ) ,
fileNo : fileN ,
totalFiles : total ,
downloadStatus : downloadStatus ,
}
}
if strings . HasPrefix ( url , OllamaPrefix ) {
url = strings . TrimPrefix ( url , OllamaPrefix )
return oci . OllamaFetchModel ( url , filePath , progressStatus )
}
url = strings . TrimPrefix ( url , OCIPrefix )
img , err := oci . GetImage ( url , "" , nil , nil )
if err != nil {
return fmt . Errorf ( "failed to get image %q: %v" , url , err )
}
return oci . ExtractOCIImage ( img , filepath . Dir ( filePath ) )
}
2023-12-18 17:58:44 +00:00
// Check if the file already exists
_ , err := os . Stat ( filePath )
if err == nil {
// File exists, check SHA
if sha != "" {
// Verify SHA
calculatedSHA , err := calculateSHA ( filePath )
if err != nil {
return fmt . Errorf ( "failed to calculate SHA for file %q: %v" , filePath , err )
}
if calculatedSHA == sha {
// SHA matches, skip downloading
log . Debug ( ) . Msgf ( "File %q already exists and matches the SHA. Skipping download" , filePath )
return nil
}
// SHA doesn't match, delete the file and download again
err = os . Remove ( filePath )
if err != nil {
return fmt . Errorf ( "failed to remove existing file %q: %v" , filePath , err )
}
log . Debug ( ) . Msgf ( "Removed %q (SHA doesn't match)" , filePath )
} else {
// SHA is missing, skip downloading
log . Debug ( ) . Msgf ( "File %q already exists. Skipping download" , filePath )
return nil
}
} else if ! os . IsNotExist ( err ) {
// Error occurred while checking file existence
return fmt . Errorf ( "failed to check file %q existence: %v" , filePath , err )
}
log . Info ( ) . Msgf ( "Downloading %q" , url )
// Download file
resp , err := http . Get ( url )
if err != nil {
return fmt . Errorf ( "failed to download file %q: %v" , filePath , err )
}
defer resp . Body . Close ( )
2024-03-01 15:19:53 +00:00
if resp . StatusCode >= 400 {
return fmt . Errorf ( "failed to download url %q, invalid status code %d" , url , resp . StatusCode )
}
2023-12-18 17:58:44 +00:00
// Create parent directory
2024-04-25 22:47:06 +00:00
err = os . MkdirAll ( filepath . Dir ( filePath ) , 0750 )
2023-12-18 17:58:44 +00:00
if err != nil {
return fmt . Errorf ( "failed to create parent directory for file %q: %v" , filePath , err )
}
2023-12-24 18:39:33 +00:00
// save partial download to dedicated file
tmpFilePath := filePath + ".partial"
// remove tmp file
err = removePartialFile ( tmpFilePath )
if err != nil {
return err
}
2023-12-18 17:58:44 +00:00
// Create and write file content
2023-12-24 18:39:33 +00:00
outFile , err := os . Create ( tmpFilePath )
2023-12-18 17:58:44 +00:00
if err != nil {
2023-12-24 18:39:33 +00:00
return fmt . Errorf ( "failed to create file %q: %v" , tmpFilePath , err )
2023-12-18 17:58:44 +00:00
}
defer outFile . Close ( )
progress := & progressWriter {
2023-12-24 18:39:33 +00:00
fileName : tmpFilePath ,
2023-12-18 17:58:44 +00:00
total : resp . ContentLength ,
hash : sha256 . New ( ) ,
2024-04-23 07:22:58 +00:00
fileNo : fileN ,
totalFiles : total ,
2023-12-18 17:58:44 +00:00
downloadStatus : downloadStatus ,
}
_ , err = io . Copy ( io . MultiWriter ( outFile , progress ) , resp . Body )
if err != nil {
return fmt . Errorf ( "failed to write file %q: %v" , filePath , err )
}
2023-12-24 18:39:33 +00:00
err = os . Rename ( tmpFilePath , filePath )
if err != nil {
return fmt . Errorf ( "failed to rename temporary file %s -> %s: %v" , tmpFilePath , filePath , err )
}
2023-12-18 17:58:44 +00:00
if sha != "" {
// Verify SHA
calculatedSHA := fmt . Sprintf ( "%x" , progress . hash . Sum ( nil ) )
if calculatedSHA != sha {
log . Debug ( ) . Msgf ( "SHA mismatch for file %q ( calculated: %s != metadata: %s )" , filePath , calculatedSHA , sha )
return fmt . Errorf ( "SHA mismatch for file %q ( calculated: %s != metadata: %s )" , filePath , calculatedSHA , sha )
}
} else {
log . Debug ( ) . Msgf ( "SHA missing for %q. Skipping validation" , filePath )
}
log . Info ( ) . Msgf ( "File %q downloaded and verified" , filePath )
2024-01-05 22:16:33 +00:00
if utils . IsArchive ( filePath ) {
2023-12-18 17:58:44 +00:00
basePath := filepath . Dir ( filePath )
log . Info ( ) . Msgf ( "File %q is an archive, uncompressing to %s" , filePath , basePath )
2024-01-05 22:16:33 +00:00
if err := utils . ExtractArchive ( filePath , basePath ) ; err != nil {
2023-12-18 17:58:44 +00:00
log . Debug ( ) . Msgf ( "Failed decompressing %q: %s" , filePath , err . Error ( ) )
return err
}
}
return nil
}
2024-01-05 22:16:33 +00:00
// this function check if the string is an URL, if it's an URL downloads the image in memory
// encodes it in base64 and returns the base64 string
func GetBase64Image ( s string ) ( string , error ) {
if strings . HasPrefix ( s , "http" ) {
// download the image
resp , err := http . Get ( s )
if err != nil {
return "" , err
}
defer resp . Body . Close ( )
2023-12-18 17:58:44 +00:00
2024-01-05 22:16:33 +00:00
// read the image data into memory
data , err := io . ReadAll ( resp . Body )
if err != nil {
return "" , err
}
2023-12-18 17:58:44 +00:00
2024-01-05 22:16:33 +00:00
// encode the image data in base64
encoded := base64 . StdEncoding . EncodeToString ( data )
2023-12-18 17:58:44 +00:00
2024-01-05 22:16:33 +00:00
// return the base64 string
return encoded , nil
}
2023-12-18 17:58:44 +00:00
2024-01-05 22:16:33 +00:00
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
if strings . HasPrefix ( s , "data:image/jpeg;base64," ) {
return strings . ReplaceAll ( s , "data:image/jpeg;base64," , "" ) , nil
}
return "" , fmt . Errorf ( "not valid string" )
2023-12-18 17:58:44 +00:00
}
func formatBytes ( bytes int64 ) string {
const unit = 1024
if bytes < unit {
return strconv . FormatInt ( bytes , 10 ) + " B"
}
div , exp := int64 ( unit ) , 0
for n := bytes / unit ; n >= unit ; n /= unit {
div *= unit
exp ++
}
return fmt . Sprintf ( "%.1f %ciB" , float64 ( bytes ) / float64 ( div ) , "KMGTPE" [ exp ] )
}
func calculateSHA ( filePath string ) ( string , error ) {
file , err := os . Open ( filePath )
if err != nil {
return "" , err
}
defer file . Close ( )
hash := sha256 . New ( )
if _ , err := io . Copy ( hash , file ) ; err != nil {
return "" , err
}
return fmt . Sprintf ( "%x" , hash . Sum ( nil ) ) , nil
}
2024-07-10 11:18:32 +00:00
type HuggingFaceScanResult struct {
RepositoryId string ` json:"repositoryId" `
Revision string ` json:"revision" `
HasUnsafeFiles bool ` json:"hasUnsafeFile" `
ClamAVInfectedFiles [ ] string ` json:"clamAVInfectedFiles" `
DangerousPickles [ ] string ` json:"dangerousPickles" `
ScansDone bool ` json:"scansDone" `
}
var ErrNonHuggingFaceFile = errors . New ( "not a huggingface repo" )
var ErrUnsafeFilesFound = errors . New ( "unsafe files found" )
func HuggingFaceScan ( uri string ) ( * HuggingFaceScanResult , error ) {
cleanParts := strings . Split ( ConvertURL ( uri ) , "/" )
if len ( cleanParts ) <= 4 || cleanParts [ 2 ] != "huggingface.co" {
return nil , ErrNonHuggingFaceFile
}
results , err := http . Get ( fmt . Sprintf ( "https://huggingface.co/api/models/%s/%s/scan" , cleanParts [ 3 ] , cleanParts [ 4 ] ) )
if err != nil {
return nil , err
}
if results . StatusCode != 200 {
return nil , fmt . Errorf ( "unexpected status code during HuggingFaceScan: %d" , results . StatusCode )
}
scanResult := & HuggingFaceScanResult { }
bodyBytes , err := io . ReadAll ( results . Body )
if err != nil {
return nil , err
}
err = json . Unmarshal ( bodyBytes , scanResult )
if err != nil {
return nil , err
}
if scanResult . HasUnsafeFiles {
return scanResult , ErrUnsafeFilesFound
}
return scanResult , nil
}