diff --git a/.github/check_and_update.py b/.github/check_and_update.py new file mode 100644 index 00000000..448844fd --- /dev/null +++ b/.github/check_and_update.py @@ -0,0 +1,79 @@ +import hashlib +from huggingface_hub import hf_hub_download, get_paths_info +import requests +import sys +import os + +uri = sys.argv[0] +file_name = uri.split('/')[-1] + +# Function to parse the URI and determine download method +def parse_uri(uri): + if uri.startswith('huggingface://'): + repo_id = uri.split('://')[1] + return 'huggingface', repo_id.rsplit('/', 1)[0] + elif 'huggingface.co' in uri: + parts = uri.split('/resolve/') + if len(parts) > 1: + repo_path = parts[0].split('https://huggingface.co/')[-1] + return 'huggingface', repo_path + return 'direct', uri + +def calculate_sha256(file_path): + sha256_hash = hashlib.sha256() + with open(file_path, 'rb') as f: + for byte_block in iter(lambda: f.read(4096), b''): + sha256_hash.update(byte_block) + return sha256_hash.hexdigest() + +def manual_safety_check_hf(repo_id): + scanResponse = requests.get('https://huggingface.co/api/models/' + repo_id + "/scan") + scan = scanResponse.json() + if scan['hasUnsafeFile']: + return scan + return None + +download_type, repo_id_or_url = parse_uri(uri) + +new_checksum = None + +# Decide download method based on URI type +if download_type == 'huggingface': + # Check if the repo is flagged as dangerous by HF + hazard = manual_safety_check_hf(repo_id_or_url) + if hazard != None: + print(f'Error: HuggingFace has detected security problems for {repo_id_or_url}: {str(hazard)}', filename=file_name) + sys.exit(5) + # Use HF API to pull sha + for file in get_paths_info(repo_id_or_url, [file_name], repo_type='model'): + try: + new_checksum = file.lfs.sha256 + break + except Exception as e: + print(f'Error from Hugging Face Hub: {str(e)}', file=sys.stderr) + sys.exit(2) + if new_checksum is None: + try: + file_path = hf_hub_download(repo_id=repo_id_or_url, filename=file_name) + except Exception as e: + print(f'Error from Hugging Face Hub: {str(e)}', file=sys.stderr) + sys.exit(2) +else: + response = requests.get(repo_id_or_url) + if response.status_code == 200: + with open(file_name, 'wb') as f: + f.write(response.content) + file_path = file_name + elif response.status_code == 404: + print(f'File not found: {response.status_code}', file=sys.stderr) + sys.exit(2) + else: + print(f'Error downloading file: {response.status_code}', file=sys.stderr) + sys.exit(1) + +if new_checksum is None: + new_checksum = calculate_sha256(file_path) + print(new_checksum) + os.remove(file_path) +else: + print(new_checksum) diff --git a/.github/checksum_checker.sh b/.github/checksum_checker.sh index 01242af6..174e6d3f 100644 --- a/.github/checksum_checker.sh +++ b/.github/checksum_checker.sh @@ -14,77 +14,14 @@ function check_and_update_checksum() { idx="$5" # Download the file and calculate new checksum using Python - new_checksum=$(python3 -c " -import hashlib -from huggingface_hub import hf_hub_download, get_paths_info -import requests -import sys -import os + new_checksum=$(python3 ./check_and_update.py $uri) + result=$? -uri = '$uri' -file_name = uri.split('/')[-1] - -# Function to parse the URI and determine download method -# Function to parse the URI and determine download method -def parse_uri(uri): - if uri.startswith('huggingface://'): - repo_id = uri.split('://')[1] - return 'huggingface', repo_id.rsplit('/', 1)[0] - elif 'huggingface.co' in uri: - parts = uri.split('/resolve/') - if len(parts) > 1: - repo_path = parts[0].split('https://huggingface.co/')[-1] - return 'huggingface', repo_path - return 'direct', uri - -def calculate_sha256(file_path): - sha256_hash = hashlib.sha256() - with open(file_path, 'rb') as f: - for byte_block in iter(lambda: f.read(4096), b''): - sha256_hash.update(byte_block) - return sha256_hash.hexdigest() - -download_type, repo_id_or_url = parse_uri(uri) - -new_checksum = None - -# Decide download method based on URI type -if download_type == 'huggingface': - # Use HF API to pull sha - for file in get_paths_info(repo_id_or_url, [file_name], repo_type='model'): - try: - new_checksum = file.lfs.sha256 - break - except Exception as e: - print(f'Error from Hugging Face Hub: {str(e)}', file=sys.stderr) - sys.exit(2) - if new_checksum is None: - try: - file_path = hf_hub_download(repo_id=repo_id_or_url, filename=file_name) - except Exception as e: - print(f'Error from Hugging Face Hub: {str(e)}', file=sys.stderr) - sys.exit(2) -else: - response = requests.get(repo_id_or_url) - if response.status_code == 200: - with open(file_name, 'wb') as f: - f.write(response.content) - file_path = file_name - elif response.status_code == 404: - print(f'File not found: {response.status_code}', file=sys.stderr) - sys.exit(2) - else: - print(f'Error downloading file: {response.status_code}', file=sys.stderr) - sys.exit(1) - -if new_checksum is None: - new_checksum = calculate_sha256(file_path) - print(new_checksum) - os.remove(file_path) -else: - print(new_checksum) - -") + if [[ result -eq 5]]; then + echo "Contaminated entry detected, deleting entry for $model_name..." + yq eval -i "del([$idx])" "$input_yaml" + return + fi if [[ "$new_checksum" == "" ]]; then echo "Error calculating checksum for $file_name. Skipping..." @@ -94,7 +31,7 @@ else: echo "Checksum for $file_name: $new_checksum" # Compare and update the YAML file if checksums do not match - result=$? + if [[ $result -eq 2 ]]; then echo "File not found, deleting entry for $file_name..." # yq eval -i "del(.[$idx].files[] | select(.filename == \"$file_name\"))" "$input_yaml" diff --git a/core/backend/llm.go b/core/backend/llm.go index a6f7fe56..9268fbbc 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -57,7 +57,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im if _, err := os.Stat(modelFile); os.IsNotExist(err) { utils.ResetDownloadTimers() // if we failed to load the model, we try to download it - err := gallery.InstallModelFromGallery(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) + err := gallery.InstallModelFromGallery(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans) if err != nil { return nil, err } diff --git a/core/cli/models.go b/core/cli/models.go index d62ad318..03047018 100644 --- a/core/cli/models.go +++ b/core/cli/models.go @@ -2,6 +2,7 @@ package cli import ( "encoding/json" + "errors" "fmt" cliContext "github.com/mudler/LocalAI/core/cli/context" @@ -24,7 +25,8 @@ type ModelsList struct { } type ModelsInstall struct { - ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"` + DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"` + ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"` ModelsCMDFlags `embed:""` } @@ -88,9 +90,15 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { return err } + err = gallery.SafetyScanGalleryModel(model) + if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) { + return err + } + log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model") } - err = startup.InstallModels(galleries, "", mi.ModelsPath, progressCallback, modelName) + + err = startup.InstallModels(galleries, "", mi.ModelsPath, !mi.DisablePredownloadScan, progressCallback, modelName) if err != nil { return err } diff --git a/core/cli/run.go b/core/cli/run.go index 4a313391..d7b45f77 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -42,26 +42,27 @@ type RunCMD struct { Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"` ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" default:"512" help:"Default context size for models" group:"performance"` - Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"` - CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"` - CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"` - LibraryPath string `env:"LOCALAI_LIBRARY_PATH,LIBRARY_PATH" help:"Path to the library directory (for e.g. external libraries used by backends)" default:"/usr/share/local-ai/libs" group:"backends"` - CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"` - UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"` - APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"` - DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"` - OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"api"` - Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"` - Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"` - ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"` - SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time" group:"backends"` - PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"` - ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"` - EnableWatchdogIdle bool `env:"LOCALAI_WATCHDOG_IDLE,WATCHDOG_IDLE" default:"false" help:"Enable watchdog for stopping backends that are idle longer than the watchdog-idle-timeout" group:"backends"` - WatchdogIdleTimeout string `env:"LOCALAI_WATCHDOG_IDLE_TIMEOUT,WATCHDOG_IDLE_TIMEOUT" default:"15m" help:"Threshold beyond which an idle backend should be stopped" group:"backends"` - EnableWatchdogBusy bool `env:"LOCALAI_WATCHDOG_BUSY,WATCHDOG_BUSY" default:"false" help:"Enable watchdog for stopping backends that are busy longer than the watchdog-busy-timeout" group:"backends"` - WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"` - Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"` + Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"` + CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"` + CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"` + LibraryPath string `env:"LOCALAI_LIBRARY_PATH,LIBRARY_PATH" help:"Path to the library directory (for e.g. external libraries used by backends)" default:"/usr/share/local-ai/libs" group:"backends"` + CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"` + UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"` + APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"` + DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"` + DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"` + OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"` + Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"` + Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"` + ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"` + SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time" group:"backends"` + PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"` + ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"` + EnableWatchdogIdle bool `env:"LOCALAI_WATCHDOG_IDLE,WATCHDOG_IDLE" default:"false" help:"Enable watchdog for stopping backends that are idle longer than the watchdog-idle-timeout" group:"backends"` + WatchdogIdleTimeout string `env:"LOCALAI_WATCHDOG_IDLE_TIMEOUT,WATCHDOG_IDLE_TIMEOUT" default:"15m" help:"Threshold beyond which an idle backend should be stopped" group:"backends"` + EnableWatchdogBusy bool `env:"LOCALAI_WATCHDOG_BUSY,WATCHDOG_BUSY" default:"false" help:"Enable watchdog for stopping backends that are busy longer than the watchdog-busy-timeout" group:"backends"` + WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"` + Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"` } func (r *RunCMD) Run(ctx *cliContext.Context) error { @@ -92,6 +93,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { config.WithApiKeys(r.APIKeys), config.WithModelsURL(append(r.Models, r.ModelArgs...)...), config.WithOpaqueErrors(r.OpaqueErrors), + config.WithEnforcedPredownloadScans(!r.DisablePredownloadScan), } token := "" diff --git a/core/cli/util.go b/core/cli/util.go index e8ccb942..a7204092 100644 --- a/core/cli/util.go +++ b/core/cli/util.go @@ -1,16 +1,22 @@ package cli import ( + "encoding/json" + "errors" "fmt" "github.com/rs/zerolog/log" cliContext "github.com/mudler/LocalAI/core/cli/context" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/pkg/downloader" gguf "github.com/thxcode/gguf-parser-go" ) type UtilCMD struct { GGUFInfo GGUFInfoCMD `cmd:"" name:"gguf-info" help:"Get information about a GGUF file"` + HFScan HFScanCMD `cmd:"" name:"hf-scan" help:"Checks installed models for known security issues. WARNING: this is a best-effort feature and may not catch everything!"` } type GGUFInfoCMD struct { @@ -18,6 +24,12 @@ type GGUFInfoCMD struct { Header bool `optional:"" default:"false" name:"header" help:"Show header information"` } +type HFScanCMD struct { + ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` + Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"` + ToScan []string `arg:""` +} + func (u *GGUFInfoCMD) Run(ctx *cliContext.Context) error { if u.Args == nil || len(u.Args) == 0 { return fmt.Errorf("no GGUF file provided") @@ -53,3 +65,37 @@ func (u *GGUFInfoCMD) Run(ctx *cliContext.Context) error { return nil } + +func (hfscmd *HFScanCMD) Run(ctx *cliContext.Context) error { + log.Info().Msg("LocalAI Security Scanner - This is BEST EFFORT functionality! Currently limited to huggingface models!") + if len(hfscmd.ToScan) == 0 { + log.Info().Msg("Checking all installed models against galleries") + var galleries []config.Gallery + if err := json.Unmarshal([]byte(hfscmd.Galleries), &galleries); err != nil { + log.Error().Err(err).Msg("unable to load galleries") + } + + err := gallery.SafetyScanGalleryModels(galleries, hfscmd.ModelsPath) + if err == nil { + log.Info().Msg("No security warnings were detected for your installed models. Please note that this is a BEST EFFORT tool, and all issues may not be detected.") + } else { + log.Error().Err(err).Msg("! WARNING ! A known-vulnerable model is installed!") + } + return err + } else { + var errs error = nil + for _, uri := range hfscmd.ToScan { + log.Info().Str("uri", uri).Msg("scanning specific uri") + scanResults, err := downloader.HuggingFaceScan(uri) + if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) { + log.Error().Err(err).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("! WARNING ! A known-vulnerable model is included in this repo!") + errs = errors.Join(errs, err) + } + } + if errs != nil { + return errs + } + log.Info().Msg("No security warnings were detected for your installed models. Please note that this is a BEST EFFORT tool, and all issues may not be detected.") + return nil + } +} diff --git a/core/config/application_config.go b/core/config/application_config.go index 1bac349b..7233d1ac 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -31,6 +31,7 @@ type ApplicationConfig struct { PreloadModelsFromPath string CORSAllowOrigins string ApiKeys []string + EnforcePredownloadScans bool OpaqueErrors bool P2PToken string @@ -301,6 +302,12 @@ func WithApiKeys(apiKeys []string) AppOption { } } +func WithEnforcedPredownloadScans(enforced bool) AppOption { + return func(o *ApplicationConfig) { + o.EnforcePredownloadScans = enforced + } +} + func WithOpaqueErrors(opaque bool) AppOption { return func(o *ApplicationConfig) { o.OpaqueErrors = opaque diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go index be167755..231dce6d 100644 --- a/core/gallery/gallery.go +++ b/core/gallery/gallery.go @@ -15,7 +15,7 @@ import ( ) // Installs a model from the gallery -func InstallModelFromGallery(galleries []config.Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { +func InstallModelFromGallery(galleries []config.Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan bool) error { applyModel := func(model *GalleryModel) error { name = strings.ReplaceAll(name, string(os.PathSeparator), "__") @@ -63,7 +63,7 @@ func InstallModelFromGallery(galleries []config.Gallery, name string, basePath s return err } - if err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus); err != nil { + if err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus, enforceScan); err != nil { return err } @@ -228,3 +228,29 @@ func DeleteModelFromSystem(basePath string, name string, additionalFiles []strin return err } + +// This is ***NEVER*** going to be perfect or finished. +// This is a BEST EFFORT function to surface known-vulnerable models to users. +func SafetyScanGalleryModels(galleries []config.Gallery, basePath string) error { + galleryModels, err := AvailableGalleryModels(galleries, basePath) + if err != nil { + return err + } + for _, gM := range galleryModels { + if gM.Installed { + err = errors.Join(err, SafetyScanGalleryModel(gM)) + } + } + return err +} + +func SafetyScanGalleryModel(galleryModel *GalleryModel) error { + for _, file := range galleryModel.AdditionalFiles { + scanResults, err := downloader.HuggingFaceScan(file.URI) + if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) { + log.Error().Str("model", galleryModel.Name).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("Contains unsafe file(s)!") + return err + } + } + return nil +} diff --git a/core/gallery/models.go b/core/gallery/models.go index 8d020ff5..28a2e3f2 100644 --- a/core/gallery/models.go +++ b/core/gallery/models.go @@ -1,6 +1,7 @@ package gallery import ( + "errors" "fmt" "os" "path/filepath" @@ -94,7 +95,7 @@ func ReadConfigFile(filePath string) (*Config, error) { return &config, nil } -func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { +func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) error { // Create base path if it doesn't exist err := os.MkdirAll(basePath, 0750) if err != nil { @@ -112,9 +113,18 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides if err := utils.VerifyPath(file.Filename, basePath); err != nil { return err } + // Create file path filePath := filepath.Join(basePath, file.Filename) + if enforceScan { + scanResults, err := downloader.HuggingFaceScan(file.URI) + if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) { + log.Error().Str("model", config.Name).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("Contains unsafe file(s)!") + return err + } + } + if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil { return err } diff --git a/core/gallery/models_test.go b/core/gallery/models_test.go index 17a30911..5217253f 100644 --- a/core/gallery/models_test.go +++ b/core/gallery/models_test.go @@ -21,7 +21,7 @@ var _ = Describe("Model test", func() { defer os.RemoveAll(tempdir) c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}) + err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { @@ -69,7 +69,7 @@ var _ = Describe("Model test", func() { Expect(models[0].URL).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml")) Expect(models[0].Installed).To(BeFalse()) - err = InstallModelFromGallery(galleries, "test@bert", tempdir, GalleryModel{}, func(s1, s2, s3 string, f float64) {}) + err = InstallModelFromGallery(galleries, "test@bert", tempdir, GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true) Expect(err).ToNot(HaveOccurred()) dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml")) @@ -106,7 +106,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) + err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { @@ -122,7 +122,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}) + err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { @@ -148,7 +148,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) + err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) Expect(err).To(HaveOccurred()) }) }) diff --git a/core/services/gallery.go b/core/services/gallery.go index 2c0ed435..45bebd4f 100644 --- a/core/services/gallery.go +++ b/core/services/gallery.go @@ -30,7 +30,7 @@ func NewGalleryService(appConfig *config.ApplicationConfig) *GalleryService { } } -func prepareModel(modelPath string, req gallery.GalleryModel, downloadStatus func(string, string, string, float64)) error { +func prepareModel(modelPath string, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan bool) error { config, err := gallery.GetGalleryConfigFromURL(req.URL, modelPath) if err != nil { @@ -39,7 +39,7 @@ func prepareModel(modelPath string, req gallery.GalleryModel, downloadStatus fun config.Files = append(config.Files, req.AdditionalFiles...) - return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) + return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus, enforceScan) } func (g *GalleryService) UpdateStatus(s string, op *gallery.GalleryOpStatus) { @@ -127,16 +127,16 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader } else { // if the request contains a gallery name, we apply the gallery from the gallery list if op.GalleryModelName != "" { - err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryModelName, g.appConfig.ModelPath, op.Req, progressCallback) + err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryModelName, g.appConfig.ModelPath, op.Req, progressCallback, g.appConfig.EnforcePredownloadScans) } else if op.ConfigURL != "" { - err = startup.InstallModels(op.Galleries, op.ConfigURL, g.appConfig.ModelPath, progressCallback, op.ConfigURL) + err = startup.InstallModels(op.Galleries, op.ConfigURL, g.appConfig.ModelPath, g.appConfig.EnforcePredownloadScans, progressCallback, op.ConfigURL) if err != nil { updateError(err) continue } err = cl.Preload(g.appConfig.ModelPath) } else { - err = prepareModel(g.appConfig.ModelPath, op.Req, progressCallback) + err = prepareModel(g.appConfig.ModelPath, op.Req, progressCallback, g.appConfig.EnforcePredownloadScans) } } @@ -175,22 +175,22 @@ type galleryModel struct { ID string `json:"id"` } -func processRequests(modelPath string, galleries []config.Gallery, requests []galleryModel) error { +func processRequests(modelPath string, enforceScan bool, galleries []config.Gallery, requests []galleryModel) error { var err error for _, r := range requests { utils.ResetDownloadTimers() if r.ID == "" { - err = prepareModel(modelPath, r.GalleryModel, utils.DisplayDownloadFunction) + err = prepareModel(modelPath, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan) } else { err = gallery.InstallModelFromGallery( - galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) + galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan) } } return err } -func ApplyGalleryFromFile(modelPath, s string, galleries []config.Gallery) error { +func ApplyGalleryFromFile(modelPath, s string, enforceScan bool, galleries []config.Gallery) error { dat, err := os.ReadFile(s) if err != nil { return err @@ -201,15 +201,15 @@ func ApplyGalleryFromFile(modelPath, s string, galleries []config.Gallery) error return err } - return processRequests(modelPath, galleries, requests) + return processRequests(modelPath, enforceScan, galleries, requests) } -func ApplyGalleryFromString(modelPath, s string, galleries []config.Gallery) error { +func ApplyGalleryFromString(modelPath, s string, enforceScan bool, galleries []config.Gallery) error { var requests []galleryModel err := json.Unmarshal([]byte(s), &requests) if err != nil { return err } - return processRequests(modelPath, galleries, requests) + return processRequests(modelPath, enforceScan, galleries, requests) } diff --git a/core/startup/startup.go b/core/startup/startup.go index 278c8e1c..66111b59 100644 --- a/core/startup/startup.go +++ b/core/startup/startup.go @@ -60,7 +60,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode } } - if err := pkgStartup.InstallModels(options.Galleries, options.ModelLibraryURL, options.ModelPath, nil, options.ModelsURL...); err != nil { + if err := pkgStartup.InstallModels(options.Galleries, options.ModelLibraryURL, options.ModelPath, options.EnforcePredownloadScans, nil, options.ModelsURL...); err != nil { log.Error().Err(err).Msg("error installing models") } @@ -84,13 +84,13 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode } if options.PreloadJSONModels != "" { - if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, options.Galleries); err != nil { + if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, options.EnforcePredownloadScans, options.Galleries); err != nil { return nil, nil, nil, err } } if options.PreloadModelsFromPath != "" { - if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, options.Galleries); err != nil { + if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, options.EnforcePredownloadScans, options.Galleries); err != nil { return nil, nil, nil, err } } diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index de575d63..1f88bbb1 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -3,6 +3,8 @@ package downloader import ( "crypto/sha256" "encoding/base64" + "encoding/json" + "errors" "fmt" "io" "net/http" @@ -129,6 +131,7 @@ func ConvertURL(s string) string { // e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf owner := strings.Split(repository, "/")[0] repo := strings.Split(repository, "/")[1] + branch := "main" if strings.Contains(repo, "@") { branch = strings.Split(repository, "@")[1] @@ -353,3 +356,42 @@ func calculateSHA(filePath string) (string, error) { return fmt.Sprintf("%x", hash.Sum(nil)), nil } + +type HuggingFaceScanResult struct { + RepositoryId string `json:"repositoryId"` + Revision string `json:"revision"` + HasUnsafeFiles bool `json:"hasUnsafeFile"` + ClamAVInfectedFiles []string `json:"clamAVInfectedFiles"` + DangerousPickles []string `json:"dangerousPickles"` + ScansDone bool `json:"scansDone"` +} + +var ErrNonHuggingFaceFile = errors.New("not a huggingface repo") +var ErrUnsafeFilesFound = errors.New("unsafe files found") + +func HuggingFaceScan(uri string) (*HuggingFaceScanResult, error) { + cleanParts := strings.Split(ConvertURL(uri), "/") + if len(cleanParts) <= 4 || cleanParts[2] != "huggingface.co" { + return nil, ErrNonHuggingFaceFile + } + results, err := http.Get(fmt.Sprintf("https://huggingface.co/api/models/%s/%s/scan", cleanParts[3], cleanParts[4])) + if err != nil { + return nil, err + } + if results.StatusCode != 200 { + return nil, fmt.Errorf("unexpected status code during HuggingFaceScan: %d", results.StatusCode) + } + scanResult := &HuggingFaceScanResult{} + bodyBytes, err := io.ReadAll(results.Body) + if err != nil { + return nil, err + } + err = json.Unmarshal(bodyBytes, scanResult) + if err != nil { + return nil, err + } + if scanResult.HasUnsafeFiles { + return scanResult, ErrUnsafeFilesFound + } + return scanResult, nil +} diff --git a/pkg/startup/model_preload.go b/pkg/startup/model_preload.go index d678f283..74a10e9e 100644 --- a/pkg/startup/model_preload.go +++ b/pkg/startup/model_preload.go @@ -18,7 +18,7 @@ import ( // InstallModels will preload models from the given list of URLs and galleries // It will download the model if it is not already present in the model path // It will also try to resolve if the model is an embedded model YAML configuration -func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath string, downloadStatus func(string, string, string, float64), models ...string) error { +func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath string, enforceScan bool, downloadStatus func(string, string, string, float64), models ...string) error { // create an error that groups all errors var err error @@ -113,7 +113,7 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath } } else { // Check if it's a model gallery, or print a warning - e, found := installModel(galleries, url, modelPath, downloadStatus) + e, found := installModel(galleries, url, modelPath, downloadStatus, enforceScan) if e != nil && found { log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url) err = errors.Join(err, e) @@ -127,7 +127,7 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath return err } -func installModel(galleries []config.Gallery, modelName, modelPath string, downloadStatus func(string, string, string, float64)) (error, bool) { +func installModel(galleries []config.Gallery, modelName, modelPath string, downloadStatus func(string, string, string, float64), enforceScan bool) (error, bool) { models, err := gallery.AvailableGalleryModels(galleries, modelPath) if err != nil { return err, false @@ -143,7 +143,7 @@ func installModel(galleries []config.Gallery, modelName, modelPath string, downl } log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model") - err = gallery.InstallModelFromGallery(galleries, modelName, modelPath, gallery.GalleryModel{}, downloadStatus) + err = gallery.InstallModelFromGallery(galleries, modelName, modelPath, gallery.GalleryModel{}, downloadStatus, enforceScan) if err != nil { return err, true } diff --git a/pkg/startup/model_preload_test.go b/pkg/startup/model_preload_test.go index e3d7d979..939ad1a2 100644 --- a/pkg/startup/model_preload_test.go +++ b/pkg/startup/model_preload_test.go @@ -22,7 +22,7 @@ var _ = Describe("Preload test", func() { libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml" fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719") - InstallModels([]config.Gallery{}, libraryURL, tmpdir, nil, "phi-2") + InstallModels([]config.Gallery{}, libraryURL, tmpdir, true, nil, "phi-2") resultFile := filepath.Join(tmpdir, fileName) @@ -38,7 +38,7 @@ var _ = Describe("Preload test", func() { url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml" fileName := fmt.Sprintf("%s.yaml", utils.MD5(url)) - InstallModels([]config.Gallery{}, "", tmpdir, nil, url) + InstallModels([]config.Gallery{}, "", tmpdir, true, nil, url) resultFile := filepath.Join(tmpdir, fileName) @@ -52,7 +52,7 @@ var _ = Describe("Preload test", func() { Expect(err).ToNot(HaveOccurred()) url := "phi-2" - InstallModels([]config.Gallery{}, "", tmpdir, nil, url) + InstallModels([]config.Gallery{}, "", tmpdir, true, nil, url) entry, err := os.ReadDir(tmpdir) Expect(err).ToNot(HaveOccurred()) @@ -70,7 +70,7 @@ var _ = Describe("Preload test", func() { url := "mistral-openorca" fileName := fmt.Sprintf("%s.yaml", utils.MD5(url)) - InstallModels([]config.Gallery{}, "", tmpdir, nil, url) + InstallModels([]config.Gallery{}, "", tmpdir, true, nil, url) resultFile := filepath.Join(tmpdir, fileName)