From e0d5079f9cd5229c35c4780031bfb1335eaf0226 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 9 Jun 2025 10:25:45 +0530 Subject: [PATCH] start implementing disk offloading in group. --- go.diff | 218 ++++++++++++++++++++++++ src/diffusers/hooks/group_offloading.py | 99 ++++++++++- 2 files changed, 313 insertions(+), 4 deletions(-) create mode 100644 go.diff diff --git a/go.diff b/go.diff new file mode 100644 index 0000000000..7640b6b78c --- /dev/null +++ b/go.diff @@ -0,0 +1,218 @@ +diff --git a/diffusers/hooks/offload.py b/diffusers/hooks/offload.py +--- a/diffusers/hooks/offload.py ++++ b/diffusers/hooks/offload.py +@@ -1,6 +1,10 @@ + import os +-import torch ++import torch ++from safetensors.torch import save_file, load_file + ++import os + from typing import Optional, Union + from torch import nn + from .module_group import ModuleGroup +@@ -25,6 +29,32 @@ from .hooks import HookRegistry + from .hooks import GroupOffloadingHook, LazyPrefetchGroupOffloadingHook + ++# ------------------------------------------------------------------------------- ++# Helpers for disk/NVMe offload using safetensors ++# ------------------------------------------------------------------------------- ++def _offload_tensor_to_disk_st(tensor: torch.Tensor, path: str) -> None: ++ """ ++ Serialize a tensor out to disk in safetensors format. ++ We pin the CPU copy so that non_blocking loads can overlap copy/compute. ++ """ ++ os.makedirs(os.path.dirname(path), exist_ok=True) ++ cpu_t = tensor.detach().cpu().pin_memory() ++ save_file({"0": cpu_t}, path) ++ # free the original GPU tensor immediately ++ del tensor ++ ++def _load_tensor_from_disk_st( ++ path: str, device: torch.device, non_blocking: bool ++) -> torch.Tensor: ++ """ ++ Load a tensor back in with safetensors. ++ - If non_blocking on CUDA: load to CPU pinned memory, then .to(cuda, non_blocking=True). ++ - Otherwise: direct load_file(device=...). ++ """ ++ # fast path: direct to target device ++ if not (non_blocking and device.type == "cuda"): ++ data = load_file(path, device=device) ++ return data["0"] ++ # pinned-CPU fallback for true non-blocking ++ data = load_file(path, device="cpu") ++ cpu_t = data["0"] ++ return cpu_t.to(device, non_blocking=True) ++ ++ + def apply_group_offloading( + module: torch.nn.Module, + onload_device: torch.device, +- offload_device: torch.device = torch.device("cpu"), +- offload_type: str = "block_level", ++ offload_device: torch.device = torch.device("cpu"), ++ *, ++ offload_to_disk: bool = False, ++ offload_path: Optional[str] = None, ++ offload_type: str = "block_level", + num_blocks_per_group: Optional[int] = None, + non_blocking: bool = False, + use_stream: bool = False, +@@ -37,6 +67,10 @@ def apply_group_offloading( + Example: + ```python + >>> apply_group_offloading(... ) ++ # to store params on NVMe: ++ >>> apply_group_offloading( ++ ... model, ++ ... onload_device=torch.device("cuda"), ++ ... offload_to_disk=True, ++ ... offload_path="/mnt/nvme1/offload", ++ ... offload_type="block_level", ++ ... num_blocks_per_group=1, ++ ... ) + ``` + """ + +@@ -69,6 +103,10 @@ def apply_group_offloading( + if num_blocks_per_group is None: + raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") ++ if offload_to_disk and offload_path is None: ++ raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") + + _apply_group_offloading_block_level( + module=module, ++ offload_to_disk=offload_to_disk, ++ offload_path=offload_path, + num_blocks_per_group=num_blocks_per_group, + offload_device=offload_device, + onload_device=onload_device, +@@ -79,6 +117,11 @@ def apply_group_offloading( + elif offload_type == "leaf_level": ++ if offload_to_disk and offload_path is None: ++ raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") + _apply_group_offloading_leaf_level( + module=module, ++ offload_to_disk=offload_to_disk, ++ offload_path=offload_path, + offload_device=offload_device, + onload_device=onload_device, + non_blocking=non_blocking, +@@ -107,10 +150,16 @@ def _apply_group_offloading_block_level( + """ +- module: torch.nn.Module, +- num_blocks_per_group: int, +- offload_device: torch.device, +- onload_device: torch.device, ++ module: torch.nn.Module, ++ num_blocks_per_group: int, ++ offload_device: torch.device, ++ offload_to_disk: bool, ++ offload_path: Optional[str], ++ onload_device: torch.device, + non_blocking: bool, + stream: Union[torch.cuda.Stream, torch.Stream, None] = None, + record_stream: Optional[bool] = False, + low_cpu_mem_usage: bool = False, + ) -> None: +@@ -138,7 +187,9 @@ def _apply_group_offloading_block_level( + for i in range(0, len(submodule), num_blocks_per_group): + current_modules = submodule[i : i + num_blocks_per_group] + group = ModuleGroup( +- modules=current_modules, ++ modules=current_modules, ++ offload_to_disk=offload_to_disk, ++ offload_path=offload_path, + offload_device=offload_device, + onload_device=onload_device, + offload_leader=current_modules[-1], +@@ -187,10 +238,14 @@ def _apply_group_offloading_block_level( + unmatched_group = ModuleGroup( + modules=unmatched_modules, +- offload_device=offload_device, ++ offload_to_disk=offload_to_disk, ++ offload_path=offload_path, ++ offload_device=offload_device, + onload_device=onload_device, + offload_leader=module, + onload_leader=module, ++ # other args omitted for brevity... + ) + + if stream is None: +@@ -216,10 +271,16 @@ def _apply_group_offloading_leaf_level( + """ +- module: torch.nn.Module, +- offload_device: torch.device, +- onload_device: torch.device, +- non_blocking: bool, ++ module: torch.nn.Module, ++ offload_device: torch.device, ++ offload_to_disk: bool, ++ offload_path: Optional[str], ++ onload_device: torch.device, ++ non_blocking: bool, + stream: Union[torch.cuda.Stream, torch.Stream, None] = None, + record_stream: Optional[bool] = False, + low_cpu_mem_usage: bool = False, + ) -> None: +@@ -229,7 +290,9 @@ def _apply_group_offloading_leaf_level( + for name, submodule in module.named_modules(): + if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): + continue +- group = ModuleGroup( ++ group = ModuleGroup( ++ offload_to_disk=offload_to_disk, ++ offload_path=offload_path, + modules=[submodule], + offload_device=offload_device, + onload_device=onload_device, +@@ -317,10 +380,14 @@ def _apply_group_offloading_leaf_level( + parent_module = module_dict[name] + assert getattr(parent_module, "_diffusers_hook", None) is None +- group = ModuleGroup( ++ group = ModuleGroup( ++ offload_to_disk=offload_to_disk, ++ offload_path=offload_path, + modules=[], + offload_device=offload_device, + onload_device=onload_device, ++ # additional args omitted for brevity... + ) + _apply_group_offloading_hook(parent_module, group, None) + +@@ -360,6 +427,38 @@ def _apply_lazy_group_offloading_hook( + registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) + + ++# ------------------------------------------------------------------------------- ++# Patch GroupOffloadingHook to use safetensors disk offload ++# ------------------------------------------------------------------------------- ++class GroupOffloadingHook: ++ def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup]): ++ self.group = group ++ self.next_group = next_group ++ # map param/buffer name -> file path ++ self.param_to_path: Dict[str,str] = {} ++ self.buffer_to_path: Dict[str,str] = {} ++ ++ def offload_parameters(self, module: nn.Module): ++ for name, param in module.named_parameters(recurse=False): ++ if self.group.offload_to_disk: ++ path = os.path.join(self.group.offload_path, f"{module.__class__.__name__}__{name}.safetensors") ++ _offload_tensor_to_disk_st(param.data, path) ++ self.param_to_path[name] = path ++ else: ++ param.data = param.data.to(self.group.offload_device, non_blocking=self.group.non_blocking) ++ ++ def onload_parameters(self, module: nn.Module): ++ for name, param in module.named_parameters(recurse=False): ++ if self.group.offload_to_disk: ++ path = self.param_to_path[name] ++ param.data = _load_tensor_from_disk_st(path, self.group.onload_device, self.group.non_blocking) ++ else: ++ param.data = param.data.to(self.group.onload_device, non_blocking=self.group.non_blocking) ++ ++ # analogous changes for buffers... ++ diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 565f8f1ff8..a31acb5a2d 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -14,9 +14,10 @@ from contextlib import contextmanager, nullcontext from typing import Dict, List, Optional, Set, Tuple, Union +import os import torch - +import safetensors.torch from ..utils import get_logger, is_accelerate_available from .hooks import HookRegistry, ModelHook @@ -59,6 +60,8 @@ class ModuleGroup: record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, onload_self: bool = True, + offload_to_disk: bool = False, + offload_path: Optional[str] = None, ) -> None: self.modules = modules self.offload_device = offload_device @@ -72,7 +75,29 @@ class ModuleGroup: self.record_stream = record_stream self.onload_self = onload_self self.low_cpu_mem_usage = low_cpu_mem_usage - self.cpu_param_dict = self._init_cpu_param_dict() + + self.offload_to_disk = offload_to_disk + self.offload_path = offload_path + self._is_offloaded_to_disk = False + + if self.offload_to_disk: + if self.offload_path is None: + raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") + self.safetensors_file_path = os.path.join(self.offload_path, f"group_{id(self)}.safetensors") + + all_tensors = [] + for module in self.modules: + all_tensors.extend(list(module.parameters())) + all_tensors.extend(list(module.buffers())) + all_tensors.extend(self.parameters) + all_tensors.extend(self.buffers) + all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates + + self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} + self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} + self.cpu_param_dict = {} + else: + self.cpu_param_dict = self._init_cpu_param_dict() if self.stream is None and self.record_stream: raise ValueError("`record_stream` cannot be True when `stream` is None.") @@ -124,6 +149,29 @@ class ModuleGroup: context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) current_stream = torch_accelerator_module.current_stream() if self.record_stream else None + if self.offload_to_disk: + if self.stream is not None: + # Wait for previous Host->Device transfer to complete + self.stream.synchronize() + + with context: + if self.stream is not None: + # Load to CPU, pin, and async copy to device for overlapping transfer and compute + loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") + for key, tensor_obj in self.key_to_tensor.items(): + pinned_tensor = loaded_cpu_tensors[key].pin_memory() + tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + tensor_obj.data.record_stream(current_stream) + else: + # Load directly to the target device (synchronous) + loaded_tensors = safetensors.torch.load_file( + self.safetensors_file_path, device=self.onload_device + ) + for key, tensor_obj in self.key_to_tensor.items(): + tensor_obj.data = loaded_tensors[key] + return + if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() @@ -169,6 +217,18 @@ class ModuleGroup: @torch.compiler.disable() def offload_(self): r"""Offloads the group of modules to the offload_device.""" + if self.offload_to_disk: + if not self._is_offloaded_to_disk: + os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) + tensors_to_save = { + key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() + } + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) + self._is_offloaded_to_disk = True + + for tensor_obj in self.tensor_to_key.keys(): + tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + return torch_accelerator_module = ( getattr(torch, torch.accelerator.current_accelerator().type) @@ -208,10 +268,13 @@ class GroupOffloadingHook(ModelHook): def __init__( self, group: ModuleGroup, - next_group: Optional[ModuleGroup] = None, + next_group: Optional[ModuleGroup] = None ) -> None: self.group = group self.next_group = next_group + # map param/buffer name -> file path + self.param_to_path: Dict[str,str] = {} + self.buffer_to_path: Dict[str,str] = {} def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.group.offload_leader == module: @@ -358,6 +421,8 @@ def apply_group_offloading( onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), offload_type: str = "block_level", + offload_to_disk: bool = False, + offload_path: Optional[str] = None, num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, @@ -401,6 +466,11 @@ def apply_group_offloading( offload_type (`str`, defaults to "block_level"): The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is "block_level". + offload_to_disk (`bool`, defaults to `False`): + If `True`, offload model parameters to disk instead of CPU RAM. This is useful when CPU memory is limited. + Requires `offload_path` to be set. + offload_path (`str`, *optional*): + The path to the directory where offloaded parameters will be stored when `offload_to_disk` is `True`. num_blocks_per_group (`int`, *optional*): The number of blocks per group when using offload_type="block_level". This is required when using offload_type="block_level". @@ -447,6 +517,9 @@ def apply_group_offloading( else: raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.") + if offload_to_disk and offload_path is None: + raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") + _raise_error_if_accelerate_model_or_sequential_hook_present(module) if offload_type == "block_level": @@ -458,6 +531,8 @@ def apply_group_offloading( num_blocks_per_group=num_blocks_per_group, offload_device=offload_device, onload_device=onload_device, + offload_to_disk=offload_to_disk, + offload_path=offload_path, non_blocking=non_blocking, stream=stream, record_stream=record_stream, @@ -468,6 +543,8 @@ def apply_group_offloading( module=module, offload_device=offload_device, onload_device=onload_device, + offload_to_disk=offload_to_disk, + offload_path=offload_path, non_blocking=non_blocking, stream=stream, record_stream=record_stream, @@ -481,6 +558,8 @@ def _apply_group_offloading_block_level( module: torch.nn.Module, num_blocks_per_group: int, offload_device: torch.device, + offload_to_disk: bool, + offload_path: Optional[str], onload_device: torch.device, non_blocking: bool, stream: Union[torch.cuda.Stream, torch.Stream, None] = None, @@ -535,6 +614,8 @@ def _apply_group_offloading_block_level( modules=current_modules, offload_device=offload_device, onload_device=onload_device, + offload_to_disk=offload_to_disk, + offload_path=offload_path, offload_leader=current_modules[-1], onload_leader=current_modules[0], non_blocking=non_blocking, @@ -567,6 +648,8 @@ def _apply_group_offloading_block_level( modules=unmatched_modules, offload_device=offload_device, onload_device=onload_device, + offload_to_disk=offload_to_disk, + offload_path=offload_path, offload_leader=module, onload_leader=module, parameters=parameters, @@ -586,6 +669,8 @@ def _apply_group_offloading_leaf_level( module: torch.nn.Module, offload_device: torch.device, onload_device: torch.device, + offload_to_disk: bool, + offload_path: Optional[str], non_blocking: bool, stream: Union[torch.cuda.Stream, torch.Stream, None] = None, record_stream: Optional[bool] = False, @@ -629,6 +714,8 @@ def _apply_group_offloading_leaf_level( modules=[submodule], offload_device=offload_device, onload_device=onload_device, + offload_to_disk=offload_to_disk, + offload_path=offload_path, offload_leader=submodule, onload_leader=submodule, non_blocking=non_blocking, @@ -675,6 +762,8 @@ def _apply_group_offloading_leaf_level( onload_device=onload_device, offload_leader=parent_module, onload_leader=parent_module, + offload_to_disk=offload_to_disk, + offload_path=offload_path, parameters=parameters, buffers=buffers, non_blocking=non_blocking, @@ -693,6 +782,8 @@ def _apply_group_offloading_leaf_level( modules=[], offload_device=offload_device, onload_device=onload_device, + offload_to_disk=offload_to_disk, + offload_path=offload_path, offload_leader=module, onload_leader=module, parameters=None, @@ -808,4 +899,4 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device: for submodule in module.modules(): if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device - raise ValueError("Group offloading is not enabled for the provided module.") + raise ValueError("Group offloading is not enabled for the provided module.") \ No newline at end of file