mirror of
https://github.com/ParisNeo/lollms.git
synced 2025-01-05 12:24:21 +00:00
Upgraded
This commit is contained in:
parent
e09f42bb89
commit
82e10f2eca
@ -6,7 +6,7 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from lollms.app import LollmsApplication
|
||||
from lollms.utilities import PackageManager, check_and_install_torch, find_next_available_filename, install_cuda, check_torch_version
|
||||
from lollms.utilities import PackageManager, check_and_install_torch, find_next_available_filename
|
||||
|
||||
import sys
|
||||
import requests
|
||||
|
@ -29,12 +29,13 @@ from typing import List, Dict, Any
|
||||
from ascii_colors import ASCIIColors, trace_exception
|
||||
from lollms.paths import LollmsPaths
|
||||
from lollms.tti import LollmsTTI
|
||||
from lollms.utilities import git_pull, show_yes_no_dialog, run_script_in_env, create_env
|
||||
from lollms.utilities import git_pull, show_yes_no_dialog, EnvironmentManager
|
||||
import subprocess
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
import threading
|
||||
|
||||
em = EnvironmentManager()
|
||||
|
||||
|
||||
def download_file(url, folder_path, local_filename):
|
||||
@ -64,7 +65,7 @@ def install_sd(lollms_app:LollmsApplication):
|
||||
shared_folder = root_dir/"shared"
|
||||
sd_folder = shared_folder / "auto_sd"
|
||||
ASCIIColors.cyan("Installing autosd conda environment with python 3.10")
|
||||
create_env("autosd","3.10")
|
||||
em.create_env("autosd","3.10")
|
||||
ASCIIColors.cyan("Done")
|
||||
if os.path.exists(str(sd_folder)):
|
||||
print("Repository already exists. Pulling latest changes...")
|
||||
|
@ -35,106 +35,253 @@ import sys
|
||||
import git
|
||||
|
||||
import mimetypes
|
||||
import sys
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
import pkg_resources
|
||||
from functools import partial
|
||||
|
||||
import pipmaster as pm
|
||||
if not pm.is_installed("Pillow"):
|
||||
pm.install("Pillow")
|
||||
from PIL import Image
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import platform
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
def create_env(env_name, python_version):
|
||||
class EnvManager(Enum):
|
||||
CONDA = 'conda'
|
||||
VENV = 'venv'
|
||||
PYENV = 'pyenv'
|
||||
PIP = 'pip'
|
||||
|
||||
class EnvironmentManager:
|
||||
def __init__(self, preferred_manager=None):
|
||||
"""
|
||||
Initialize environment manager with optional preferred manager.
|
||||
Args:
|
||||
preferred_manager (str, optional): 'conda', 'venv', 'pyenv', or 'pip'
|
||||
"""
|
||||
self.preferred_manager = EnvManager(preferred_manager) if preferred_manager else None
|
||||
self.manager = self._detect_env_manager()
|
||||
self.manager_path = self._get_env_manager_path()
|
||||
|
||||
def _get_env_manager_path(self):
|
||||
"""Get the path of the environment manager executable"""
|
||||
if platform.system() == 'Windows':
|
||||
ext = '.exe'
|
||||
else:
|
||||
ext = ''
|
||||
|
||||
if self.manager == EnvManager.CONDA:
|
||||
# Check for portable conda
|
||||
portable_conda = os.getenv('PORTABLE_CONDA_PATH')
|
||||
if portable_conda:
|
||||
return os.path.join(portable_conda, 'condabin', f'conda{ext}')
|
||||
|
||||
# Check standard conda locations
|
||||
possible_paths = [
|
||||
os.path.join(sys.prefix, 'condabin', f'conda{ext}'),
|
||||
os.path.join(os.path.expanduser('~'), 'miniconda3', 'condabin', f'conda{ext}'),
|
||||
os.path.join(os.path.expanduser('~'), 'anaconda3', 'condabin', f'conda{ext}')
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if os.path.exists(path):
|
||||
return path
|
||||
|
||||
# Use which/where to find the executable
|
||||
try:
|
||||
if platform.system() == 'Windows':
|
||||
result = subprocess.run(['where', self.manager.value], capture_output=True, text=True)
|
||||
else:
|
||||
result = subprocess.run(['which', self.manager.value], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip().split('\n')[0]
|
||||
except:
|
||||
pass
|
||||
|
||||
def run_pip_in_env(env_name, pip_args, cwd=None):
|
||||
import platform
|
||||
# Set the current working directory if provided, otherwise use the current directory
|
||||
if cwd is None:
|
||||
cwd = os.getcwd()
|
||||
return None
|
||||
|
||||
# Activate the Conda environment
|
||||
python_path = Path(sys.executable).parent.parent/"miniconda3"/"envs"/env_name/"python"
|
||||
ASCIIColors.yellow(f"Executing: {python_path} -m pip {pip_args}")
|
||||
process = subprocess.Popen(f'{python_path} -m pip {pip_args}', shell=True)
|
||||
def _detect_env_manager(self):
|
||||
"""Detect which environment manager to use based on context and preference"""
|
||||
# If preferred manager is specified and available, use it
|
||||
if self.preferred_manager:
|
||||
if self._check_manager_available(self.preferred_manager):
|
||||
return self.preferred_manager
|
||||
|
||||
# Wait for the process to finish
|
||||
process.wait()
|
||||
# Check if we're in a venv
|
||||
if sys.prefix != sys.base_prefix:
|
||||
return EnvManager.VENV
|
||||
|
||||
# Check for conda (both portable and installed)
|
||||
if self._check_manager_available(EnvManager.CONDA):
|
||||
return EnvManager.CONDA
|
||||
|
||||
def run_python_script_in_env(env_name, script_path, cwd=None, wait=True):
|
||||
import platform
|
||||
# Set the current working directory if provided, otherwise use the current directory
|
||||
if cwd is None:
|
||||
cwd = os.getcwd()
|
||||
# Check for pyenv
|
||||
if self._check_manager_available(EnvManager.PYENV):
|
||||
return EnvManager.PYENV
|
||||
|
||||
# Activate the Conda environment
|
||||
python_path = Path(sys.executable).parent.parent/"miniconda3"/"envs"/env_name/"python"
|
||||
ASCIIColors.yellow(f"Executing: {python_path} {script_path}")
|
||||
process = subprocess.Popen(f'{python_path} {script_path}', shell=True)
|
||||
# Default to pip
|
||||
return EnvManager.PIP
|
||||
|
||||
def _check_manager_available(self, manager):
|
||||
"""Check if a specific environment manager is available"""
|
||||
try:
|
||||
if platform.system() == 'Windows':
|
||||
result = subprocess.run(['where', manager.value], capture_output=True)
|
||||
else:
|
||||
result = subprocess.run(['which', manager.value], capture_output=True)
|
||||
return result.returncode == 0
|
||||
except:
|
||||
return False
|
||||
|
||||
def _get_env_python_path(self, env_name):
|
||||
"""Get the Python executable path for the environment"""
|
||||
if platform.system() == 'Windows':
|
||||
return os.path.join(env_name, 'Scripts', 'python.exe')
|
||||
return os.path.join(env_name, 'bin', 'python')
|
||||
|
||||
def create_env(self, env_name, python_version):
|
||||
"""
|
||||
Create a new environment with specified Python version
|
||||
Args:
|
||||
env_name (str): Name of the environment
|
||||
python_version (str): Python version to install (e.g., '3.8')
|
||||
"""
|
||||
try:
|
||||
if self.manager == EnvManager.CONDA:
|
||||
subprocess.run([self.manager_path, 'create', '-n', env_name,
|
||||
f'python={python_version}', '-y'], check=True)
|
||||
elif self.manager == EnvManager.VENV:
|
||||
subprocess.run([f'python{python_version}', '-m', 'venv', env_name], check=True)
|
||||
elif self.manager == EnvManager.PYENV:
|
||||
subprocess.run([self.manager_path, 'install', python_version], check=True)
|
||||
subprocess.run([self.manager_path, 'virtualenv', python_version, env_name], check=True)
|
||||
else:
|
||||
subprocess.run([sys.executable, '-m', 'venv', env_name], check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to create environment: {str(e)}")
|
||||
|
||||
def run_pip_in_env(self, env_name, pip_args, cwd=None):
|
||||
"""
|
||||
Run pip commands in the specified environment
|
||||
Args:
|
||||
env_name (str): Name of the environment
|
||||
pip_args (str): Arguments to pass to pip
|
||||
cwd (str, optional): Working directory
|
||||
"""
|
||||
try:
|
||||
if self.manager == EnvManager.CONDA:
|
||||
cmd = f'"{self.manager_path}" run -n {env_name} pip {pip_args}'
|
||||
else:
|
||||
if platform.system() == 'Windows':
|
||||
pip_path = os.path.join(env_name, 'Scripts', 'pip')
|
||||
else:
|
||||
pip_path = os.path.join(env_name, 'bin', 'pip')
|
||||
cmd = f'"{pip_path}" {pip_args}'
|
||||
|
||||
subprocess.run(cmd, shell=True, cwd=cwd, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to run pip: {str(e)}")
|
||||
|
||||
def run_python_script_in_env(self, env_name, script_path, cwd=None, wait=True):
|
||||
"""
|
||||
Run a Python script in the specified environment
|
||||
Args:
|
||||
env_name (str): Name of the environment
|
||||
script_path (str): Path to the Python script
|
||||
cwd (str, optional): Working directory
|
||||
wait (bool): Whether to wait for the script to complete
|
||||
"""
|
||||
try:
|
||||
if self.manager == EnvManager.CONDA:
|
||||
cmd = f'"{self.manager_path}" run -n {env_name} python "{script_path}"'
|
||||
else:
|
||||
python_path = self._get_env_python_path(env_name)
|
||||
cmd = f'"{python_path}" "{script_path}"'
|
||||
|
||||
# Wait for the process to finish
|
||||
if wait:
|
||||
process.wait()
|
||||
return process
|
||||
|
||||
def run_script_in_env(env_name, script_path, cwd=None):
|
||||
import platform
|
||||
# Set the current working directory if provided, otherwise use the current directory
|
||||
if cwd is None:
|
||||
cwd = os.path.dirname(script_path)
|
||||
|
||||
# Activate the Conda environment
|
||||
if platform.system()=="Windows":
|
||||
python_path = Path(sys.executable).parent.parent/"miniconda3"/"condabin"/"conda"
|
||||
subprocess.Popen(f'{python_path} activate {env_name} && {script_path}', shell=True, cwd=cwd)
|
||||
subprocess.run(cmd, shell=True, cwd=cwd, check=True)
|
||||
else:
|
||||
python_path = Path(sys.executable).parent.parent.parent/"miniconda3"/"bin"/"conda"
|
||||
subprocess.Popen(f'source {python_path} activate {env_name} && {script_path}', shell=True, cwd=cwd)
|
||||
# Activate the Conda environment
|
||||
ASCIIColors.red("Python path:")
|
||||
ASCIIColors.yellow(python_path)
|
||||
#run_command(Commands.RUN, "-n", env_name, str(script_path), cwd=cwd)
|
||||
subprocess.Popen(cmd, shell=True, cwd=cwd)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to run script: {str(e)}")
|
||||
|
||||
def get_conda_path():
|
||||
import platform
|
||||
if platform.system() == "Windows":
|
||||
return Path(sys.executable).parent.parent / "miniconda3" / "condabin" / "conda"
|
||||
def run_script_in_env(self, env_name, script_path, cwd=None):
|
||||
"""
|
||||
Run any script in the specified environment
|
||||
Args:
|
||||
env_name (str): Name of the environment
|
||||
script_path (str): Path to the script
|
||||
cwd (str, optional): Working directory
|
||||
"""
|
||||
try:
|
||||
if self.manager == EnvManager.CONDA:
|
||||
cmd = f'"{self.manager_path}" run -n {env_name} "{script_path}"'
|
||||
else:
|
||||
return Path(sys.executable).parent.parent.parent / "miniconda3" / "bin" / "conda"
|
||||
cmd = os.path.join(env_name, 'bin', script_path)
|
||||
|
||||
def environment_exists(env_name):
|
||||
from lollms.security import sanitize_shell_code
|
||||
env_name = sanitize_shell_code(env_name)
|
||||
env_name = Path(sys.executable).parent.parent / "miniconda3" / "env" / env_name
|
||||
conda_path = get_conda_path()
|
||||
ASCIIColors.yellow(f"Using conda from : {conda_path}")
|
||||
subprocess.run(cmd, shell=True, cwd=cwd, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to run script: {str(e)}")
|
||||
|
||||
result = subprocess.run(f'{conda_path} env list --json', shell=True, capture_output=True, text=True)
|
||||
envs_info = json.loads(result.stdout)
|
||||
env_names = [Path(env).name for env in envs_info['envs']]
|
||||
return env_name in env_names
|
||||
def environment_exists(self, env_name):
|
||||
"""
|
||||
Check if the specified environment exists
|
||||
Args:
|
||||
env_name (str): Name of the environment
|
||||
Returns:
|
||||
bool: True if environment exists, False otherwise
|
||||
"""
|
||||
if self.manager == EnvManager.CONDA:
|
||||
result = subprocess.run([self.manager_path, 'env', 'list'],
|
||||
capture_output=True, text=True)
|
||||
return env_name in result.stdout
|
||||
else:
|
||||
return os.path.exists(env_name) and os.path.isdir(env_name)
|
||||
|
||||
def get_python_version(env_name):
|
||||
from lollms.security import sanitize_shell_code
|
||||
env_name = sanitize_shell_code(env_name)
|
||||
conda_path = get_conda_path()
|
||||
if environment_exists(env_name):
|
||||
result = subprocess.run(f'{conda_path} run -n {env_name} python --version', shell=True, capture_output=True, text=True)
|
||||
def get_python_version(self, env_name):
|
||||
"""
|
||||
Get Python version of the specified environment
|
||||
Args:
|
||||
env_name (str): Name of the environment
|
||||
Returns:
|
||||
str: Python version
|
||||
"""
|
||||
try:
|
||||
if self.manager == EnvManager.CONDA:
|
||||
cmd = f'"{self.manager_path}" run -n {env_name} python --version'
|
||||
else:
|
||||
python_path = self._get_env_python_path(env_name)
|
||||
cmd = f'"{python_path}" --version'
|
||||
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, check=True)
|
||||
return result.stdout.strip()
|
||||
else:
|
||||
return "Environment does not exist."
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to get Python version: {str(e)}")
|
||||
|
||||
def remove_environment(env_name):
|
||||
from lollms.security import sanitize_shell_code
|
||||
env_name = sanitize_shell_code(env_name)
|
||||
conda_path = get_conda_path()
|
||||
if environment_exists(env_name):
|
||||
process = subprocess.Popen(f'{conda_path} env remove --name {env_name} -y', shell=True)
|
||||
process.wait()
|
||||
return f"Environment '{env_name}' has been removed."
|
||||
def remove_environment(self, env_name):
|
||||
"""
|
||||
Remove the specified environment
|
||||
Args:
|
||||
env_name (str): Name of the environment
|
||||
"""
|
||||
try:
|
||||
if self.manager == EnvManager.CONDA:
|
||||
subprocess.run([self.manager_path, 'env', 'remove', '-n', env_name, '-y'],
|
||||
check=True)
|
||||
elif self.manager == EnvManager.PYENV:
|
||||
subprocess.run([self.manager_path, 'virtualenv-delete', env_name], check=True)
|
||||
else:
|
||||
return "Environment does not exist."
|
||||
if os.path.exists(env_name):
|
||||
shutil.rmtree(env_name)
|
||||
except (subprocess.CalledProcessError, OSError) as e:
|
||||
raise RuntimeError(f"Failed to remove environment: {str(e)}")
|
||||
|
||||
def process_ai_output(output, images, output_folder):
|
||||
if not PackageManager.check_package_installed("cv2"):
|
||||
@ -673,131 +820,145 @@ def remove_text_from_string(string, text_to_find):
|
||||
return string
|
||||
|
||||
|
||||
# Pytorch and cuda tools
|
||||
def check_torch_version(min_version, min_cuda_versio=12):
|
||||
import torch
|
||||
import sys
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
if "+" in torch.__version__ and int(torch.__version__.split("+")[-1][2:4])<min_cuda_versio:
|
||||
def check_torch_version(min_version="2.0.0", min_cuda_version=12):
|
||||
try:
|
||||
import torch
|
||||
current_version = torch.__version__
|
||||
|
||||
if pkg_resources.parse_version(current_version) < pkg_resources.parse_version(min_version):
|
||||
print(f"PyTorch version {current_version} is lower than minimum required version {min_version}")
|
||||
return False
|
||||
|
||||
# Extract torch version from __version__ attribute with regular expression
|
||||
current_version_float = float('.'.join(torch.__version__.split(".")[:2]))
|
||||
# Check if the current version meets or exceeds the minimum required version
|
||||
return current_version_float >= min_version
|
||||
if torch.cuda.is_available():
|
||||
cuda_version = torch.version.cuda
|
||||
if int(cuda_version.split('.')[0]) < min_cuda_version:
|
||||
print(f"CUDA version {cuda_version} is lower than minimum required version {min_cuda_version}")
|
||||
return False
|
||||
print(f"PyTorch {current_version} with CUDA {cuda_version} is properly installed")
|
||||
else:
|
||||
print("CUDA is not available")
|
||||
|
||||
def install_ninja():
|
||||
import conda.cli
|
||||
try:
|
||||
ASCIIColors.info("Installing ninja") # -c nvidia/label/cuda-12.3.2 -c nvidia -c conda-forge
|
||||
result = conda.cli.main("install", "-c", "nvidia/label/cuda-12.3.2", "-c", "nvidia", "-c", "conda-forge", "ninja", "-y","--force-reinstall")
|
||||
except Exception as ex:
|
||||
ASCIIColors.error(ex)
|
||||
return True
|
||||
|
||||
def install_cuda():
|
||||
import conda.cli
|
||||
try:
|
||||
ASCIIColors.info("Installing cuda 12.3.2") # -c nvidia/label/cuda-12.3.2 -c nvidia -c conda-forge
|
||||
result = conda.cli.main("install", "-c", "nvidia/label/cuda-12.3.2", "-c", "nvidia", "-c", "conda-forge", "cuda-toolkit","-y","--force-reinstall")
|
||||
except Exception as ex:
|
||||
ASCIIColors.error(ex)
|
||||
try:
|
||||
ASCIIColors.info("Installing cuda compiler") # -c nvidia/label/cuda-12.3.2 -c nvidia -c conda-forge
|
||||
result = conda.cli.main("install", "-c", "nvidia/label/cuda-12.3.2", "-c", "nvidia", "-c", "conda-forge", "cuda-compiler", "-y","--force-reinstall")
|
||||
except Exception as ex:
|
||||
ASCIIColors.error(ex)
|
||||
|
||||
def install_cmake():
|
||||
import conda.cli
|
||||
try:
|
||||
ASCIIColors.info("Installing cmake") # -c nvidia/label/cuda-12.3.2 -c nvidia -c conda-forge
|
||||
result = conda.cli.main("install", "-c", "nvidia/label/cuda-12.3.2", "-c", "nvidia", "-c", "conda-forge", "cmake","-y","--force-reinstall")
|
||||
except Exception as ex:
|
||||
ASCIIColors.error(ex)
|
||||
except ImportError:
|
||||
print("PyTorch is not installed")
|
||||
return False
|
||||
|
||||
def reinstall_pytorch_with_cuda():
|
||||
"""
|
||||
Reinstall PyTorch with CUDA support using pip
|
||||
Platform-aware: Windows and Linux will use CUDA, Mac will use default
|
||||
"""
|
||||
try:
|
||||
import conda.cli
|
||||
ASCIIColors.info("Installing cuda 12.3.2") # -c nvidia/label/cuda-12.3.2 -c nvidia -c conda-forge
|
||||
result = conda.cli.main("install", "-c", "nvidia/label/cuda-12.3.2", "-c", "nvidia", "-c", "conda-forge", "cuda-toolkit","-y","--force-reinstall")
|
||||
except Exception as ex:
|
||||
ASCIIColors.error(ex)
|
||||
try:
|
||||
ASCIIColors.info("Installing ninja") # -c nvidia/label/cuda-12.3.2 -c nvidia -c conda-forge
|
||||
result = conda.cli.main("install", "-c", "nvidia/label/cuda-12.3.2", "-c", "nvidia", "-c", "conda-forge", "ninja", "-y","--force-reinstall")
|
||||
except Exception as ex:
|
||||
ASCIIColors.error(ex)
|
||||
try:
|
||||
ASCIIColors.info("Installing cuda compiler") # -c nvidia/label/cuda-12.3.2 -c nvidia -c conda-forge
|
||||
result = conda.cli.main("install", "-c", "nvidia/label/cuda-12.3.2", "-c", "nvidia", "-c", "conda-forge", "cuda-compiler", "-y","--force-reinstall")
|
||||
except Exception as ex:
|
||||
ASCIIColors.error(ex)
|
||||
try:
|
||||
ASCIIColors.info("Installing pytorch 2.2.1")
|
||||
result = subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "torch", "torchvision", "torchaudio", "--no-cache-dir", "--index-url", "https://download.pytorch.org/whl/cu121"])
|
||||
except Exception as ex:
|
||||
ASCIIColors.error(ex)
|
||||
if result.returncode != 0:
|
||||
ASCIIColors.warning("Couldn't find Cuda build tools on your PC. Reverting to CPU.")
|
||||
result = subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "torch", "torchvision", "torchaudio", "--no-cache-dir"])
|
||||
if result.returncode != 0:
|
||||
ASCIIColors.error("Couldn't install pytorch !!")
|
||||
else:
|
||||
ASCIIColors.error("Pytorch installed successfully!!")
|
||||
system = platform.system().lower()
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "torch", "torchvision", "torchaudio", "-y"])
|
||||
|
||||
if system in ['windows', 'linux']:
|
||||
subprocess.check_call([
|
||||
sys.executable, "-m", "pip", "install",
|
||||
"torch", "torchvision", "torchaudio",
|
||||
"--index-url", "https://download.pytorch.org/whl/cu121"
|
||||
])
|
||||
elif system == 'darwin':
|
||||
print("Note: Installing default MacOS version as CUDA is not supported on MacOS")
|
||||
subprocess.check_call([
|
||||
sys.executable, "-m", "pip", "install",
|
||||
"torch", "torchvision", "torchaudio"
|
||||
])
|
||||
print("PyTorch reinstalled with CUDA support (where applicable)")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error reinstalling PyTorch: {e}")
|
||||
|
||||
def reinstall_pytorch_with_rocm():
|
||||
result = subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "torch", "torchvision", "torchaudio", "--no-cache-dir", "--index-url", "https://download.pytorch.org/whl/rocm5.6"])
|
||||
if result.returncode != 0:
|
||||
ASCIIColors.warning("Couldn't find Cuda build tools on your PC. Reverting to CPU.")
|
||||
result = subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "torch", "torchvision", "torchaudio", "--no-cache-dir"])
|
||||
if result.returncode != 0:
|
||||
ASCIIColors.error("Couldn't install pytorch !!")
|
||||
else:
|
||||
ASCIIColors.error("Pytorch installed successfully!!")
|
||||
|
||||
|
||||
"""
|
||||
Reinstall PyTorch with ROCm support using pip
|
||||
"""
|
||||
try:
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "torch", "torchvision", "torchaudio", "-y"])
|
||||
subprocess.check_call([
|
||||
sys.executable, "-m", "pip", "install",
|
||||
"torch", "torchvision", "torchaudio",
|
||||
"--index-url", "https://download.pytorch.org/whl/rocm5.6"
|
||||
])
|
||||
print("PyTorch reinstalled with ROCm support")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error reinstalling PyTorch: {e}")
|
||||
|
||||
def reinstall_pytorch_with_cpu():
|
||||
result = subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "torch", "torchvision", "torchaudio", "--no-cache-dir"])
|
||||
if result.returncode != 0:
|
||||
ASCIIColors.warning("Couldn't find Cuda build tools on your PC. Reverting to CPU.")
|
||||
result = subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "torch", "torchvision", "torchaudio", "--no-cache-dir"])
|
||||
if result.returncode != 0:
|
||||
ASCIIColors.error("Couldn't install pytorch !!")
|
||||
else:
|
||||
ASCIIColors.error("Pytorch installed successfully!!")
|
||||
"""
|
||||
Reinstall PyTorch CPU-only version using pip
|
||||
"""
|
||||
try:
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "torch", "torchvision", "torchaudio", "-y"])
|
||||
subprocess.check_call([
|
||||
sys.executable, "-m", "pip", "install",
|
||||
"torch", "torchvision", "torchaudio"
|
||||
])
|
||||
print("PyTorch reinstalled (CPU-only version)")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error reinstalling PyTorch: {e}")
|
||||
|
||||
|
||||
def check_and_install_torch(enable_gpu:bool, version:float=2.2):
|
||||
|
||||
def check_and_install_torch(enable_gpu: bool, version: float = 2.2):
|
||||
"""
|
||||
Check and install PyTorch with specified configuration
|
||||
Args:
|
||||
enable_gpu (bool): Whether to install GPU version
|
||||
version (float): Minimum required PyTorch version
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
current_version = torch.__version__
|
||||
system = platform.system().lower()
|
||||
|
||||
# Check if current installation meets requirements
|
||||
if pkg_resources.parse_version(current_version) >= pkg_resources.parse_version(str(version)):
|
||||
if enable_gpu:
|
||||
ASCIIColors.yellow("This installation has enabled GPU support. Trying to install with GPU support")
|
||||
ASCIIColors.info("Checking pytorch")
|
||||
if system == 'darwin':
|
||||
# For Mac, MPS is the GPU solution
|
||||
print(f"Current PyTorch installation ({current_version}) is compatible with Mac GPU (MPS)")
|
||||
return True
|
||||
elif torch.cuda.is_available():
|
||||
print(f"Current PyTorch installation ({current_version}) has CUDA support")
|
||||
return True
|
||||
else:
|
||||
print("GPU version requested but CUDA not available. Reinstalling...")
|
||||
else:
|
||||
print(f"Current CPU PyTorch installation ({current_version}) meets version requirement")
|
||||
return True
|
||||
except ImportError:
|
||||
print("PyTorch not found. Installing...")
|
||||
|
||||
# Perform installation based on requirements
|
||||
if enable_gpu:
|
||||
if system == 'darwin':
|
||||
reinstall_pytorch_with_cpu() # Mac uses default installation for MPS
|
||||
else:
|
||||
reinstall_pytorch_with_cuda()
|
||||
else:
|
||||
reinstall_pytorch_with_cpu()
|
||||
|
||||
# Verify installation
|
||||
try:
|
||||
import torch
|
||||
import torchvision
|
||||
if torch.cuda.is_available():
|
||||
ASCIIColors.success(f"CUDA is supported.\nCurrent version is {torch.__version__}.")
|
||||
if not check_torch_version(version):
|
||||
ASCIIColors.yellow("Torch version is old. Installing new version")
|
||||
reinstall_pytorch_with_cuda()
|
||||
print(f"PyTorch {torch.__version__} installed successfully")
|
||||
if enable_gpu:
|
||||
if system == 'darwin':
|
||||
print("MPS (Mac GPU) support available if hardware supports it")
|
||||
elif torch.cuda.is_available():
|
||||
print(f"CUDA version: {torch.version.cuda}")
|
||||
else:
|
||||
ASCIIColors.yellow("Torch OK")
|
||||
else:
|
||||
ASCIIColors.warning("CUDA is not supported. Trying to reinstall PyTorch with CUDA support.")
|
||||
reinstall_pytorch_with_cuda()
|
||||
except Exception as ex:
|
||||
ASCIIColors.info("Pytorch not installed. Reinstalling ...")
|
||||
reinstall_pytorch_with_cuda()
|
||||
else:
|
||||
try:
|
||||
import torch
|
||||
import torchvision
|
||||
if check_torch_version(version):
|
||||
ASCIIColors.warning("Torch version is too old. Trying to reinstall PyTorch with CUDA support.")
|
||||
reinstall_pytorch_with_cpu()
|
||||
except Exception as ex:
|
||||
ASCIIColors.info("Pytorch not installed. Reinstalling ...")
|
||||
reinstall_pytorch_with_cpu()
|
||||
print("Warning: GPU version requested but CUDA not available")
|
||||
except ImportError:
|
||||
print("Installation failed")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class NumpyEncoderDecoder(json.JSONEncoder):
|
||||
|
Loading…
Reference in New Issue
Block a user