diff --git a/app.py b/app.py index ade55c97..7f2f56a9 100644 --- a/app.py +++ b/app.py @@ -27,7 +27,7 @@ import traceback import webbrowser from pathlib import Path from lollms.utilities import AdvancedGarbageCollector - +from lollms.utilities import reinstall_pytorch_with_cuda def run_update_script(args=None): update_script = Path(__file__).parent/"update_script.py" @@ -1019,28 +1019,14 @@ try: def upgrade_to_gpu(self): ASCIIColors.yellow("Received command to upgrade to GPU") ASCIIColors.info("Installing cuda toolkit") - res = subprocess.check_call(["conda", "install", "-c", "nvidia/label/cuda-11.7.0", "-c", "nvidia", "-c", "conda-forge", "cuda-toolkit", "ninja", "git", "--force-reinstall", "-y"]) - if res!=0: - ASCIIColors.red("Couldn't install cuda toolkit") - return jsonify({'status':False, "error": "Couldn't install cuda toolkit. Make sure you are running from conda environment"}) - ASCIIColors.green("Cuda toolkit installed successfully") ASCIIColors.yellow("Removing pytorch") try: res = subprocess.check_call(["pip","uninstall","torch", "torchvision", "torchaudio", "-y"]) except : pass - ASCIIColors.green("PyTorch unstalled successfully") + ASCIIColors.green("PyTorch uninstalled successfully") + reinstall_pytorch_with_cuda() ASCIIColors.yellow("Installing pytorch with cuda support") - res = subprocess.check_call(["pip","install","--upgrade","torch==2.0.1+cu117", "torchvision", "torchaudio", "--index-url", "https://download.pytorch.org/whl/cu117","--no-cache"]) - if res==0: - ASCIIColors.green("PyTorch installed successfully") - import torch - if torch.cuda.is_available(): - ASCIIColors.success("CUDA is supported.") - else: - ASCIIColors.warning("CUDA is not supported. This may mean that the upgrade didn't succeed. Try rebooting the application") - else: - ASCIIColors.green("An error hapened") self.config.enable_gpu=True return jsonify({'status':res==0})