LocalAI/.github/check_and_update.py
Dave 133987b1fb
feat: HF /scan endpoint (#2566)
* start by checking /scan during the checksum update

Signed-off-by: Dave Lee <dave@gray101.com>

* add back in golang side features: downloader/uri gets struct and scan function, gallery uses it, and secscan/models calls it.

Signed-off-by: Dave Lee <dave@gray101.com>

* add a param to scan specific urls - useful for debugging

Signed-off-by: Dave Lee <dave@gray101.com>

* helpful printouts

Signed-off-by: Dave Lee <dave@gray101.com>

* fix offsets

Signed-off-by: Dave Lee <dave@gray101.com>

* fix error and naming

Signed-off-by: Dave Lee <dave@gray101.com>

* expose error

Signed-off-by: Dave Lee <dave@gray101.com>

* fix json tags

Signed-off-by: Dave Lee <dave@gray101.com>

* slight wording change

Signed-off-by: Dave Lee <dave@gray101.com>

* go mod tidy - getting warnings

Signed-off-by: Dave Lee <dave@gray101.com>

* split out python to make editing easier, add some simple code  to delete contaminated entries from gallery

Signed-off-by: Dave Lee <dave@gray101.com>

* o7 to my favorite part of our old name, go-skynet

Signed-off-by: Dave Lee <dave@gray101.com>

* merge fix

Signed-off-by: Dave Lee <dave@gray101.com>

* merge fix

Signed-off-by: Dave Lee <dave@gray101.com>

* merge fix

Signed-off-by: Dave Lee <dave@gray101.com>

* address review comments

Signed-off-by: Dave Lee <dave@gray101.com>

* forgot secscan could accept multiple URL at once

Signed-off-by: Dave Lee <dave@gray101.com>

* invert naming and actually use it

Signed-off-by: Dave Lee <dave@gray101.com>

* missed cli/models.go

Signed-off-by: Dave Lee <dave@gray101.com>

* Update .github/check_and_update.py

Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
Signed-off-by: Dave <dave@gray101.com>

---------

Signed-off-by: Dave Lee <dave@gray101.com>
Signed-off-by: Dave <dave@gray101.com>
Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-07-10 13:18:32 +02:00

80 lines
2.7 KiB
Python

import hashlib
from huggingface_hub import hf_hub_download, get_paths_info
import requests
import sys
import os
uri = sys.argv[0]
file_name = uri.split('/')[-1]
# Function to parse the URI and determine download method
def parse_uri(uri):
if uri.startswith('huggingface://'):
repo_id = uri.split('://')[1]
return 'huggingface', repo_id.rsplit('/', 1)[0]
elif 'huggingface.co' in uri:
parts = uri.split('/resolve/')
if len(parts) > 1:
repo_path = parts[0].split('https://huggingface.co/')[-1]
return 'huggingface', repo_path
return 'direct', uri
def calculate_sha256(file_path):
sha256_hash = hashlib.sha256()
with open(file_path, 'rb') as f:
for byte_block in iter(lambda: f.read(4096), b''):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def manual_safety_check_hf(repo_id):
scanResponse = requests.get('https://huggingface.co/api/models/' + repo_id + "/scan")
scan = scanResponse.json()
if scan['hasUnsafeFile']:
return scan
return None
download_type, repo_id_or_url = parse_uri(uri)
new_checksum = None
# Decide download method based on URI type
if download_type == 'huggingface':
# Check if the repo is flagged as dangerous by HF
hazard = manual_safety_check_hf(repo_id_or_url)
if hazard != None:
print(f'Error: HuggingFace has detected security problems for {repo_id_or_url}: {str(hazard)}', filename=file_name)
sys.exit(5)
# Use HF API to pull sha
for file in get_paths_info(repo_id_or_url, [file_name], repo_type='model'):
try:
new_checksum = file.lfs.sha256
break
except Exception as e:
print(f'Error from Hugging Face Hub: {str(e)}', file=sys.stderr)
sys.exit(2)
if new_checksum is None:
try:
file_path = hf_hub_download(repo_id=repo_id_or_url, filename=file_name)
except Exception as e:
print(f'Error from Hugging Face Hub: {str(e)}', file=sys.stderr)
sys.exit(2)
else:
response = requests.get(repo_id_or_url)
if response.status_code == 200:
with open(file_name, 'wb') as f:
f.write(response.content)
file_path = file_name
elif response.status_code == 404:
print(f'File not found: {response.status_code}', file=sys.stderr)
sys.exit(2)
else:
print(f'Error downloading file: {response.status_code}', file=sys.stderr)
sys.exit(1)
if new_checksum is None:
new_checksum = calculate_sha256(file_path)
print(new_checksum)
os.remove(file_path)
else:
print(new_checksum)