mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 12:26:26 +00:00
3e8e71f8b6
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
86 lines
2.8 KiB
Python
86 lines
2.8 KiB
Python
import hashlib
|
|
from huggingface_hub import hf_hub_download, get_paths_info
|
|
import requests
|
|
import sys
|
|
import os
|
|
|
|
uri = sys.argv[1]
|
|
file_name = uri.split('/')[-1]
|
|
|
|
# Function to parse the URI and determine download method
|
|
def parse_uri(uri):
|
|
if uri.startswith('huggingface://'):
|
|
repo_id = uri.split('://')[1]
|
|
return 'huggingface', repo_id.rsplit('/', 1)[0]
|
|
elif 'huggingface.co' in uri:
|
|
parts = uri.split('/resolve/')
|
|
if len(parts) > 1:
|
|
repo_path = parts[0].split('https://huggingface.co/')[-1]
|
|
return 'huggingface', repo_path
|
|
return 'direct', uri
|
|
|
|
def calculate_sha256(file_path):
|
|
sha256_hash = hashlib.sha256()
|
|
with open(file_path, 'rb') as f:
|
|
for byte_block in iter(lambda: f.read(4096), b''):
|
|
sha256_hash.update(byte_block)
|
|
return sha256_hash.hexdigest()
|
|
|
|
def manual_safety_check_hf(repo_id):
|
|
scanResponse = requests.get('https://huggingface.co/api/models/' + repo_id + "/scan")
|
|
scan = scanResponse.json()
|
|
# Check if 'hasUnsafeFile' exists in the response
|
|
if 'hasUnsafeFile' in scan:
|
|
if scan['hasUnsafeFile']:
|
|
return scan
|
|
else:
|
|
return None
|
|
else:
|
|
return None
|
|
|
|
download_type, repo_id_or_url = parse_uri(uri)
|
|
|
|
new_checksum = None
|
|
file_path = None
|
|
|
|
# Decide download method based on URI type
|
|
if download_type == 'huggingface':
|
|
# Check if the repo is flagged as dangerous by HF
|
|
hazard = manual_safety_check_hf(repo_id_or_url)
|
|
if hazard != None:
|
|
print(f'Error: HuggingFace has detected security problems for {repo_id_or_url}: {str(hazard)}', filename=file_name)
|
|
sys.exit(5)
|
|
# Use HF API to pull sha
|
|
for file in get_paths_info(repo_id_or_url, [file_name], repo_type='model'):
|
|
try:
|
|
new_checksum = file.lfs.sha256
|
|
break
|
|
except Exception as e:
|
|
print(f'Error from Hugging Face Hub: {str(e)}', file=sys.stderr)
|
|
sys.exit(2)
|
|
if new_checksum is None:
|
|
try:
|
|
file_path = hf_hub_download(repo_id=repo_id_or_url, filename=file_name)
|
|
except Exception as e:
|
|
print(f'Error from Hugging Face Hub: {str(e)}', file=sys.stderr)
|
|
sys.exit(2)
|
|
else:
|
|
response = requests.get(repo_id_or_url)
|
|
if response.status_code == 200:
|
|
with open(file_name, 'wb') as f:
|
|
f.write(response.content)
|
|
file_path = file_name
|
|
elif response.status_code == 404:
|
|
print(f'File not found: {response.status_code}', file=sys.stderr)
|
|
sys.exit(2)
|
|
else:
|
|
print(f'Error downloading file: {response.status_code}', file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
if new_checksum is None:
|
|
new_checksum = calculate_sha256(file_path)
|
|
print(new_checksum)
|
|
os.remove(file_path)
|
|
else:
|
|
print(new_checksum)
|