1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/zluda_installer.py
vladmandic 8cf27dffdb allow unsafe ssl context for downloads
Signed-off-by: vladmandic <mandic00@live.com>
2026-01-19 17:36:01 +01:00

189 lines
6.5 KiB
Python

import os
import sys
import ssl
import site
import ctypes
import shutil
import zipfile
import urllib.request
from typing import Union
from installer import args, log
from modules import rocm
DLL_MAPPING = {
'cublas.dll': 'cublas64_11.dll',
'cusparse.dll': 'cusparse64_11.dll',
'cufft.dll': 'cufft64_10.dll',
'cufftw.dll': 'cufftw64_10.dll',
'nvrtc.dll': 'nvrtc64_112_0.dll',
}
HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll', 'rocsparse.dll', 'hipfft.dll',]
MIOpen_enabled = False
path = os.path.abspath(os.environ.get('ZLUDA', '.zluda'))
default_agent: Union[rocm.Agent, None] = None
hipBLASLt_enabled = False
class ZLUDAResult(ctypes.Structure):
_fields_ = [
('return_code', ctypes.c_int),
('value', ctypes.c_ulonglong),
]
class ZLUDALibrary:
internal: ctypes.CDLL
def __init__(self, internal: ctypes.CDLL):
self.internal = internal
class Core(ZLUDALibrary):
def __init__(self, internal: ctypes.CDLL):
internal.zluda_get_hip_object.restype = ZLUDAResult
internal.zluda_get_hip_object.argtypes = [ctypes.c_void_p, ctypes.c_int]
try:
internal.zluda_get_nightly_flag.restype = ctypes.c_int
internal.zluda_get_nightly_flag.argtypes = []
except AttributeError:
internal.zluda_get_nightly_flag = lambda: 0
super().__init__(internal)
def to_hip_stream(self, zluda_object: ctypes.c_void_p):
return self.internal.zluda_get_hip_object(zluda_object, 1).value
def get_nightly_flag(self) -> int:
return self.internal.zluda_get_nightly_flag()
core = None
ml = None
def set_default_agent(agent: rocm.Agent):
global default_agent # pylint: disable=global-statement
default_agent = agent
def is_reinstall_needed() -> bool: # ZLUDA<3.9.4
return os.path.exists(os.path.join(path, 'cudart.dll'))
def install():
if os.path.exists(path):
return
platform = "windows"
commit = os.environ.get("ZLUDA_HASH", "5e717459179dc272b7d7d23391f0fad66c7459cf")
if os.environ.get("ZLUDA_NIGHTLY", "0") == "1":
log.warning("Environment variable 'ZLUDA_NIGHTLY' will be removed. Please use command-line argument '--use-nightly' instead.")
args.use_nightly = True
if args.use_nightly:
platform = "nightly-" + platform
log.debug(f'Install ZLUDA: rocm={rocm.version} platform={platform} commit={commit}')
ssl._create_default_https_context = ssl._create_unverified_context # pylint: disable=protected-access
try:
urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.{commit}/ZLUDA-{platform}-rocm{rocm.version[0]}-amd64.zip', '_zluda')
if not os.path.exists('_zluda'):
raise RuntimeError('ZLUDA download failed')
with zipfile.ZipFile('_zluda', 'r') as archive:
infos = archive.infolist()
for info in infos:
if not info.is_dir():
info.filename = os.path.basename(info.filename)
archive.extract(info, path)
except Exception as e:
raise RuntimeError(f'Install ZLUDA: {e}') from e
finally:
if os.path.exists('_zluda'):
os.remove('_zluda')
def uninstall():
if os.path.exists(path):
shutil.rmtree(path)
def set_blaslt_enabled(enabled: bool):
global hipBLASLt_enabled # pylint: disable=global-statement
hipBLASLt_enabled = enabled
def get_blaslt_enabled() -> bool:
return hipBLASLt_enabled
def link_or_copy(src: os.PathLike, dst: os.PathLike):
try:
os.symlink(src, dst)
except Exception:
try:
os.link(src, dst)
except Exception:
shutil.copyfile(src, dst)
def load():
assert isinstance(rocm.environment, rocm.ROCmEnvironment)
global core, ml, hipBLASLt_enabled, MIOpen_enabled # pylint: disable=global-statement
core = Core(ctypes.windll.LoadLibrary(os.path.join(path, 'nvcuda.dll')))
ml = ZLUDALibrary(ctypes.windll.LoadLibrary(os.path.join(path, 'nvml.dll')))
is_nightly = core.get_nightly_flag() == 1
hipBLASLt_enabled = is_nightly and os.path.exists(rocm.blaslt_tensile_libpath) and os.path.exists(os.path.join(rocm.environment.path, "bin", "hipblaslt.dll")) and default_agent is not None and default_agent.blaslt_supported
MIOpen_enabled = is_nightly and os.path.exists(os.path.join(rocm.environment.path, "bin", "MIOpen.dll"))
for k, v in DLL_MAPPING.items():
if not os.path.exists(os.path.join(path, v)):
link_or_copy(os.path.join(path, k), os.path.join(path, v))
if hipBLASLt_enabled and not os.path.exists(os.path.join(path, 'cublasLt64_11.dll')):
link_or_copy(os.path.join(path, 'cublasLt.dll'), os.path.join(path, 'cublasLt64_11.dll'))
if MIOpen_enabled and not os.path.exists(os.path.join(path, 'cudnn64_9.dll')):
link_or_copy(os.path.join(path, 'cudnn.dll'), os.path.join(path, 'cudnn64_9.dll'))
log.info(f"ZLUDA load: path='{path}' nightly={bool(core.get_nightly_flag())}")
os.environ["ZLUDA_COMGR_LOG_LEVEL"] = "1"
os.environ["ZLUDA_NVRTC_LIB"] = os.path.join([v for v in site.getsitepackages() if v.endswith("site-packages")][0], "torch", "lib", "nvrtc64_112_0.dll")
for v in HIPSDK_TARGETS:
ctypes.windll.LoadLibrary(os.path.join(rocm.environment.path, 'bin', v))
for v in DLL_MAPPING.values():
ctypes.windll.LoadLibrary(os.path.join(path, v))
if hipBLASLt_enabled:
os.environ.setdefault("DISABLE_ADDMM_CUDA_LT", "0")
ctypes.windll.LoadLibrary(os.path.join(rocm.environment.path, 'bin', 'hipblaslt.dll'))
ctypes.windll.LoadLibrary(os.path.join(path, 'cublasLt64_11.dll'))
else:
os.environ["DISABLE_ADDMM_CUDA_LT"] = "1"
if MIOpen_enabled:
ctypes.windll.LoadLibrary(os.path.join(rocm.environment.path, 'bin', 'MIOpen.dll'))
ctypes.windll.LoadLibrary(os.path.join(path, 'cudnn64_9.dll'))
def postinstall():
import torch
torch.version.hip = rocm.version
platform = sys.platform
sys.platform = ""
from torch.utils import cpp_extension
sys.platform = platform
cpp_extension.IS_WINDOWS = platform == "win32"
cpp_extension.IS_MACOS = False
cpp_extension.IS_LINUX = platform.startswith('linux')
def _join_rocm_home(*paths) -> str:
return os.path.join(cpp_extension.ROCM_HOME, *paths)
cpp_extension._join_rocm_home = _join_rocm_home # pylint: disable=protected-access
rocm.postinstall = postinstall
from modules.zluda import zluda_init
rocm.rocm_init = zluda_init