Update utilities.py

This commit is contained in:
Saifeddine ALOUI 2023-11-15 14:08:40 +01:00 committed by GitHub
parent c9ea6efcff
commit 1908ee3532
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -61,6 +61,38 @@ def reinstall_pytorch_with_cpu():
ASCIIColors.error("Pytorch installed successfully!!")
def check_and_install_torch(enable_gpu:bool, version:float=2.1):
if enable_gpu:
ASCIIColors.yellow("This installation has enabled GPU support. Trying to install with GPU support")
ASCIIColors.info("Checking pytorch")
try:
import torch
import torchvision
if torch.cuda.is_available():
ASCIIColors.success(f"CUDA is supported.\nCurrent version is {torch.__version__}.")
if not check_torch_version(version):
ASCIIColors.yellow("Torch version is old. Installing new version")
reinstall_pytorch_with_cuda()
else:
ASCIIColors.yellow("Torch OK")
else:
ASCIIColors.warning("CUDA is not supported. Trying to reinstall PyTorch with CUDA support.")
reinstall_pytorch_with_cuda()
except Exception as ex:
ASCIIColors.info("Pytorch not installed. Reinstalling ...")
reinstall_pytorch_with_cuda()
else:
try:
import torch
import torchvision
if check_torch_version(version):
ASCIIColors.warning("Torch version is too old. Trying to reinstall PyTorch with CUDA support.")
reinstall_pytorch_with_cpu()
except Exception as ex:
ASCIIColors.info("Pytorch not installed. Reinstalling ...")
reinstall_pytorch_with_cpu()
class NumpyEncoderDecoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):