mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
start implementing disk offloading in group.
This commit is contained in:
218
go.diff
Normal file
218
go.diff
Normal file
@@ -0,0 +1,218 @@
|
||||
diff --git a/diffusers/hooks/offload.py b/diffusers/hooks/offload.py
|
||||
--- a/diffusers/hooks/offload.py
|
||||
+++ b/diffusers/hooks/offload.py
|
||||
@@ -1,6 +1,10 @@
|
||||
import os
|
||||
-import torch
|
||||
+import torch
|
||||
+from safetensors.torch import save_file, load_file
|
||||
|
||||
+import os
|
||||
from typing import Optional, Union
|
||||
from torch import nn
|
||||
from .module_group import ModuleGroup
|
||||
@@ -25,6 +29,32 @@ from .hooks import HookRegistry
|
||||
from .hooks import GroupOffloadingHook, LazyPrefetchGroupOffloadingHook
|
||||
|
||||
+# -------------------------------------------------------------------------------
|
||||
+# Helpers for disk/NVMe offload using safetensors
|
||||
+# -------------------------------------------------------------------------------
|
||||
+def _offload_tensor_to_disk_st(tensor: torch.Tensor, path: str) -> None:
|
||||
+ """
|
||||
+ Serialize a tensor out to disk in safetensors format.
|
||||
+ We pin the CPU copy so that non_blocking loads can overlap copy/compute.
|
||||
+ """
|
||||
+ os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
+ cpu_t = tensor.detach().cpu().pin_memory()
|
||||
+ save_file({"0": cpu_t}, path)
|
||||
+ # free the original GPU tensor immediately
|
||||
+ del tensor
|
||||
+
|
||||
+def _load_tensor_from_disk_st(
|
||||
+ path: str, device: torch.device, non_blocking: bool
|
||||
+) -> torch.Tensor:
|
||||
+ """
|
||||
+ Load a tensor back in with safetensors.
|
||||
+ - If non_blocking on CUDA: load to CPU pinned memory, then .to(cuda, non_blocking=True).
|
||||
+ - Otherwise: direct load_file(device=...).
|
||||
+ """
|
||||
+ # fast path: direct to target device
|
||||
+ if not (non_blocking and device.type == "cuda"):
|
||||
+ data = load_file(path, device=device)
|
||||
+ return data["0"]
|
||||
+ # pinned-CPU fallback for true non-blocking
|
||||
+ data = load_file(path, device="cpu")
|
||||
+ cpu_t = data["0"]
|
||||
+ return cpu_t.to(device, non_blocking=True)
|
||||
+
|
||||
+
|
||||
def apply_group_offloading(
|
||||
module: torch.nn.Module,
|
||||
onload_device: torch.device,
|
||||
- offload_device: torch.device = torch.device("cpu"),
|
||||
- offload_type: str = "block_level",
|
||||
+ offload_device: torch.device = torch.device("cpu"),
|
||||
+ *,
|
||||
+ offload_to_disk: bool = False,
|
||||
+ offload_path: Optional[str] = None,
|
||||
+ offload_type: str = "block_level",
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
@@ -37,6 +67,10 @@ def apply_group_offloading(
|
||||
Example:
|
||||
```python
|
||||
>>> apply_group_offloading(... )
|
||||
+ # to store params on NVMe:
|
||||
+ >>> apply_group_offloading(
|
||||
+ ... model,
|
||||
+ ... onload_device=torch.device("cuda"),
|
||||
+ ... offload_to_disk=True,
|
||||
+ ... offload_path="/mnt/nvme1/offload",
|
||||
+ ... offload_type="block_level",
|
||||
+ ... num_blocks_per_group=1,
|
||||
+ ... )
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -69,6 +103,10 @@ def apply_group_offloading(
|
||||
if num_blocks_per_group is None:
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
+ if offload_to_disk and offload_path is None:
|
||||
+ raise ValueError("`offload_path` must be set when `offload_to_disk=True`.")
|
||||
|
||||
_apply_group_offloading_block_level(
|
||||
module=module,
|
||||
+ offload_to_disk=offload_to_disk,
|
||||
+ offload_path=offload_path,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
@@ -79,6 +117,11 @@ def apply_group_offloading(
|
||||
elif offload_type == "leaf_level":
|
||||
+ if offload_to_disk and offload_path is None:
|
||||
+ raise ValueError("`offload_path` must be set when `offload_to_disk=True`.")
|
||||
_apply_group_offloading_leaf_level(
|
||||
module=module,
|
||||
+ offload_to_disk=offload_to_disk,
|
||||
+ offload_path=offload_path,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
non_blocking=non_blocking,
|
||||
@@ -107,10 +150,16 @@ def _apply_group_offloading_block_level(
|
||||
"""
|
||||
- module: torch.nn.Module,
|
||||
- num_blocks_per_group: int,
|
||||
- offload_device: torch.device,
|
||||
- onload_device: torch.device,
|
||||
+ module: torch.nn.Module,
|
||||
+ num_blocks_per_group: int,
|
||||
+ offload_device: torch.device,
|
||||
+ offload_to_disk: bool,
|
||||
+ offload_path: Optional[str],
|
||||
+ onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
@@ -138,7 +187,9 @@ def _apply_group_offloading_block_level(
|
||||
for i in range(0, len(submodule), num_blocks_per_group):
|
||||
current_modules = submodule[i : i + num_blocks_per_group]
|
||||
group = ModuleGroup(
|
||||
- modules=current_modules,
|
||||
+ modules=current_modules,
|
||||
+ offload_to_disk=offload_to_disk,
|
||||
+ offload_path=offload_path,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_leader=current_modules[-1],
|
||||
@@ -187,10 +238,14 @@ def _apply_group_offloading_block_level(
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=unmatched_modules,
|
||||
- offload_device=offload_device,
|
||||
+ offload_to_disk=offload_to_disk,
|
||||
+ offload_path=offload_path,
|
||||
+ offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
+ # other args omitted for brevity...
|
||||
)
|
||||
|
||||
if stream is None:
|
||||
@@ -216,10 +271,16 @@ def _apply_group_offloading_leaf_level(
|
||||
"""
|
||||
- module: torch.nn.Module,
|
||||
- offload_device: torch.device,
|
||||
- onload_device: torch.device,
|
||||
- non_blocking: bool,
|
||||
+ module: torch.nn.Module,
|
||||
+ offload_device: torch.device,
|
||||
+ offload_to_disk: bool,
|
||||
+ offload_path: Optional[str],
|
||||
+ onload_device: torch.device,
|
||||
+ non_blocking: bool,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
@@ -229,7 +290,9 @@ def _apply_group_offloading_leaf_level(
|
||||
for name, submodule in module.named_modules():
|
||||
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
|
||||
continue
|
||||
- group = ModuleGroup(
|
||||
+ group = ModuleGroup(
|
||||
+ offload_to_disk=offload_to_disk,
|
||||
+ offload_path=offload_path,
|
||||
modules=[submodule],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
@@ -317,10 +380,14 @@ def _apply_group_offloading_leaf_level(
|
||||
parent_module = module_dict[name]
|
||||
assert getattr(parent_module, "_diffusers_hook", None) is None
|
||||
- group = ModuleGroup(
|
||||
+ group = ModuleGroup(
|
||||
+ offload_to_disk=offload_to_disk,
|
||||
+ offload_path=offload_path,
|
||||
modules=[],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
+ # additional args omitted for brevity...
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
|
||||
@@ -360,6 +427,38 @@ def _apply_lazy_group_offloading_hook(
|
||||
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
|
||||
|
||||
|
||||
+# -------------------------------------------------------------------------------
|
||||
+# Patch GroupOffloadingHook to use safetensors disk offload
|
||||
+# -------------------------------------------------------------------------------
|
||||
+class GroupOffloadingHook:
|
||||
+ def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup]):
|
||||
+ self.group = group
|
||||
+ self.next_group = next_group
|
||||
+ # map param/buffer name -> file path
|
||||
+ self.param_to_path: Dict[str,str] = {}
|
||||
+ self.buffer_to_path: Dict[str,str] = {}
|
||||
+
|
||||
+ def offload_parameters(self, module: nn.Module):
|
||||
+ for name, param in module.named_parameters(recurse=False):
|
||||
+ if self.group.offload_to_disk:
|
||||
+ path = os.path.join(self.group.offload_path, f"{module.__class__.__name__}__{name}.safetensors")
|
||||
+ _offload_tensor_to_disk_st(param.data, path)
|
||||
+ self.param_to_path[name] = path
|
||||
+ else:
|
||||
+ param.data = param.data.to(self.group.offload_device, non_blocking=self.group.non_blocking)
|
||||
+
|
||||
+ def onload_parameters(self, module: nn.Module):
|
||||
+ for name, param in module.named_parameters(recurse=False):
|
||||
+ if self.group.offload_to_disk:
|
||||
+ path = self.param_to_path[name]
|
||||
+ param.data = _load_tensor_from_disk_st(path, self.group.onload_device, self.group.non_blocking)
|
||||
+ else:
|
||||
+ param.data = param.data.to(self.group.onload_device, non_blocking=self.group.non_blocking)
|
||||
+
|
||||
+ # analogous changes for buffers...
|
||||
+
|
||||
@@ -14,9 +14,10 @@
|
||||
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
import safetensors.torch
|
||||
from ..utils import get_logger, is_accelerate_available
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
@@ -59,6 +60,8 @@ class ModuleGroup:
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
onload_self: bool = True,
|
||||
offload_to_disk: bool = False,
|
||||
offload_path: Optional[str] = None,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
self.offload_device = offload_device
|
||||
@@ -72,7 +75,29 @@ class ModuleGroup:
|
||||
self.record_stream = record_stream
|
||||
self.onload_self = onload_self
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
self.cpu_param_dict = self._init_cpu_param_dict()
|
||||
|
||||
self.offload_to_disk = offload_to_disk
|
||||
self.offload_path = offload_path
|
||||
self._is_offloaded_to_disk = False
|
||||
|
||||
if self.offload_to_disk:
|
||||
if self.offload_path is None:
|
||||
raise ValueError("`offload_path` must be set when `offload_to_disk=True`.")
|
||||
self.safetensors_file_path = os.path.join(self.offload_path, f"group_{id(self)}.safetensors")
|
||||
|
||||
all_tensors = []
|
||||
for module in self.modules:
|
||||
all_tensors.extend(list(module.parameters()))
|
||||
all_tensors.extend(list(module.buffers()))
|
||||
all_tensors.extend(self.parameters)
|
||||
all_tensors.extend(self.buffers)
|
||||
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
|
||||
|
||||
self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
|
||||
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
|
||||
self.cpu_param_dict = {}
|
||||
else:
|
||||
self.cpu_param_dict = self._init_cpu_param_dict()
|
||||
|
||||
if self.stream is None and self.record_stream:
|
||||
raise ValueError("`record_stream` cannot be True when `stream` is None.")
|
||||
@@ -124,6 +149,29 @@ class ModuleGroup:
|
||||
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
|
||||
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
|
||||
|
||||
if self.offload_to_disk:
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
if self.stream is not None:
|
||||
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
|
||||
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
|
||||
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
tensor_obj.data.record_stream(current_stream)
|
||||
else:
|
||||
# Load directly to the target device (synchronous)
|
||||
loaded_tensors = safetensors.torch.load_file(
|
||||
self.safetensors_file_path, device=self.onload_device
|
||||
)
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
tensor_obj.data = loaded_tensors[key]
|
||||
return
|
||||
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
@@ -169,6 +217,18 @@ class ModuleGroup:
|
||||
@torch.compiler.disable()
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
if self.offload_to_disk:
|
||||
if not self._is_offloaded_to_disk:
|
||||
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
|
||||
tensors_to_save = {
|
||||
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
|
||||
}
|
||||
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
|
||||
self._is_offloaded_to_disk = True
|
||||
|
||||
for tensor_obj in self.tensor_to_key.keys():
|
||||
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
|
||||
return
|
||||
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
@@ -208,10 +268,13 @@ class GroupOffloadingHook(ModelHook):
|
||||
def __init__(
|
||||
self,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
next_group: Optional[ModuleGroup] = None
|
||||
) -> None:
|
||||
self.group = group
|
||||
self.next_group = next_group
|
||||
# map param/buffer name -> file path
|
||||
self.param_to_path: Dict[str,str] = {}
|
||||
self.buffer_to_path: Dict[str,str] = {}
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if self.group.offload_leader == module:
|
||||
@@ -358,6 +421,8 @@ def apply_group_offloading(
|
||||
onload_device: torch.device,
|
||||
offload_device: torch.device = torch.device("cpu"),
|
||||
offload_type: str = "block_level",
|
||||
offload_to_disk: bool = False,
|
||||
offload_path: Optional[str] = None,
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
@@ -401,6 +466,11 @@ def apply_group_offloading(
|
||||
offload_type (`str`, defaults to "block_level"):
|
||||
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
|
||||
"block_level".
|
||||
offload_to_disk (`bool`, defaults to `False`):
|
||||
If `True`, offload model parameters to disk instead of CPU RAM. This is useful when CPU memory is limited.
|
||||
Requires `offload_path` to be set.
|
||||
offload_path (`str`, *optional*):
|
||||
The path to the directory where offloaded parameters will be stored when `offload_to_disk` is `True`.
|
||||
num_blocks_per_group (`int`, *optional*):
|
||||
The number of blocks per group when using offload_type="block_level". This is required when using
|
||||
offload_type="block_level".
|
||||
@@ -447,6 +517,9 @@ def apply_group_offloading(
|
||||
else:
|
||||
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
|
||||
|
||||
if offload_to_disk and offload_path is None:
|
||||
raise ValueError("`offload_path` must be set when `offload_to_disk=True`.")
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
if offload_type == "block_level":
|
||||
@@ -458,6 +531,8 @@ def apply_group_offloading(
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk=offload_to_disk,
|
||||
offload_path=offload_path,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
@@ -468,6 +543,8 @@ def apply_group_offloading(
|
||||
module=module,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk=offload_to_disk,
|
||||
offload_path=offload_path,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
@@ -481,6 +558,8 @@ def _apply_group_offloading_block_level(
|
||||
module: torch.nn.Module,
|
||||
num_blocks_per_group: int,
|
||||
offload_device: torch.device,
|
||||
offload_to_disk: bool,
|
||||
offload_path: Optional[str],
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
@@ -535,6 +614,8 @@ def _apply_group_offloading_block_level(
|
||||
modules=current_modules,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk=offload_to_disk,
|
||||
offload_path=offload_path,
|
||||
offload_leader=current_modules[-1],
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
@@ -567,6 +648,8 @@ def _apply_group_offloading_block_level(
|
||||
modules=unmatched_modules,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk=offload_to_disk,
|
||||
offload_path=offload_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=parameters,
|
||||
@@ -586,6 +669,8 @@ def _apply_group_offloading_leaf_level(
|
||||
module: torch.nn.Module,
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
offload_to_disk: bool,
|
||||
offload_path: Optional[str],
|
||||
non_blocking: bool,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
@@ -629,6 +714,8 @@ def _apply_group_offloading_leaf_level(
|
||||
modules=[submodule],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk=offload_to_disk,
|
||||
offload_path=offload_path,
|
||||
offload_leader=submodule,
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
@@ -675,6 +762,8 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_device=onload_device,
|
||||
offload_leader=parent_module,
|
||||
onload_leader=parent_module,
|
||||
offload_to_disk=offload_to_disk,
|
||||
offload_path=offload_path,
|
||||
parameters=parameters,
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
@@ -693,6 +782,8 @@ def _apply_group_offloading_leaf_level(
|
||||
modules=[],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk=offload_to_disk,
|
||||
offload_path=offload_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=None,
|
||||
@@ -808,4 +899,4 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
|
||||
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
|
||||
raise ValueError("Group offloading is not enabled for the provided module.")
|
||||
raise ValueError("Group offloading is not enabled for the provided module.")
|
||||
Reference in New Issue
Block a user