mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
fix: be consistent in downloading files, check for scanner errors (#3108)
* fix(downloader): be consistent in downloading files This PR puts some order in the downloader such as functions are re-used across several places. This fixes an issue with having uri's inside the model YAML file, it would resolve to MD5 rather then using the filename Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(scanner): do raise error only if unsafeFiles are found Fixes: https://github.com/mudler/LocalAI/issues/3114 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
fc50a90f6a
commit
a36b721ca6
@ -83,7 +83,9 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !downloader.LooksLikeOCI(modelName) {
|
modelURI := downloader.URI(modelName)
|
||||||
|
|
||||||
|
if !modelURI.LooksLikeOCI() {
|
||||||
model := gallery.FindModel(models, modelName, mi.ModelsPath)
|
model := gallery.FindModel(models, modelName, mi.ModelsPath)
|
||||||
if model == nil {
|
if model == nil {
|
||||||
log.Error().Str("model", modelName).Msg("model not found")
|
log.Error().Str("model", modelName).Msg("model not found")
|
||||||
|
@ -86,8 +86,8 @@ func (hfscmd *HFScanCMD) Run(ctx *cliContext.Context) error {
|
|||||||
var errs error = nil
|
var errs error = nil
|
||||||
for _, uri := range hfscmd.ToScan {
|
for _, uri := range hfscmd.ToScan {
|
||||||
log.Info().Str("uri", uri).Msg("scanning specific uri")
|
log.Info().Str("uri", uri).Msg("scanning specific uri")
|
||||||
scanResults, err := downloader.HuggingFaceScan(uri)
|
scanResults, err := downloader.HuggingFaceScan(downloader.URI(uri))
|
||||||
if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) {
|
if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) {
|
||||||
log.Error().Err(err).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("! WARNING ! A known-vulnerable model is included in this repo!")
|
log.Error().Err(err).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("! WARNING ! A known-vulnerable model is included in this repo!")
|
||||||
errs = errors.Join(errs, err)
|
errs = errors.Join(errs, err)
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,6 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/pkg/downloader"
|
"github.com/mudler/LocalAI/pkg/downloader"
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -74,7 +73,7 @@ type BackendConfig struct {
|
|||||||
type File struct {
|
type File struct {
|
||||||
Filename string `yaml:"filename" json:"filename"`
|
Filename string `yaml:"filename" json:"filename"`
|
||||||
SHA256 string `yaml:"sha256" json:"sha256"`
|
SHA256 string `yaml:"sha256" json:"sha256"`
|
||||||
URI string `yaml:"uri" json:"uri"`
|
URI downloader.URI `yaml:"uri" json:"uri"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type VallE struct {
|
type VallE struct {
|
||||||
@ -213,28 +212,32 @@ func (c *BackendConfig) ShouldCallSpecificFunction() bool {
|
|||||||
// MMProjFileName returns the filename of the MMProj file
|
// MMProjFileName returns the filename of the MMProj file
|
||||||
// If the MMProj is a URL, it will return the MD5 of the URL which is the filename
|
// If the MMProj is a URL, it will return the MD5 of the URL which is the filename
|
||||||
func (c *BackendConfig) MMProjFileName() string {
|
func (c *BackendConfig) MMProjFileName() string {
|
||||||
modelURL := downloader.ConvertURL(c.MMProj)
|
uri := downloader.URI(c.MMProj)
|
||||||
if downloader.LooksLikeURL(modelURL) {
|
if uri.LooksLikeURL() {
|
||||||
return utils.MD5(modelURL)
|
f, _ := uri.FilenameFromUrl()
|
||||||
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.MMProj
|
return c.MMProj
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BackendConfig) IsMMProjURL() bool {
|
func (c *BackendConfig) IsMMProjURL() bool {
|
||||||
return downloader.LooksLikeURL(downloader.ConvertURL(c.MMProj))
|
uri := downloader.URI(c.MMProj)
|
||||||
|
return uri.LooksLikeURL()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BackendConfig) IsModelURL() bool {
|
func (c *BackendConfig) IsModelURL() bool {
|
||||||
return downloader.LooksLikeURL(downloader.ConvertURL(c.Model))
|
uri := downloader.URI(c.Model)
|
||||||
|
return uri.LooksLikeURL()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelFileName returns the filename of the model
|
// ModelFileName returns the filename of the model
|
||||||
// If the model is a URL, it will return the MD5 of the URL which is the filename
|
// If the model is a URL, it will return the MD5 of the URL which is the filename
|
||||||
func (c *BackendConfig) ModelFileName() string {
|
func (c *BackendConfig) ModelFileName() string {
|
||||||
modelURL := downloader.ConvertURL(c.Model)
|
uri := downloader.URI(c.Model)
|
||||||
if downloader.LooksLikeURL(modelURL) {
|
if uri.LooksLikeURL() {
|
||||||
return utils.MD5(modelURL)
|
f, _ := uri.FilenameFromUrl()
|
||||||
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.Model
|
return c.Model
|
||||||
|
@ -244,7 +244,7 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error {
|
|||||||
// Create file path
|
// Create file path
|
||||||
filePath := filepath.Join(modelPath, file.Filename)
|
filePath := filepath.Join(modelPath, file.Filename)
|
||||||
|
|
||||||
if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, i, len(config.DownloadFiles), status); err != nil {
|
if err := file.URI.DownloadFile(filePath, file.SHA256, i, len(config.DownloadFiles), status); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -252,10 +252,10 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error {
|
|||||||
// If the model is an URL, expand it, and download the file
|
// If the model is an URL, expand it, and download the file
|
||||||
if config.IsModelURL() {
|
if config.IsModelURL() {
|
||||||
modelFileName := config.ModelFileName()
|
modelFileName := config.ModelFileName()
|
||||||
modelURL := downloader.ConvertURL(config.Model)
|
uri := downloader.URI(config.Model)
|
||||||
// check if file exists
|
// check if file exists
|
||||||
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
|
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
|
||||||
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status)
|
err := uri.DownloadFile(filepath.Join(modelPath, modelFileName), "", 0, 0, status)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -269,10 +269,10 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error {
|
|||||||
|
|
||||||
if config.IsMMProjURL() {
|
if config.IsMMProjURL() {
|
||||||
modelFileName := config.MMProjFileName()
|
modelFileName := config.MMProjFileName()
|
||||||
modelURL := downloader.ConvertURL(config.MMProj)
|
uri := downloader.URI(config.MMProj)
|
||||||
// check if file exists
|
// check if file exists
|
||||||
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
|
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
|
||||||
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status)
|
err := uri.DownloadFile(filepath.Join(modelPath, modelFileName), "", 0, 0, status)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -37,7 +37,8 @@ func main() {
|
|||||||
|
|
||||||
// download the assets
|
// download the assets
|
||||||
for _, asset := range assets {
|
for _, asset := range assets {
|
||||||
if err := downloader.DownloadFile(asset.URL, filepath.Join(destPath, asset.FileName), asset.SHA, 1, 1, utils.DisplayDownloadFunction); err != nil {
|
uri := downloader.URI(asset.URL)
|
||||||
|
if err := uri.DownloadFile(filepath.Join(destPath, asset.FileName), asset.SHA, 1, 1, utils.DisplayDownloadFunction); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -131,7 +131,8 @@ func AvailableGalleryModels(galleries []config.Gallery, basePath string) ([]*Gal
|
|||||||
|
|
||||||
func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) {
|
func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) {
|
||||||
var refFile string
|
var refFile string
|
||||||
err := downloader.DownloadAndUnmarshal(url, basePath, func(url string, d []byte) error {
|
uri := downloader.URI(url)
|
||||||
|
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
|
||||||
refFile = string(d)
|
refFile = string(d)
|
||||||
if len(refFile) == 0 {
|
if len(refFile) == 0 {
|
||||||
return fmt.Errorf("invalid reference file at url %s: %s", url, d)
|
return fmt.Errorf("invalid reference file at url %s: %s", url, d)
|
||||||
@ -153,8 +154,9 @@ func getGalleryModels(gallery config.Gallery, basePath string) ([]*GalleryModel,
|
|||||||
return models, err
|
return models, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
uri := downloader.URI(gallery.URL)
|
||||||
|
|
||||||
err := downloader.DownloadAndUnmarshal(gallery.URL, basePath, func(url string, d []byte) error {
|
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
|
||||||
return yaml.Unmarshal(d, &models)
|
return yaml.Unmarshal(d, &models)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -252,8 +254,8 @@ func SafetyScanGalleryModels(galleries []config.Gallery, basePath string) error
|
|||||||
|
|
||||||
func SafetyScanGalleryModel(galleryModel *GalleryModel) error {
|
func SafetyScanGalleryModel(galleryModel *GalleryModel) error {
|
||||||
for _, file := range galleryModel.AdditionalFiles {
|
for _, file := range galleryModel.AdditionalFiles {
|
||||||
scanResults, err := downloader.HuggingFaceScan(file.URI)
|
scanResults, err := downloader.HuggingFaceScan(downloader.URI(file.URI))
|
||||||
if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) {
|
if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) {
|
||||||
log.Error().Str("model", galleryModel.Name).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("Contains unsafe file(s)!")
|
log.Error().Str("model", galleryModel.Name).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("Contains unsafe file(s)!")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,8 @@ type PromptTemplate struct {
|
|||||||
|
|
||||||
func GetGalleryConfigFromURL(url string, basePath string) (Config, error) {
|
func GetGalleryConfigFromURL(url string, basePath string) (Config, error) {
|
||||||
var config Config
|
var config Config
|
||||||
err := downloader.DownloadAndUnmarshal(url, basePath, func(url string, d []byte) error {
|
uri := downloader.URI(url)
|
||||||
|
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
|
||||||
return yaml.Unmarshal(d, &config)
|
return yaml.Unmarshal(d, &config)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -118,14 +119,14 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides
|
|||||||
filePath := filepath.Join(basePath, file.Filename)
|
filePath := filepath.Join(basePath, file.Filename)
|
||||||
|
|
||||||
if enforceScan {
|
if enforceScan {
|
||||||
scanResults, err := downloader.HuggingFaceScan(file.URI)
|
scanResults, err := downloader.HuggingFaceScan(downloader.URI(file.URI))
|
||||||
if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) {
|
if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) {
|
||||||
log.Error().Str("model", config.Name).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("Contains unsafe file(s)!")
|
log.Error().Str("model", config.Name).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("Contains unsafe file(s)!")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
uri := downloader.URI(file.URI)
|
||||||
if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
|
if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -73,8 +73,9 @@ func getModelStatus(url string) (response map[string]interface{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getModels(url string) (response []gallery.GalleryModel) {
|
func getModels(url string) (response []gallery.GalleryModel) {
|
||||||
|
uri := downloader.URI(url)
|
||||||
// TODO: No tests currently seem to exercise file:// urls. Fix?
|
// TODO: No tests currently seem to exercise file:// urls. Fix?
|
||||||
downloader.DownloadAndUnmarshal(url, "", func(url string, i []byte) error {
|
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
|
||||||
// Unmarshal YAML data into a struct
|
// Unmarshal YAML data into a struct
|
||||||
return json.Unmarshal(i, &response)
|
return json.Unmarshal(i, &response)
|
||||||
})
|
})
|
||||||
|
@ -38,8 +38,8 @@ func init() {
|
|||||||
|
|
||||||
func GetRemoteLibraryShorteners(url string, basePath string) (map[string]string, error) {
|
func GetRemoteLibraryShorteners(url string, basePath string) (map[string]string, error) {
|
||||||
remoteLibrary := map[string]string{}
|
remoteLibrary := map[string]string{}
|
||||||
|
uri := downloader.URI(url)
|
||||||
err := downloader.DownloadAndUnmarshal(url, basePath, func(_ string, i []byte) error {
|
err := uri.DownloadAndUnmarshal(basePath, func(_ string, i []byte) error {
|
||||||
return yaml.Unmarshal(i, &remoteLibrary)
|
return yaml.Unmarshal(i, &remoteLibrary)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
49
pkg/downloader/huggingface.go
Normal file
49
pkg/downloader/huggingface.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package downloader
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HuggingFaceScanResult struct {
|
||||||
|
RepositoryId string `json:"repositoryId"`
|
||||||
|
Revision string `json:"revision"`
|
||||||
|
HasUnsafeFiles bool `json:"hasUnsafeFile"`
|
||||||
|
ClamAVInfectedFiles []string `json:"clamAVInfectedFiles"`
|
||||||
|
DangerousPickles []string `json:"dangerousPickles"`
|
||||||
|
ScansDone bool `json:"scansDone"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrNonHuggingFaceFile = errors.New("not a huggingface repo")
|
||||||
|
var ErrUnsafeFilesFound = errors.New("unsafe files found")
|
||||||
|
|
||||||
|
func HuggingFaceScan(uri URI) (*HuggingFaceScanResult, error) {
|
||||||
|
cleanParts := strings.Split(uri.ResolveURL(), "/")
|
||||||
|
if len(cleanParts) <= 4 || cleanParts[2] != "huggingface.co" {
|
||||||
|
return nil, ErrNonHuggingFaceFile
|
||||||
|
}
|
||||||
|
results, err := http.Get(fmt.Sprintf("https://huggingface.co/api/models/%s/%s/scan", cleanParts[3], cleanParts[4]))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if results.StatusCode != 200 {
|
||||||
|
return nil, fmt.Errorf("unexpected status code during HuggingFaceScan: %d", results.StatusCode)
|
||||||
|
}
|
||||||
|
scanResult := &HuggingFaceScanResult{}
|
||||||
|
bodyBytes, err := io.ReadAll(results.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(bodyBytes, scanResult)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if scanResult.HasUnsafeFiles {
|
||||||
|
return scanResult, ErrUnsafeFilesFound
|
||||||
|
}
|
||||||
|
return scanResult, nil
|
||||||
|
}
|
@ -2,12 +2,10 @@ package downloader
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -28,13 +26,16 @@ const (
|
|||||||
HTTPSPrefix = "https://"
|
HTTPSPrefix = "https://"
|
||||||
GithubURI = "github:"
|
GithubURI = "github:"
|
||||||
GithubURI2 = "github://"
|
GithubURI2 = "github://"
|
||||||
|
LocalPrefix = "file://"
|
||||||
)
|
)
|
||||||
|
|
||||||
func DownloadAndUnmarshal(url string, basePath string, f func(url string, i []byte) error) error {
|
type URI string
|
||||||
url = ConvertURL(url)
|
|
||||||
|
|
||||||
if strings.HasPrefix(url, "file://") {
|
func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte) error) error {
|
||||||
rawURL := strings.TrimPrefix(url, "file://")
|
url := uri.ResolveURL()
|
||||||
|
|
||||||
|
if strings.HasPrefix(url, LocalPrefix) {
|
||||||
|
rawURL := strings.TrimPrefix(url, LocalPrefix)
|
||||||
// checks if the file is symbolic, and resolve if so - otherwise, this function returns the path unmodified.
|
// checks if the file is symbolic, and resolve if so - otherwise, this function returns the path unmodified.
|
||||||
resolvedFile, err := filepath.EvalSymlinks(rawURL)
|
resolvedFile, err := filepath.EvalSymlinks(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -78,24 +79,54 @@ func DownloadAndUnmarshal(url string, basePath string, f func(url string, i []by
|
|||||||
return f(url, body)
|
return f(url, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LooksLikeURL(s string) bool {
|
func (u URI) FilenameFromUrl() (string, error) {
|
||||||
return strings.HasPrefix(s, HTTPPrefix) ||
|
f, err := filenameFromUrl(string(u))
|
||||||
strings.HasPrefix(s, HTTPSPrefix) ||
|
if err != nil || f == "" {
|
||||||
strings.HasPrefix(s, HuggingFacePrefix) ||
|
f = utils.MD5(string(u))
|
||||||
strings.HasPrefix(s, GithubURI) ||
|
if strings.HasSuffix(string(u), ".yaml") || strings.HasSuffix(string(u), ".yml") {
|
||||||
strings.HasPrefix(s, OllamaPrefix) ||
|
f = f + ".yaml"
|
||||||
strings.HasPrefix(s, OCIPrefix) ||
|
}
|
||||||
strings.HasPrefix(s, GithubURI2)
|
err = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func LooksLikeOCI(s string) bool {
|
return f, err
|
||||||
return strings.HasPrefix(s, OCIPrefix) || strings.HasPrefix(s, OllamaPrefix)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertURL(s string) string {
|
func filenameFromUrl(urlstr string) (string, error) {
|
||||||
|
// strip anything after @
|
||||||
|
if strings.Contains(urlstr, "@") {
|
||||||
|
urlstr = strings.Split(urlstr, "@")[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(urlstr)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error due to parsing url: %w", err)
|
||||||
|
}
|
||||||
|
x, err := url.QueryUnescape(u.EscapedPath())
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error due to escaping: %w", err)
|
||||||
|
}
|
||||||
|
return filepath.Base(x), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u URI) LooksLikeURL() bool {
|
||||||
|
return strings.HasPrefix(string(u), HTTPPrefix) ||
|
||||||
|
strings.HasPrefix(string(u), HTTPSPrefix) ||
|
||||||
|
strings.HasPrefix(string(u), HuggingFacePrefix) ||
|
||||||
|
strings.HasPrefix(string(u), GithubURI) ||
|
||||||
|
strings.HasPrefix(string(u), OllamaPrefix) ||
|
||||||
|
strings.HasPrefix(string(u), OCIPrefix) ||
|
||||||
|
strings.HasPrefix(string(u), GithubURI2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s URI) LooksLikeOCI() bool {
|
||||||
|
return strings.HasPrefix(string(s), OCIPrefix) || strings.HasPrefix(string(s), OllamaPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s URI) ResolveURL() string {
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(s, GithubURI2):
|
case strings.HasPrefix(string(s), GithubURI2):
|
||||||
repository := strings.Replace(s, GithubURI2, "", 1)
|
repository := strings.Replace(string(s), GithubURI2, "", 1)
|
||||||
|
|
||||||
repoParts := strings.Split(repository, "@")
|
repoParts := strings.Split(repository, "@")
|
||||||
branch := "main"
|
branch := "main"
|
||||||
@ -110,8 +141,8 @@ func ConvertURL(s string) string {
|
|||||||
projectPath := strings.Join(repoPath[2:], "/")
|
projectPath := strings.Join(repoPath[2:], "/")
|
||||||
|
|
||||||
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
|
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
|
||||||
case strings.HasPrefix(s, GithubURI):
|
case strings.HasPrefix(string(s), GithubURI):
|
||||||
parts := strings.Split(s, ":")
|
parts := strings.Split(string(s), ":")
|
||||||
repoParts := strings.Split(parts[1], "@")
|
repoParts := strings.Split(parts[1], "@")
|
||||||
branch := "main"
|
branch := "main"
|
||||||
|
|
||||||
@ -125,8 +156,8 @@ func ConvertURL(s string) string {
|
|||||||
projectPath := strings.Join(repoPath[2:], "/")
|
projectPath := strings.Join(repoPath[2:], "/")
|
||||||
|
|
||||||
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
|
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
|
||||||
case strings.HasPrefix(s, HuggingFacePrefix):
|
case strings.HasPrefix(string(s), HuggingFacePrefix):
|
||||||
repository := strings.Replace(s, HuggingFacePrefix, "", 1)
|
repository := strings.Replace(string(s), HuggingFacePrefix, "", 1)
|
||||||
// convert repository to a full URL.
|
// convert repository to a full URL.
|
||||||
// e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf
|
// e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf
|
||||||
owner := strings.Split(repository, "/")[0]
|
owner := strings.Split(repository, "/")[0]
|
||||||
@ -144,7 +175,7 @@ func ConvertURL(s string) string {
|
|||||||
return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath)
|
return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return string(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
func removePartialFile(tmpFilePath string) error {
|
func removePartialFile(tmpFilePath string) error {
|
||||||
@ -161,9 +192,9 @@ func removePartialFile(tmpFilePath string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DownloadFile(url string, filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
|
func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
|
||||||
url = ConvertURL(url)
|
url := uri.ResolveURL()
|
||||||
if LooksLikeOCI(url) {
|
if uri.LooksLikeOCI() {
|
||||||
progressStatus := func(desc ocispec.Descriptor) io.Writer {
|
progressStatus := func(desc ocispec.Descriptor) io.Writer {
|
||||||
return &progressWriter{
|
return &progressWriter{
|
||||||
fileName: filePath,
|
fileName: filePath,
|
||||||
@ -298,37 +329,6 @@ func DownloadFile(url string, filePath, sha string, fileN, total int, downloadSt
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// this function check if the string is an URL, if it's an URL downloads the image in memory
|
|
||||||
// encodes it in base64 and returns the base64 string
|
|
||||||
func GetBase64Image(s string) (string, error) {
|
|
||||||
if strings.HasPrefix(s, "http") {
|
|
||||||
// download the image
|
|
||||||
resp, err := http.Get(s)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// read the image data into memory
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
// encode the image data in base64
|
|
||||||
encoded := base64.StdEncoding.EncodeToString(data)
|
|
||||||
|
|
||||||
// return the base64 string
|
|
||||||
return encoded, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
|
|
||||||
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
|
|
||||||
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("not valid string")
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatBytes(bytes int64) string {
|
func formatBytes(bytes int64) string {
|
||||||
const unit = 1024
|
const unit = 1024
|
||||||
if bytes < unit {
|
if bytes < unit {
|
||||||
@ -356,42 +356,3 @@ func calculateSHA(filePath string) (string, error) {
|
|||||||
|
|
||||||
return fmt.Sprintf("%x", hash.Sum(nil)), nil
|
return fmt.Sprintf("%x", hash.Sum(nil)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type HuggingFaceScanResult struct {
|
|
||||||
RepositoryId string `json:"repositoryId"`
|
|
||||||
Revision string `json:"revision"`
|
|
||||||
HasUnsafeFiles bool `json:"hasUnsafeFile"`
|
|
||||||
ClamAVInfectedFiles []string `json:"clamAVInfectedFiles"`
|
|
||||||
DangerousPickles []string `json:"dangerousPickles"`
|
|
||||||
ScansDone bool `json:"scansDone"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var ErrNonHuggingFaceFile = errors.New("not a huggingface repo")
|
|
||||||
var ErrUnsafeFilesFound = errors.New("unsafe files found")
|
|
||||||
|
|
||||||
func HuggingFaceScan(uri string) (*HuggingFaceScanResult, error) {
|
|
||||||
cleanParts := strings.Split(ConvertURL(uri), "/")
|
|
||||||
if len(cleanParts) <= 4 || cleanParts[2] != "huggingface.co" {
|
|
||||||
return nil, ErrNonHuggingFaceFile
|
|
||||||
}
|
|
||||||
results, err := http.Get(fmt.Sprintf("https://huggingface.co/api/models/%s/%s/scan", cleanParts[3], cleanParts[4]))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if results.StatusCode != 200 {
|
|
||||||
return nil, fmt.Errorf("unexpected status code during HuggingFaceScan: %d", results.StatusCode)
|
|
||||||
}
|
|
||||||
scanResult := &HuggingFaceScanResult{}
|
|
||||||
bodyBytes, err := io.ReadAll(results.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(bodyBytes, scanResult)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if scanResult.HasUnsafeFiles {
|
|
||||||
return scanResult, ErrUnsafeFilesFound
|
|
||||||
}
|
|
||||||
return scanResult, nil
|
|
||||||
}
|
|
||||||
|
@ -9,24 +9,28 @@ import (
|
|||||||
var _ = Describe("Gallery API tests", func() {
|
var _ = Describe("Gallery API tests", func() {
|
||||||
Context("URI", func() {
|
Context("URI", func() {
|
||||||
It("parses github with a branch", func() {
|
It("parses github with a branch", func() {
|
||||||
|
uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml")
|
||||||
Expect(
|
Expect(
|
||||||
DownloadAndUnmarshal("github:go-skynet/model-gallery/gpt4all-j.yaml", "", func(url string, i []byte) error {
|
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
|
||||||
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||||
return nil
|
return nil
|
||||||
}),
|
}),
|
||||||
).ToNot(HaveOccurred())
|
).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
It("parses github without a branch", func() {
|
It("parses github without a branch", func() {
|
||||||
|
uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml@main")
|
||||||
|
|
||||||
Expect(
|
Expect(
|
||||||
DownloadAndUnmarshal("github:go-skynet/model-gallery/gpt4all-j.yaml@main", "", func(url string, i []byte) error {
|
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
|
||||||
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||||
return nil
|
return nil
|
||||||
}),
|
}),
|
||||||
).ToNot(HaveOccurred())
|
).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
It("parses github with urls", func() {
|
It("parses github with urls", func() {
|
||||||
|
uri := URI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")
|
||||||
Expect(
|
Expect(
|
||||||
DownloadAndUnmarshal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml", "", func(url string, i []byte) error {
|
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
|
||||||
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||||
return nil
|
return nil
|
||||||
}),
|
}),
|
||||||
|
@ -3,7 +3,6 @@ package startup
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@ -23,21 +22,21 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath
|
|||||||
// create an error that groups all errors
|
// create an error that groups all errors
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
for _, url := range models {
|
lib, _ := embedded.GetRemoteLibraryShorteners(modelLibraryURL, modelPath)
|
||||||
|
|
||||||
|
for _, url := range models {
|
||||||
// As a best effort, try to resolve the model from the remote library
|
// As a best effort, try to resolve the model from the remote library
|
||||||
// if it's not resolved we try with the other method below
|
// if it's not resolved we try with the other method below
|
||||||
if modelLibraryURL != "" {
|
if modelLibraryURL != "" {
|
||||||
lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL, modelPath)
|
|
||||||
if err == nil {
|
|
||||||
if lib[url] != "" {
|
if lib[url] != "" {
|
||||||
log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url])
|
log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url])
|
||||||
url = lib[url]
|
url = lib[url]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
url = embedded.ModelShortURL(url)
|
url = embedded.ModelShortURL(url)
|
||||||
|
uri := downloader.URI(url)
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case embedded.ExistsInModelsLibrary(url):
|
case embedded.ExistsInModelsLibrary(url):
|
||||||
modelYAML, e := embedded.ResolveContent(url)
|
modelYAML, e := embedded.ResolveContent(url)
|
||||||
@ -55,7 +54,7 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath
|
|||||||
log.Error().Err(e).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition")
|
log.Error().Err(e).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition")
|
||||||
err = errors.Join(err, e)
|
err = errors.Join(err, e)
|
||||||
}
|
}
|
||||||
case downloader.LooksLikeOCI(url):
|
case uri.LooksLikeOCI():
|
||||||
log.Debug().Msgf("[startup] resolved OCI model to download: %s", url)
|
log.Debug().Msgf("[startup] resolved OCI model to download: %s", url)
|
||||||
|
|
||||||
// convert OCI image name to a file name.
|
// convert OCI image name to a file name.
|
||||||
@ -67,7 +66,7 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath
|
|||||||
// check if file exists
|
// check if file exists
|
||||||
if _, e := os.Stat(filepath.Join(modelPath, ociName)); errors.Is(e, os.ErrNotExist) {
|
if _, e := os.Stat(filepath.Join(modelPath, ociName)); errors.Is(e, os.ErrNotExist) {
|
||||||
modelDefinitionFilePath := filepath.Join(modelPath, ociName)
|
modelDefinitionFilePath := filepath.Join(modelPath, ociName)
|
||||||
e := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) {
|
e := uri.DownloadFile(modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) {
|
||||||
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
||||||
})
|
})
|
||||||
if e != nil {
|
if e != nil {
|
||||||
@ -77,19 +76,15 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Msgf("[startup] installed model from OCI repository: %s", ociName)
|
log.Info().Msgf("[startup] installed model from OCI repository: %s", ociName)
|
||||||
case downloader.LooksLikeURL(url):
|
case uri.LooksLikeURL():
|
||||||
log.Debug().Msgf("[startup] downloading %s", url)
|
log.Debug().Msgf("[startup] downloading %s", url)
|
||||||
|
|
||||||
// Extract filename from URL
|
// Extract filename from URL
|
||||||
fileName, e := filenameFromUrl(url)
|
fileName, e := uri.FilenameFromUrl()
|
||||||
if e != nil || fileName == "" {
|
if e != nil {
|
||||||
fileName = utils.MD5(url)
|
|
||||||
if strings.HasSuffix(url, ".yaml") || strings.HasSuffix(url, ".yml") {
|
|
||||||
fileName = fileName + ".yaml"
|
|
||||||
}
|
|
||||||
log.Warn().Err(e).Str("url", url).Msg("error extracting filename from URL")
|
log.Warn().Err(e).Str("url", url).Msg("error extracting filename from URL")
|
||||||
//err = errors.Join(err, e)
|
err = errors.Join(err, e)
|
||||||
//continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
modelPath := filepath.Join(modelPath, fileName)
|
modelPath := filepath.Join(modelPath, fileName)
|
||||||
@ -102,7 +97,7 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath
|
|||||||
|
|
||||||
// check if file exists
|
// check if file exists
|
||||||
if _, e := os.Stat(modelPath); errors.Is(e, os.ErrNotExist) {
|
if _, e := os.Stat(modelPath); errors.Is(e, os.ErrNotExist) {
|
||||||
e := downloader.DownloadFile(url, modelPath, "", 0, 0, func(fileName, current, total string, percent float64) {
|
e := uri.DownloadFile(modelPath, "", 0, 0, func(fileName, current, total string, percent float64) {
|
||||||
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
||||||
})
|
})
|
||||||
if e != nil {
|
if e != nil {
|
||||||
@ -167,20 +162,3 @@ func installModel(galleries []config.Gallery, modelName, modelPath string, downl
|
|||||||
|
|
||||||
return nil, true
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func filenameFromUrl(urlstr string) (string, error) {
|
|
||||||
// strip anything after @
|
|
||||||
if strings.Contains(urlstr, "@") {
|
|
||||||
urlstr = strings.Split(urlstr, "@")[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
u, err := url.Parse(urlstr)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error due to parsing url: %w", err)
|
|
||||||
}
|
|
||||||
x, err := url.QueryUnescape(u.EscapedPath())
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error due to escaping: %w", err)
|
|
||||||
}
|
|
||||||
return filepath.Base(x), nil
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user