mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -27,6 +27,56 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch.fft import fftn, fftshift, ifftn, ifftshift
|
||||
|
||||
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
|
||||
BACKEND_EMPTY_CACHE = {
|
||||
"cuda": torch.cuda.empty_cache,
|
||||
"xpu": torch.xpu.empty_cache,
|
||||
"cpu": None,
|
||||
"mps": torch.mps.empty_cache,
|
||||
"default": None,
|
||||
}
|
||||
BACKEND_DEVICE_COUNT = {
|
||||
"cuda": torch.cuda.device_count,
|
||||
"xpu": torch.xpu.device_count,
|
||||
"cpu": lambda: 0,
|
||||
"mps": lambda: 0,
|
||||
"default": 0,
|
||||
}
|
||||
BACKEND_MANUAL_SEED = {
|
||||
"cuda": torch.cuda.manual_seed,
|
||||
"xpu": torch.xpu.manual_seed,
|
||||
"cpu": torch.manual_seed,
|
||||
"mps": torch.mps.manual_seed,
|
||||
"default": torch.manual_seed,
|
||||
}
|
||||
BACKEND_RESET_PEAK_MEMORY_STATS = {
|
||||
"cuda": torch.cuda.reset_peak_memory_stats,
|
||||
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
|
||||
"cpu": None,
|
||||
"mps": None,
|
||||
"default": None,
|
||||
}
|
||||
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
|
||||
"cuda": torch.cuda.reset_max_memory_allocated,
|
||||
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
|
||||
"cpu": None,
|
||||
"mps": None,
|
||||
"default": None,
|
||||
}
|
||||
BACKEND_MAX_MEMORY_ALLOCATED = {
|
||||
"cuda": torch.cuda.max_memory_allocated,
|
||||
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
|
||||
"cpu": 0,
|
||||
"mps": 0,
|
||||
"default": 0,
|
||||
}
|
||||
BACKEND_SYNCHRONIZE = {
|
||||
"cuda": torch.cuda.synchronize,
|
||||
"xpu": getattr(torch.xpu, "synchronize", None),
|
||||
"cpu": None,
|
||||
"mps": None,
|
||||
"default": None,
|
||||
}
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
@@ -37,60 +87,6 @@ except (ImportError, ModuleNotFoundError):
|
||||
return cls
|
||||
|
||||
|
||||
# Behaviour flags
|
||||
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
|
||||
# Function definitions
|
||||
BACKEND_EMPTY_CACHE = {
|
||||
"cuda": torch.cuda.empty_cache,
|
||||
"xpu": torch.xpu.empty_cache,
|
||||
"cpu": None,
|
||||
"mps": torch.mps.empty_cache,
|
||||
"default": None,
|
||||
}
|
||||
BACKEND_DEVICE_COUNT = {
|
||||
"cuda": torch.cuda.device_count,
|
||||
"xpu": torch.xpu.device_count,
|
||||
"cpu": lambda: 0,
|
||||
"mps": lambda: 0,
|
||||
"default": 0,
|
||||
}
|
||||
BACKEND_MANUAL_SEED = {
|
||||
"cuda": torch.cuda.manual_seed,
|
||||
"xpu": torch.xpu.manual_seed,
|
||||
"cpu": torch.manual_seed,
|
||||
"mps": torch.mps.manual_seed,
|
||||
"default": torch.manual_seed,
|
||||
}
|
||||
BACKEND_RESET_PEAK_MEMORY_STATS = {
|
||||
"cuda": torch.cuda.reset_peak_memory_stats,
|
||||
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
|
||||
"cpu": None,
|
||||
"mps": None,
|
||||
"default": None,
|
||||
}
|
||||
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
|
||||
"cuda": torch.cuda.reset_max_memory_allocated,
|
||||
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
|
||||
"cpu": None,
|
||||
"mps": None,
|
||||
"default": None,
|
||||
}
|
||||
BACKEND_MAX_MEMORY_ALLOCATED = {
|
||||
"cuda": torch.cuda.max_memory_allocated,
|
||||
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
|
||||
"cpu": 0,
|
||||
"mps": 0,
|
||||
"default": 0,
|
||||
}
|
||||
BACKEND_SYNCHRONIZE = {
|
||||
"cuda": torch.cuda.synchronize,
|
||||
"xpu": getattr(torch.xpu, "synchronize", None),
|
||||
"cpu": None,
|
||||
"mps": None,
|
||||
"default": None,
|
||||
}
|
||||
|
||||
|
||||
# This dispatches a defined function according to the accelerator from the function definitions.
|
||||
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
|
||||
if device not in dispatch_table:
|
||||
@@ -334,4 +330,5 @@ def disable_full_determinism():
|
||||
torch.use_deterministic_algorithms(False)
|
||||
|
||||
|
||||
torch_device = get_device()
|
||||
if is_torch_available():
|
||||
torch_device = get_device()
|
||||
|
||||
Reference in New Issue
Block a user