1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-08-26 10:37:30 +05:30
parent ce12925a23
commit 0137a16ed5

View File

@@ -27,6 +27,56 @@ if is_torch_available():
import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
BACKEND_EMPTY_CACHE = {
"cuda": torch.cuda.empty_cache,
"xpu": torch.xpu.empty_cache,
"cpu": None,
"mps": torch.mps.empty_cache,
"default": None,
}
BACKEND_DEVICE_COUNT = {
"cuda": torch.cuda.device_count,
"xpu": torch.xpu.device_count,
"cpu": lambda: 0,
"mps": lambda: 0,
"default": 0,
}
BACKEND_MANUAL_SEED = {
"cuda": torch.cuda.manual_seed,
"xpu": torch.xpu.manual_seed,
"cpu": torch.manual_seed,
"mps": torch.mps.manual_seed,
"default": torch.manual_seed,
}
BACKEND_RESET_PEAK_MEMORY_STATS = {
"cuda": torch.cuda.reset_peak_memory_stats,
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
"cpu": None,
"mps": None,
"default": None,
}
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.reset_max_memory_allocated,
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
"cpu": None,
"mps": None,
"default": None,
}
BACKEND_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.max_memory_allocated,
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
"cpu": 0,
"mps": 0,
"default": 0,
}
BACKEND_SYNCHRONIZE = {
"cuda": torch.cuda.synchronize,
"xpu": getattr(torch.xpu, "synchronize", None),
"cpu": None,
"mps": None,
"default": None,
}
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
try:
@@ -37,60 +87,6 @@ except (ImportError, ModuleNotFoundError):
return cls
# Behaviour flags
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
# Function definitions
BACKEND_EMPTY_CACHE = {
"cuda": torch.cuda.empty_cache,
"xpu": torch.xpu.empty_cache,
"cpu": None,
"mps": torch.mps.empty_cache,
"default": None,
}
BACKEND_DEVICE_COUNT = {
"cuda": torch.cuda.device_count,
"xpu": torch.xpu.device_count,
"cpu": lambda: 0,
"mps": lambda: 0,
"default": 0,
}
BACKEND_MANUAL_SEED = {
"cuda": torch.cuda.manual_seed,
"xpu": torch.xpu.manual_seed,
"cpu": torch.manual_seed,
"mps": torch.mps.manual_seed,
"default": torch.manual_seed,
}
BACKEND_RESET_PEAK_MEMORY_STATS = {
"cuda": torch.cuda.reset_peak_memory_stats,
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
"cpu": None,
"mps": None,
"default": None,
}
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.reset_max_memory_allocated,
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
"cpu": None,
"mps": None,
"default": None,
}
BACKEND_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.max_memory_allocated,
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
"cpu": 0,
"mps": 0,
"default": 0,
}
BACKEND_SYNCHRONIZE = {
"cuda": torch.cuda.synchronize,
"xpu": getattr(torch.xpu, "synchronize", None),
"cpu": None,
"mps": None,
"default": None,
}
# This dispatches a defined function according to the accelerator from the function definitions.
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
if device not in dispatch_table:
@@ -334,4 +330,5 @@ def disable_full_determinism():
torch.use_deterministic_algorithms(False)
torch_device = get_device()
if is_torch_available():
torch_device = get_device()