fixed file download

This commit is contained in:
Saifeddine ALOUI 2024-01-21 19:10:48 +01:00
parent f26e26765c
commit 00d593815e

View File

@ -27,6 +27,9 @@ import inspect
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from lollms.utilities import trace_exception from lollms.utilities import trace_exception
from tqdm import tqdm
__author__ = "parisneo" __author__ = "parisneo"
__github__ = "https://github.com/ParisNeo/lollms_bindings_zoo" __github__ = "https://github.com/ParisNeo/lollms_bindings_zoo"
__copyright__ = "Copyright 2023, " __copyright__ = "Copyright 2023, "
@ -103,6 +106,41 @@ class LLMBinding:
self.configuration_file_path.parent.mkdir(parents=True, exist_ok=True) self.configuration_file_path.parent.mkdir(parents=True, exist_ok=True)
binding_config.config.file_path = self.configuration_file_path binding_config.config.file_path = self.configuration_file_path
def download_file(self, url, installation_path, callback=None):
"""
Downloads a file from a URL, reports the download progress using a callback function, and displays a progress bar.
Args:
url (str): The URL of the file to download.
installation_path (str): The path where the file should be saved.
callback (function, optional): A callback function to be called during the download
with the progress percentage as an argument. Defaults to None.
"""
try:
response = requests.get(url, stream=True)
# Get the file size from the response headers
total_size = int(response.headers.get('content-length', 0))
with open(installation_path, 'wb') as file:
downloaded_size = 0
with tqdm(total=total_size, unit='B', unit_scale=True, ncols=80) as progress_bar:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
file.write(chunk)
downloaded_size += len(chunk)
if callback is not None:
callback(downloaded_size, total_size)
progress_bar.update(len(chunk))
if callback is not None:
callback(total_size, total_size)
print("File downloaded successfully")
except Exception as e:
print("Couldn't download file:", str(e))
def install_model(self, model_type:str, model_path:str, variant_name:str, client_id:int=None): def install_model(self, model_type:str, model_path:str, variant_name:str, client_id:int=None):
print("Install model triggered") print("Install model triggered")
model_path = model_path.replace("\\","/") model_path = model_path.replace("\\","/")