1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00
Files
sdnext/modules/devices.py

663 lines
27 KiB
Python

import os
import sys
import time
import contextlib
from functools import wraps
import torch
from modules import rocm
from modules.errors import log, display, install as install_traceback
from installer import install, installed
debug = os.environ.get('SD_DEVICE_DEBUG', None) is not None
install_traceback() # traceback handler
opts = None # initialized in get_backend to avoid circular import
args = None # initialized in get_backend to avoid circular import
cuda_ok = torch.cuda.is_available() or (hasattr(torch, 'xpu') and torch.xpu.is_available())
inference_context = torch.no_grad
cpu = torch.device("cpu")
fp16_ok = None # set once by test_fp16
bf16_ok = None # set once by test_bf16
backend = None # set by get_backend
device = None # set by get_optimal_device
dtype = None # set by set_dtype
dtype_vae = None
dtype_unet = None
unet_needs_upcast = False # compatibility item
onnx = None
sdpa_original = None
sdpa_pre_dyanmic_atten = None
previous_oom = 0 # oom counter
if debug:
log.info(f'Torch build config: {torch.__config__.show()}')
# set_cuda_sync_mode('block') # none/auto/spin/yield/block
def has_mps() -> bool:
if sys.platform != "darwin":
return False
else:
from modules import devices_mac # pylint: disable=ungrouped-imports
return devices_mac.has_mps # pylint: disable=used-before-assignment
def has_xpu() -> bool:
return bool(hasattr(torch, 'xpu') and torch.xpu.is_available())
def has_zluda() -> bool:
if not cuda_ok:
return False
try:
dev = torch.device("cuda")
cc = torch.cuda.get_device_capability(dev)
return cc == (8, 8)
except Exception:
return False
def get_backend(shared_cmd_opts):
global args # pylint: disable=global-statement
args = shared_cmd_opts
if args.use_openvino:
name = 'openvino'
elif args.use_directml:
name = 'directml'
elif has_xpu():
name = 'ipex'
elif has_zluda():
name = 'zluda'
elif torch.cuda.is_available() and torch.version.cuda:
name = 'cuda'
elif torch.cuda.is_available() and torch.version.hip:
name = 'rocm'
elif sys.platform == 'darwin':
name = 'mps'
else:
name = 'cpu'
return name
def get_gpu_info():
def get_driver():
import subprocess
if torch.xpu.is_available():
try:
return torch.xpu.get_device_properties(torch.xpu.current_device()).driver_version
except Exception:
return ''
elif torch.cuda.is_available() and torch.version.cuda:
try:
result = subprocess.run('nvidia-smi --query-gpu=driver_version --format=csv,noheader', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
version = result.stdout.decode(encoding="utf8", errors="ignore").strip()
return version
except Exception:
return ''
else:
return ''
def get_package_version(pkg: str):
import pkg_resources
spec = pkg_resources.working_set.by_key.get(pkg, None) # more reliable than importlib
version = pkg_resources.get_distribution(pkg).version if spec is not None else ''
return version
if not torch.cuda.is_available():
try:
if backend == 'openvino':
from modules.intel.openvino import get_openvino_device
return {
'device': get_openvino_device(), # pylint: disable=used-before-assignment
'openvino': get_package_version("openvino"),
}
elif backend == 'directml':
return {
'device': f'{torch.cuda.get_device_name(torch.cuda.current_device())} n={torch.cuda.device_count()}',
'directml': get_package_version("torch-directml"),
}
else:
return {}
except Exception:
return {}
else:
try:
if backend == 'ipex':
return {
'device': f'{torch.xpu.get_device_name(torch.xpu.current_device())} n={torch.xpu.device_count()}',
'ipex': get_package_version('intel-extension-for-pytorch'),
'driver': get_driver(),
}
elif backend == 'cuda' or backend == 'zluda':
return {
'device': f'{torch.cuda.get_device_name(torch.cuda.current_device())} n={torch.cuda.device_count()} arch={torch.cuda.get_arch_list()[-1]} capability={torch.cuda.get_device_capability(device)}',
'cuda': torch.version.cuda,
'cudnn': torch.backends.cudnn.version(),
'driver': get_driver(),
}
elif backend == 'rocm':
return {
'device': f'{torch.cuda.get_device_name(torch.cuda.current_device())} n={torch.cuda.device_count()}',
'hip': torch.version.hip,
}
else:
return {
'device': 'unknown'
}
except Exception as ex:
if debug:
display(ex, 'Device exception')
return { 'error': ex }
def extract_device_id(args, name): # pylint: disable=redefined-outer-name
for x in range(len(args)):
if name in args[x]:
return args[x + 1]
return None
def get_cuda_device_string():
from modules.shared import cmd_opts
if backend == 'ipex':
if cmd_opts.device_id is not None:
return f"xpu:{cmd_opts.device_id}"
return "xpu"
elif backend == 'directml' and torch.dml.is_available():
if cmd_opts.device_id is not None:
return f"privateuseone:{cmd_opts.device_id}"
return torch.dml.get_device_string(torch.dml.default_device().index)
else:
if cmd_opts.device_id is not None:
return f"cuda:{cmd_opts.device_id}"
return "cuda"
def get_optimal_device_name():
if cuda_ok or backend == 'directml':
return get_cuda_device_string()
if has_mps() and backend != 'openvino':
return "mps"
return "cpu"
def get_optimal_device():
return torch.device(get_optimal_device_name())
def get_device_for(task): # pylint: disable=unused-argument
# if task in cmd_opts.use_cpu:
# log.debug(f'Forcing CPU for task: {task}')
# return cpu
return get_optimal_device()
def torch_gc(force:bool=False, fast:bool=False, reason:str=None):
def get_stats():
mem_dict = memstats.memory_stats()
gpu_dict = mem_dict.get('gpu', {})
ram_dict = mem_dict.get('ram', {})
oom = gpu_dict.get('oom', 0)
ram = ram_dict.get('used', 0)
if backend == "directml":
gpu = torch.cuda.memory_allocated() / (1 << 30)
else:
gpu = gpu_dict.get('used', 0)
used_gpu = round(100 * gpu / gpu_dict.get('total', 1)) if gpu_dict.get('total', 1) > 1 else 0
used_ram = round(100 * ram / ram_dict.get('total', 1)) if ram_dict.get('total', 1) > 1 else 0
return gpu, used_gpu, ram, used_ram, oom
global previous_oom # pylint: disable=global-statement
import gc
from modules import timer, memstats
from modules.shared import cmd_opts
t0 = time.time()
gpu, used_gpu, ram, _used_ram, oom = get_stats()
threshold = 0 if (cmd_opts.lowvram and not cmd_opts.use_zluda) else opts.torch_gc_threshold
collected = 0
if reason is None and force:
reason='force'
if threshold == 0 or used_gpu >= threshold:
force = True
if reason is None:
reason = 'threshold'
if oom > previous_oom:
previous_oom = oom
log.warning(f'Torch GPU out-of-memory error: {memstats.memory_stats()}')
force = True
if reason is None:
reason = 'oom'
if debug:
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.trace(f'GC: run={force} fast={fast} used={used_gpu} threshold={threshold} fn={fn}')
if force:
# actual gc
collected = gc.collect() if not fast else 0 # python gc
if cuda_ok:
try:
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.synchronize()
torch.cuda.empty_cache() # cuda gc
torch.cuda.ipc_collect()
except Exception:
pass
else:
return gpu, ram
t1 = time.time()
timer.process.add('gc', t1 - t0)
if fast:
return gpu, ram
new_gpu, new_used_gpu, new_ram, new_used_ram, oom = get_stats()
before = { 'gpu': gpu, 'ram': ram }
after = { 'gpu': new_gpu, 'ram': new_ram, 'oom': oom }
utilization = { 'gpu': new_used_gpu, 'ram': new_used_ram }
results = { 'gpu': round(gpu - new_gpu, 2), 'py': collected }
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'GC: current={after} prev={before} load={utilization} gc={results} fn={fn} why={reason} time={t1-t0:.2f}')
return new_gpu, new_ram
def set_cuda_sync_mode(mode):
"""
Set the CUDA device synchronization mode: auto, spin, yield or block.
auto: Chooses spin or yield depending on the number of available CPU cores.
spin: Runs one CPU core per GPU at 100% to poll for completed operations.
yield: Gives control to other threads between polling, if any are waiting.
block: Lets the thread sleep until the GPU driver signals completion.
"""
if mode == -1 or mode == 'none' or not cuda_ok:
return
try:
import ctypes
log.info(f'Torch CUDA sync: mode={mode}')
torch.cuda.set_device(torch.device(get_optimal_device_name()))
ctypes.CDLL('libcudart.so').cudaSetDeviceFlags({'auto': 0, 'spin': 1, 'yield': 2, 'block': 4}[mode])
except Exception:
pass
def set_cuda_memory_limit():
if not cuda_ok or opts.cuda_mem_fraction == 0:
return
try:
from modules.shared import cmd_opts
torch_gc(force=True)
mem = torch.cuda.get_device_properties(device).total_memory
torch.cuda.set_per_process_memory_fraction(float(opts.cuda_mem_fraction), cmd_opts.device_id if cmd_opts.device_id is not None else 0)
log.info(f'Torch memory limit: fraction={opts.cuda_mem_fraction:.2f} limit={round(opts.cuda_mem_fraction * mem / 1024 / 1024)} total={round(mem / 1024 / 1024)}')
except Exception as e:
log.warning(f'Torch memory limit: fraction={opts.cuda_mem_fraction:.2f} {e}')
def set_cuda_tunable():
if not cuda_ok:
return
try:
if opts.torch_tunable_ops != 'default':
torch.cuda.tunable.enable(opts.torch_tunable_ops == 'true')
torch.cuda.tunable.tuning_enable(opts.torch_tunable_ops == 'true')
torch.cuda.tunable.set_max_tuning_duration(1000) # set to high value as actual is min(duration, iterations)
torch.cuda.tunable.set_max_tuning_iterations(opts.torch_tunable_limit)
fn = os.path.join(opts.tunable_dir, 'tunable.csv')
lines={0}
try:
if os.path.exists(fn):
with open(fn, 'r', encoding='utf8') as f:
lines = sum(1 for _line in f)
except Exception:
pass
torch.cuda.tunable.set_filename(fn)
if torch.cuda.tunable.is_enabled():
log.debug(f'Torch tunable: enabled={torch.cuda.tunable.is_enabled()} tuning={torch.cuda.tunable.tuning_is_enabled()} iterations={torch.cuda.tunable.get_max_tuning_iterations()} duration={torch.cuda.tunable.get_max_tuning_duration()} fn="{fn}" entries={lines}')
except Exception as e:
log.warning(f'Torch tunable: {e}')
def test_fp16():
global fp16_ok # pylint: disable=global-statement
if fp16_ok is not None:
return fp16_ok
if opts.cuda_dtype != 'FP16': # don't override if the user sets it
if sys.platform == "darwin" or backend == 'openvino': # override
fp16_ok = False
return fp16_ok
elif backend == 'rocm':
# gfx1102 (RX 7600, 7500, 7650 and 7700S) causes segfaults with fp16
# agent can be overriden to gfx1100 to get gfx1102 working with ROCm so check the gpu name as well
agent = getattr(torch.cuda.get_device_properties(device), "gcnArchName", "gfx0000")
agent_name = getattr(torch.cuda.get_device_properties(device), "name", "AMD Radeon RX 0000")
if agent == "gfx1102" or (agent == "gfx1100" and any(i in agent_name for i in ("7600", "7500", "7650", "7700S"))):
fp16_ok = False
return fp16_ok
try:
x = torch.tensor([[1.5,.0,.0,.0]]).to(device=device, dtype=torch.float16)
layerNorm = torch.nn.LayerNorm(4, eps=0.00001, elementwise_affine=True, dtype=torch.float16, device=device)
out = layerNorm(x)
if out.dtype != torch.float16:
raise RuntimeError('Torch FP16 test: dtype mismatch')
if torch.all(torch.isnan(out)).item():
raise RuntimeError('Torch FP16 test: NaN')
fp16_ok = True
except Exception as ex:
log.warning(f'Torch FP16 test fail: {ex}')
fp16_ok = False
return fp16_ok
def test_bf16():
global bf16_ok # pylint: disable=global-statement
if bf16_ok is not None:
return bf16_ok
if opts.cuda_dtype != 'BF16': # don't override if the user sets it
if sys.platform == "darwin" or backend == 'openvino' or backend == 'directml': # override
bf16_ok = False
return bf16_ok
elif backend == 'rocm' or backend == 'zluda':
agent = None
if backend == 'rocm':
agent = rocm.Agent(getattr(torch.cuda.get_device_properties(device), "gcnArchName", "gfx0000"))
else:
from modules.zluda_installer import default_agent
agent = default_agent
if agent is not None and agent.gfx_version < 0x1100 and agent.arch != rocm.MicroArchitecture.CDNA: # all cards before RDNA 3 except for CDNA cards
bf16_ok = False
return bf16_ok
try:
import torch.nn.functional as F
image = torch.randn(1, 4, 32, 32).to(device=device, dtype=torch.bfloat16)
out = F.interpolate(image, size=(64, 64), mode="nearest")
if out.dtype != torch.bfloat16:
raise RuntimeError('Torch BF16 test: dtype mismatch')
if torch.all(torch.isnan(out)).item():
raise RuntimeError('Torch BF16 test: NaN')
bf16_ok = True
except Exception as ex:
log.warning(f'Torch BF16 test fail: {ex}')
bf16_ok = False
return bf16_ok
def set_cudnn_params():
if not cuda_ok:
return
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
except Exception as e:
log.warning(f'Torch matmul: {e}')
if torch.backends.cudnn.is_available():
try:
torch.backends.cudnn.deterministic = opts.cudnn_deterministic
torch.use_deterministic_algorithms(opts.cudnn_deterministic)
if opts.cudnn_deterministic:
os.environ.setdefault('CUBLAS_WORKSPACE_CONFIG', ':4096:8')
torch.backends.cudnn.benchmark = True
if opts.cudnn_benchmark:
log.debug('Torch cuDNN: enable benchmark')
torch.backends.cudnn.benchmark_limit = 0
torch.backends.cudnn.allow_tf32 = True
except Exception as e:
log.warning(f'Torch cudnn: {e}')
def override_ipex_math():
if backend == "ipex":
try:
if hasattr(torch.xpu, "set_fp32_math_mode"): # not available with pure torch+xpu, requires ipex
torch.xpu.set_fp32_math_mode(mode=torch.xpu.FP32MathMode.TF32)
torch.backends.mkldnn.allow_tf32 = True
except Exception as e:
log.warning(f'Torch ipex: {e}')
def set_sdpa_params():
try:
if opts.cross_attention_optimization != "Scaled-Dot-Product":
return
try:
global sdpa_original # pylint: disable=global-statement
if sdpa_original is not None:
torch.nn.functional.scaled_dot_product_attention = sdpa_original
else:
sdpa_original = torch.nn.functional.scaled_dot_product_attention
except Exception as err:
log.warning(f'Torch attention: type="sdpa" {err}')
try:
torch.backends.cuda.enable_flash_sdp('Flash attention' in opts.sdp_options)
torch.backends.cuda.enable_mem_efficient_sdp('Memory attention' in opts.sdp_options)
torch.backends.cuda.enable_math_sdp('Math attention' in opts.sdp_options)
if hasattr(torch.backends.cuda, "allow_fp16_bf16_reduction_math_sdp"): # only valid for torch >= 2.5
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
log.debug(f'Torch attention: type="sdpa" flash={"Flash attention" in opts.sdp_options} memory={"Memory attention" in opts.sdp_options} math={"Math attention" in opts.sdp_options}')
except Exception as err:
log.warning(f'Torch attention: type="sdpa" {err}')
# Stack hijcaks in reverse order. This gives priority to the last added hijack.
# If the last hijack is not compatible, it will use the one before it and so on.
if 'Dynamic attention' in opts.sdp_options:
try:
global sdpa_pre_dyanmic_atten # pylint: disable=global-statement
sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
from modules.sd_hijack_dynamic_atten import dynamic_scaled_dot_product_attention
torch.nn.functional.scaled_dot_product_attention = dynamic_scaled_dot_product_attention
log.debug('Torch attention: type="dynamic attention"')
except Exception as err:
log.error(f'Torch attention: type="dynamic attention" {err}')
if 'CK Flash attention' in opts.sdp_options:
try:
if backend == "rocm":
if not installed('flash-attn'):
agent = rocm.Agent(getattr(torch.cuda.get_device_properties(device), "gcnArchName", "gfx0000"))
install(rocm.get_flash_attention_command(agent), reinstall=True)
else:
install('flash-attn')
from flash_attn import flash_attn_func
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_flash_atten)
def sdpa_flash_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
is_unsqueezed = False
if query.dim() == 3:
query = query.unsqueeze(0)
is_unsqueezed = True
if key.dim() == 3:
key = key.unsqueeze(0)
if value.dim() == 3:
value = value.unsqueeze(0)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_output = flash_attn_func(q=query, k=key, v=value, dropout_p=dropout_p, causal=is_causal, softmax_scale=scale).transpose(1, 2)
if is_unsqueezed:
attn_output = attn_output.squeeze(0)
return attn_output
else:
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
log.debug('Torch attention: type="ck flash attention"')
except Exception as err:
log.error(f'Torch attention: type="ck flash attention" {err}')
if 'Sage attention' in opts.sdp_options:
try:
install('sageattention')
from sageattention import sageattn
sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_sage_atten)
def sdpa_sage_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if (query.shape[-1] in {128, 96, 64}) and (attn_mask is None) and (query.dtype != torch.float32):
return sageattn(q=query, k=key, v=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
else:
return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten
log.debug('Torch attention: type="sage attention"')
except Exception as err:
log.error(f'Torch attention: type="sage attention" {err}')
except Exception as e:
log.warning(f'Torch SDPA: {e}')
def set_dtype():
global dtype, dtype_vae, dtype_unet, unet_needs_upcast, inference_context # pylint: disable=global-statement
test_fp16()
test_bf16()
if opts.cuda_dtype == 'Auto': # detect
if bf16_ok:
dtype = torch.bfloat16
dtype_vae = torch.bfloat16
dtype_unet = torch.bfloat16
elif fp16_ok:
dtype = torch.float16
dtype_vae = torch.float16
dtype_unet = torch.float16
else:
dtype = torch.float32
dtype_vae = torch.float32
dtype_unet = torch.float32
elif opts.cuda_dtype == 'FP32':
dtype = torch.float32
dtype_vae = torch.float32
dtype_unet = torch.float32
elif opts.cuda_dtype == 'BF16':
if not bf16_ok:
log.warning(f'Torch device capability failed: device={device} dtype={torch.bfloat16}')
dtype = torch.bfloat16
dtype_vae = torch.bfloat16
dtype_unet = torch.bfloat16
elif opts.cuda_dtype == 'FP16':
if not fp16_ok:
log.warning(f'Torch device capability failed: device={device} dtype={torch.float16}')
dtype = torch.float16
dtype_vae = torch.float16
dtype_unet = torch.float16
if opts.no_half:
dtype = torch.float32
dtype_vae = torch.float32
dtype_unet = torch.float32
log.info(f'Torch override: no-half dtype={dtype}')
if opts.no_half_vae:
dtype_vae = torch.float32
log.info(f'Torch override: no-half-vae dtype={dtype_vae}')
unet_needs_upcast = opts.upcast_sampling
if opts.inference_mode == 'inference-mode':
inference_context = torch.inference_mode
elif opts.inference_mode == 'none':
inference_context = contextlib.nullcontext
else:
inference_context = torch.no_grad
def set_cuda_params():
override_ipex_math()
set_cuda_memory_limit()
set_cuda_tunable()
set_cudnn_params()
set_sdpa_params()
set_dtype()
if backend == 'openvino':
from modules.intel.openvino import get_device as get_raw_openvino_device
device_name = get_raw_openvino_device()
else:
device_name = torch.device(get_optimal_device_name())
try:
# tunable = torch._C._jit_get_tunable_op_enabled() # pylint: disable=protected-access
tunable = [torch.cuda.tunable.is_enabled(), torch.cuda.tunable.tuning_is_enabled()]
except Exception:
tunable = [False, False]
log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upcast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} tunable={tunable} fp16={"pass" if fp16_ok else "fail"} bf16={"pass" if bf16_ok else "fail"} optimization="{opts.cross_attention_optimization}"')
def cond_cast_unet(tensor):
return tensor.to(dtype_unet) if unet_needs_upcast else tensor
def cond_cast_float(tensor):
return tensor.float() if unet_needs_upcast else tensor
def randn(seed, shape=None):
torch.manual_seed(seed)
if backend == 'ipex':
torch.xpu.manual_seed_all(seed)
if shape is None:
return None
if device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
elif opts.diffusers_generator_device == "CPU":
return torch.randn(shape, device=cpu)
else:
return torch.randn(shape, device=device)
def randn_without_seed(shape):
if device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def autocast(disable=False):
if disable or dtype == torch.float32:
return contextlib.nullcontext()
if backend == 'directml':
return torch.dml.amp.autocast(dtype)
if cuda_ok:
return torch.autocast("cuda")
else:
return torch.autocast("cpu")
def without_autocast(disable=False):
if disable:
return contextlib.nullcontext()
if backend == 'directml':
return torch.dml.amp.autocast(enabled=False) if torch.is_autocast_enabled() else contextlib.nullcontext() # pylint: disable=unexpected-keyword-arg
if cuda_ok:
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() else contextlib.nullcontext()
else:
return torch.autocast("cpu", enabled=False) if torch.is_autocast_enabled() else contextlib.nullcontext()
class NansException(Exception):
pass
def test_for_nans(x, where):
if opts.disable_nan_check:
return
if not torch.all(torch.isnan(x)).item():
return
if where == "unet":
message = "A tensor with all NaNs was produced in Unet."
if not opts.no_half:
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
elif where == "vae":
message = "A tensor with all NaNs was produced in VAE."
if not opts.no_half and not opts.no_half_vae:
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
else:
message = "A tensor with all NaNs was produced."
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message)
def normalize_device(dev):
if torch.device(dev).type in {"cpu", "mps", "meta"}:
return torch.device(dev)
if torch.device(dev).index is None:
return torch.device(str(dev), index=0)
return torch.device(dev)
def same_device(d1, d2):
if d1.type != d2.type:
return False
return normalize_device(d1) == normalize_device(d2)