diff --git a/lollms/utilities.py b/lollms/utilities.py index 85e8f66..05cc077 100644 --- a/lollms/utilities.py +++ b/lollms/utilities.py @@ -61,6 +61,38 @@ def reinstall_pytorch_with_cpu(): ASCIIColors.error("Pytorch installed successfully!!") +def check_and_install_torch(enable_gpu:bool, version:float=2.1): + if enable_gpu: + ASCIIColors.yellow("This installation has enabled GPU support. Trying to install with GPU support") + ASCIIColors.info("Checking pytorch") + try: + import torch + import torchvision + if torch.cuda.is_available(): + ASCIIColors.success(f"CUDA is supported.\nCurrent version is {torch.__version__}.") + if not check_torch_version(version): + ASCIIColors.yellow("Torch version is old. Installing new version") + reinstall_pytorch_with_cuda() + else: + ASCIIColors.yellow("Torch OK") + else: + ASCIIColors.warning("CUDA is not supported. Trying to reinstall PyTorch with CUDA support.") + reinstall_pytorch_with_cuda() + except Exception as ex: + ASCIIColors.info("Pytorch not installed. Reinstalling ...") + reinstall_pytorch_with_cuda() + else: + try: + import torch + import torchvision + if check_torch_version(version): + ASCIIColors.warning("Torch version is too old. Trying to reinstall PyTorch with CUDA support.") + reinstall_pytorch_with_cpu() + except Exception as ex: + ASCIIColors.info("Pytorch not installed. Reinstalling ...") + reinstall_pytorch_with_cpu() + + class NumpyEncoderDecoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray):