LocalAI/pkg/downloader/huggingface.go
Ettore Di Giacinto a36b721ca6
fix: be consistent in downloading files, check for scanner errors (#3108)
* fix(downloader): be consistent in downloading files

This PR puts some order in the downloader such as functions are re-used
across several places.

This fixes an issue with having uri's inside the model YAML file, it
would resolve to MD5 rather then using the filename

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(scanner): do raise error only if unsafeFiles are found

Fixes: https://github.com/mudler/LocalAI/issues/3114

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2024-08-02 20:06:25 +02:00

50 lines
1.4 KiB
Go

package downloader
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
)
type HuggingFaceScanResult struct {
RepositoryId string `json:"repositoryId"`
Revision string `json:"revision"`
HasUnsafeFiles bool `json:"hasUnsafeFile"`
ClamAVInfectedFiles []string `json:"clamAVInfectedFiles"`
DangerousPickles []string `json:"dangerousPickles"`
ScansDone bool `json:"scansDone"`
}
var ErrNonHuggingFaceFile = errors.New("not a huggingface repo")
var ErrUnsafeFilesFound = errors.New("unsafe files found")
func HuggingFaceScan(uri URI) (*HuggingFaceScanResult, error) {
cleanParts := strings.Split(uri.ResolveURL(), "/")
if len(cleanParts) <= 4 || cleanParts[2] != "huggingface.co" {
return nil, ErrNonHuggingFaceFile
}
results, err := http.Get(fmt.Sprintf("https://huggingface.co/api/models/%s/%s/scan", cleanParts[3], cleanParts[4]))
if err != nil {
return nil, err
}
if results.StatusCode != 200 {
return nil, fmt.Errorf("unexpected status code during HuggingFaceScan: %d", results.StatusCode)
}
scanResult := &HuggingFaceScanResult{}
bodyBytes, err := io.ReadAll(results.Body)
if err != nil {
return nil, err
}
err = json.Unmarshal(bodyBytes, scanResult)
if err != nil {
return nil, err
}
if scanResult.HasUnsafeFiles {
return scanResult, ErrUnsafeFilesFound
}
return scanResult, nil
}