1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00

zluda device id #2

This commit is contained in:
Seunghoon Lee
2024-09-26 13:01:55 +09:00
parent 6af1524be8
commit bdcef261c5

View File

@@ -479,7 +479,7 @@ def install_rocm_zluda():
amd_gpus = rocm.get_agents()
if len(amd_gpus) == 0:
(log.info if sys.platform == "win32" else log.warning)('ROCm: no agent was found')
else:
elif args.device_id is None:
log.info(f'ROCm: agents={[gpu.name for gpu in amd_gpus]}')
hip_default_device = amd_gpus[0]
for idx, gpu in enumerate(amd_gpus):
@@ -494,22 +494,20 @@ def install_rocm_zluda():
except Exception as e:
log.warning(f'ROCm agent enumerator failed: {e}')
if args.device_id is not None:
if os.environ.get('HIP_VISIBLE_DEVICES', None) is not None:
log.warning('Setting HIP_VISIBLE_DEVICES and --device-id at the same time may be mistake.')
device_id = int(args.device_id)
if device_id < len(amd_gpus):
hip_default_device = amd_gpus[device_id]
os.environ['HIP_VISIBLE_DEVICES'] = args.device_id
del args.device_id
if hip_default_device is not None:
os.environ.setdefault('HSA_OVERRIDE_GFX_VERSION', hip_default_device.get_gfx_version())
log.info(f'ROCm: version={rocm.version}')
torch_command = ''
if sys.platform == "win32":
# TODO after ROCm for Windows is released
if args.device_id is not None:
if os.environ.get('HIP_VISIBLE_DEVICES', None) is not None:
log.warning('Setting HIP_VISIBLE_DEVICES and --device-id at the same time may be mistake.')
device_id = int(args.device_id)
if device_id < len(amd_gpus):
hip_default_device = amd_gpus[device_id]
os.environ['HIP_VISIBLE_DEVICES'] = args.device_id
del args.device_id
log.warning("ZLUDA support: experimental")
error = None
from modules import zluda_installer
@@ -559,6 +557,10 @@ def install_rocm_zluda():
if hip_default_device is not None and rocm.version != "6.2" and rocm.version == rocm.version_torch and rocm.get_blaslt_enabled():
log.debug(f'ROCm hipBLASLt: arch={hip_default_device.name} available={hip_default_device.blaslt_supported}')
rocm.set_blaslt_enabled(hip_default_device.blaslt_supported)
if hip_default_device is not None:
os.environ.setdefault('HSA_OVERRIDE_GFX_VERSION', hip_default_device.get_gfx_version())
return torch_command