mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
# pylint: disable=no-member,no-self-argument,no-method-argument
|
|
from typing import Optional, Callable
|
|
import torch
|
|
import torch_directml # pylint: disable=import-error
|
|
import modules.dml.amp as amp
|
|
from .utils import rDevice, get_device
|
|
from .device import Device
|
|
from .Generator import Generator
|
|
from .device_properties import DeviceProperties
|
|
|
|
|
|
def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
|
|
from .memory_amd import AMDMemoryProvider
|
|
return AMDMemoryProvider.mem_get_info(get_device(device).index)
|
|
|
|
|
|
def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
|
|
mem_info = DirectML.memory_provider.get_memory(get_device(device).index)
|
|
return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"])
|
|
|
|
|
|
def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: # pylint: disable=unused-argument
|
|
return (8589934592, 8589934592)
|
|
|
|
|
|
class DirectML:
|
|
amp = amp
|
|
device = Device
|
|
Generator = Generator
|
|
|
|
context_device: Optional[torch.device] = None
|
|
|
|
is_autocast_enabled = False
|
|
autocast_gpu_dtype = torch.float16
|
|
|
|
memory_provider = None
|
|
|
|
def is_available() -> bool:
|
|
return torch_directml.is_available()
|
|
|
|
def is_directml_device(device: torch.device) -> bool:
|
|
return device.type == "privateuseone"
|
|
|
|
def has_float64_support(device: Optional[rDevice]=None) -> bool:
|
|
return torch_directml.has_float64_support(get_device(device).index)
|
|
|
|
def device_count() -> int:
|
|
return torch_directml.device_count()
|
|
|
|
def current_device() -> torch.device:
|
|
return DirectML.context_device or DirectML.default_device()
|
|
|
|
def default_device() -> torch.device:
|
|
return torch_directml.device(torch_directml.default_device())
|
|
|
|
def get_device_string(device: Optional[rDevice]=None) -> str:
|
|
return f"privateuseone:{get_device(device).index}"
|
|
|
|
def get_device_name(device: Optional[rDevice]=None) -> str:
|
|
return torch_directml.device_name(get_device(device).index)
|
|
|
|
def get_device_properties(device: Optional[rDevice]=None) -> DeviceProperties:
|
|
return DeviceProperties(get_device(device))
|
|
|
|
def memory_stats(device: Optional[rDevice]=None):
|
|
return {
|
|
"num_ooms": 0,
|
|
"num_alloc_retries": 0,
|
|
}
|
|
|
|
mem_get_info: Callable = mem_get_info
|
|
|
|
def memory_allocated(device: Optional[rDevice]=None) -> int:
|
|
return sum(torch_directml.gpu_memory(get_device(device).index)) * (1 << 20)
|
|
|
|
def max_memory_allocated(device: Optional[rDevice]=None):
|
|
return DirectML.memory_allocated(device) # DirectML does not empty GPU memory
|
|
|
|
def reset_peak_memory_stats(device: Optional[rDevice]=None):
|
|
return
|