mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
219 lines
8.6 KiB
Diff
219 lines
8.6 KiB
Diff
diff --git a/diffusers/hooks/offload.py b/diffusers/hooks/offload.py
|
|
--- a/diffusers/hooks/offload.py
|
|
+++ b/diffusers/hooks/offload.py
|
|
@@ -1,6 +1,10 @@
|
|
import os
|
|
-import torch
|
|
+import torch
|
|
+from safetensors.torch import save_file, load_file
|
|
|
|
+import os
|
|
from typing import Optional, Union
|
|
from torch import nn
|
|
from .module_group import ModuleGroup
|
|
@@ -25,6 +29,32 @@ from .hooks import HookRegistry
|
|
from .hooks import GroupOffloadingHook, LazyPrefetchGroupOffloadingHook
|
|
|
|
+# -------------------------------------------------------------------------------
|
|
+# Helpers for disk/NVMe offload using safetensors
|
|
+# -------------------------------------------------------------------------------
|
|
+def _offload_tensor_to_disk_st(tensor: torch.Tensor, path: str) -> None:
|
|
+ """
|
|
+ Serialize a tensor out to disk in safetensors format.
|
|
+ We pin the CPU copy so that non_blocking loads can overlap copy/compute.
|
|
+ """
|
|
+ os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
+ cpu_t = tensor.detach().cpu().pin_memory()
|
|
+ save_file({"0": cpu_t}, path)
|
|
+ # free the original GPU tensor immediately
|
|
+ del tensor
|
|
+
|
|
+def _load_tensor_from_disk_st(
|
|
+ path: str, device: torch.device, non_blocking: bool
|
|
+) -> torch.Tensor:
|
|
+ """
|
|
+ Load a tensor back in with safetensors.
|
|
+ - If non_blocking on CUDA: load to CPU pinned memory, then .to(cuda, non_blocking=True).
|
|
+ - Otherwise: direct load_file(device=...).
|
|
+ """
|
|
+ # fast path: direct to target device
|
|
+ if not (non_blocking and device.type == "cuda"):
|
|
+ data = load_file(path, device=device)
|
|
+ return data["0"]
|
|
+ # pinned-CPU fallback for true non-blocking
|
|
+ data = load_file(path, device="cpu")
|
|
+ cpu_t = data["0"]
|
|
+ return cpu_t.to(device, non_blocking=True)
|
|
+
|
|
+
|
|
def apply_group_offloading(
|
|
module: torch.nn.Module,
|
|
onload_device: torch.device,
|
|
- offload_device: torch.device = torch.device("cpu"),
|
|
- offload_type: str = "block_level",
|
|
+ offload_device: torch.device = torch.device("cpu"),
|
|
+ *,
|
|
+ offload_to_disk: bool = False,
|
|
+ offload_path: Optional[str] = None,
|
|
+ offload_type: str = "block_level",
|
|
num_blocks_per_group: Optional[int] = None,
|
|
non_blocking: bool = False,
|
|
use_stream: bool = False,
|
|
@@ -37,6 +67,10 @@ def apply_group_offloading(
|
|
Example:
|
|
```python
|
|
>>> apply_group_offloading(... )
|
|
+ # to store params on NVMe:
|
|
+ >>> apply_group_offloading(
|
|
+ ... model,
|
|
+ ... onload_device=torch.device("cuda"),
|
|
+ ... offload_to_disk=True,
|
|
+ ... offload_path="/mnt/nvme1/offload",
|
|
+ ... offload_type="block_level",
|
|
+ ... num_blocks_per_group=1,
|
|
+ ... )
|
|
```
|
|
"""
|
|
|
|
@@ -69,6 +103,10 @@ def apply_group_offloading(
|
|
if num_blocks_per_group is None:
|
|
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
|
+ if offload_to_disk and offload_path is None:
|
|
+ raise ValueError("`offload_path` must be set when `offload_to_disk=True`.")
|
|
|
|
_apply_group_offloading_block_level(
|
|
module=module,
|
|
+ offload_to_disk=offload_to_disk,
|
|
+ offload_path=offload_path,
|
|
num_blocks_per_group=num_blocks_per_group,
|
|
offload_device=offload_device,
|
|
onload_device=onload_device,
|
|
@@ -79,6 +117,11 @@ def apply_group_offloading(
|
|
elif offload_type == "leaf_level":
|
|
+ if offload_to_disk and offload_path is None:
|
|
+ raise ValueError("`offload_path` must be set when `offload_to_disk=True`.")
|
|
_apply_group_offloading_leaf_level(
|
|
module=module,
|
|
+ offload_to_disk=offload_to_disk,
|
|
+ offload_path=offload_path,
|
|
offload_device=offload_device,
|
|
onload_device=onload_device,
|
|
non_blocking=non_blocking,
|
|
@@ -107,10 +150,16 @@ def _apply_group_offloading_block_level(
|
|
"""
|
|
- module: torch.nn.Module,
|
|
- num_blocks_per_group: int,
|
|
- offload_device: torch.device,
|
|
- onload_device: torch.device,
|
|
+ module: torch.nn.Module,
|
|
+ num_blocks_per_group: int,
|
|
+ offload_device: torch.device,
|
|
+ offload_to_disk: bool,
|
|
+ offload_path: Optional[str],
|
|
+ onload_device: torch.device,
|
|
non_blocking: bool,
|
|
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
|
record_stream: Optional[bool] = False,
|
|
low_cpu_mem_usage: bool = False,
|
|
) -> None:
|
|
@@ -138,7 +187,9 @@ def _apply_group_offloading_block_level(
|
|
for i in range(0, len(submodule), num_blocks_per_group):
|
|
current_modules = submodule[i : i + num_blocks_per_group]
|
|
group = ModuleGroup(
|
|
- modules=current_modules,
|
|
+ modules=current_modules,
|
|
+ offload_to_disk=offload_to_disk,
|
|
+ offload_path=offload_path,
|
|
offload_device=offload_device,
|
|
onload_device=onload_device,
|
|
offload_leader=current_modules[-1],
|
|
@@ -187,10 +238,14 @@ def _apply_group_offloading_block_level(
|
|
unmatched_group = ModuleGroup(
|
|
modules=unmatched_modules,
|
|
- offload_device=offload_device,
|
|
+ offload_to_disk=offload_to_disk,
|
|
+ offload_path=offload_path,
|
|
+ offload_device=offload_device,
|
|
onload_device=onload_device,
|
|
offload_leader=module,
|
|
onload_leader=module,
|
|
+ # other args omitted for brevity...
|
|
)
|
|
|
|
if stream is None:
|
|
@@ -216,10 +271,16 @@ def _apply_group_offloading_leaf_level(
|
|
"""
|
|
- module: torch.nn.Module,
|
|
- offload_device: torch.device,
|
|
- onload_device: torch.device,
|
|
- non_blocking: bool,
|
|
+ module: torch.nn.Module,
|
|
+ offload_device: torch.device,
|
|
+ offload_to_disk: bool,
|
|
+ offload_path: Optional[str],
|
|
+ onload_device: torch.device,
|
|
+ non_blocking: bool,
|
|
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
|
record_stream: Optional[bool] = False,
|
|
low_cpu_mem_usage: bool = False,
|
|
) -> None:
|
|
@@ -229,7 +290,9 @@ def _apply_group_offloading_leaf_level(
|
|
for name, submodule in module.named_modules():
|
|
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
|
|
continue
|
|
- group = ModuleGroup(
|
|
+ group = ModuleGroup(
|
|
+ offload_to_disk=offload_to_disk,
|
|
+ offload_path=offload_path,
|
|
modules=[submodule],
|
|
offload_device=offload_device,
|
|
onload_device=onload_device,
|
|
@@ -317,10 +380,14 @@ def _apply_group_offloading_leaf_level(
|
|
parent_module = module_dict[name]
|
|
assert getattr(parent_module, "_diffusers_hook", None) is None
|
|
- group = ModuleGroup(
|
|
+ group = ModuleGroup(
|
|
+ offload_to_disk=offload_to_disk,
|
|
+ offload_path=offload_path,
|
|
modules=[],
|
|
offload_device=offload_device,
|
|
onload_device=onload_device,
|
|
+ # additional args omitted for brevity...
|
|
)
|
|
_apply_group_offloading_hook(parent_module, group, None)
|
|
|
|
@@ -360,6 +427,38 @@ def _apply_lazy_group_offloading_hook(
|
|
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
|
|
|
|
|
|
+# -------------------------------------------------------------------------------
|
|
+# Patch GroupOffloadingHook to use safetensors disk offload
|
|
+# -------------------------------------------------------------------------------
|
|
+class GroupOffloadingHook:
|
|
+ def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup]):
|
|
+ self.group = group
|
|
+ self.next_group = next_group
|
|
+ # map param/buffer name -> file path
|
|
+ self.param_to_path: Dict[str,str] = {}
|
|
+ self.buffer_to_path: Dict[str,str] = {}
|
|
+
|
|
+ def offload_parameters(self, module: nn.Module):
|
|
+ for name, param in module.named_parameters(recurse=False):
|
|
+ if self.group.offload_to_disk:
|
|
+ path = os.path.join(self.group.offload_path, f"{module.__class__.__name__}__{name}.safetensors")
|
|
+ _offload_tensor_to_disk_st(param.data, path)
|
|
+ self.param_to_path[name] = path
|
|
+ else:
|
|
+ param.data = param.data.to(self.group.offload_device, non_blocking=self.group.non_blocking)
|
|
+
|
|
+ def onload_parameters(self, module: nn.Module):
|
|
+ for name, param in module.named_parameters(recurse=False):
|
|
+ if self.group.offload_to_disk:
|
|
+ path = self.param_to_path[name]
|
|
+ param.data = _load_tensor_from_disk_st(path, self.group.onload_device, self.group.non_blocking)
|
|
+ else:
|
|
+ param.data = param.data.to(self.group.onload_device, non_blocking=self.group.non_blocking)
|
|
+
|
|
+ # analogous changes for buffers...
|
|
+
|