mirror of
https://github.com/ParisNeo/lollms.git
synced 2024-12-24 06:46:40 +00:00
fixed file download
This commit is contained in:
parent
f26e26765c
commit
00d593815e
@ -27,6 +27,9 @@ import inspect
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from lollms.utilities import trace_exception
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
__author__ = "parisneo"
|
||||
__github__ = "https://github.com/ParisNeo/lollms_bindings_zoo"
|
||||
__copyright__ = "Copyright 2023, "
|
||||
@ -103,6 +106,41 @@ class LLMBinding:
|
||||
self.configuration_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
binding_config.config.file_path = self.configuration_file_path
|
||||
|
||||
|
||||
def download_file(self, url, installation_path, callback=None):
|
||||
"""
|
||||
Downloads a file from a URL, reports the download progress using a callback function, and displays a progress bar.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the file to download.
|
||||
installation_path (str): The path where the file should be saved.
|
||||
callback (function, optional): A callback function to be called during the download
|
||||
with the progress percentage as an argument. Defaults to None.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
|
||||
# Get the file size from the response headers
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
|
||||
with open(installation_path, 'wb') as file:
|
||||
downloaded_size = 0
|
||||
with tqdm(total=total_size, unit='B', unit_scale=True, ncols=80) as progress_bar:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
file.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
if callback is not None:
|
||||
callback(downloaded_size, total_size)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
if callback is not None:
|
||||
callback(total_size, total_size)
|
||||
|
||||
print("File downloaded successfully")
|
||||
except Exception as e:
|
||||
print("Couldn't download file:", str(e))
|
||||
|
||||
def install_model(self, model_type:str, model_path:str, variant_name:str, client_id:int=None):
|
||||
print("Install model triggered")
|
||||
model_path = model_path.replace("\\","/")
|
||||
|
Loading…
Reference in New Issue
Block a user