diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index ce6f47f67a..c2adb6ab1d 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -61,8 +61,7 @@ class ModuleGroup: record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, onload_self: bool = True, - offload_to_disk: bool = False, - offload_path: Optional[str] = None, + offload_to_disk_path: Optional[str] = None, ) -> None: self.modules = modules self.offload_device = offload_device @@ -77,14 +76,11 @@ class ModuleGroup: self.onload_self = onload_self self.low_cpu_mem_usage = low_cpu_mem_usage - self.offload_to_disk = offload_to_disk - self.offload_path = offload_path + self.offload_to_disk_path = offload_to_disk_path self._is_offloaded_to_disk = False - if self.offload_to_disk: - if self.offload_path is None: - raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") - self.safetensors_file_path = os.path.join(self.offload_path, f"group_{id(self)}.safetensors") + if self.offload_to_disk_path: + self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors") all_tensors = [] for module in self.modules: @@ -150,7 +146,7 @@ class ModuleGroup: context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) current_stream = torch_accelerator_module.current_stream() if self.record_stream else None - if self.offload_to_disk: + if self.offload_to_disk_path: if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() @@ -219,7 +215,7 @@ class ModuleGroup: @torch.compiler.disable() def offload_(self): r"""Offloads the group of modules to the offload_device.""" - if self.offload_to_disk: + if self.offload_to_disk_path: if not self._is_offloaded_to_disk: os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) tensors_to_save = { @@ -419,8 +415,7 @@ def apply_group_offloading( onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), offload_type: str = "block_level", - offload_to_disk: bool = False, - offload_path: Optional[str] = None, + offload_to_disk_path: Optional[str] = None, num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, @@ -464,11 +459,8 @@ def apply_group_offloading( offload_type (`str`, defaults to "block_level"): The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is "block_level". - offload_to_disk (`bool`, defaults to `False`): - If `True`, offload model parameters to disk instead of CPU RAM. This is useful when CPU memory is limited. - Requires `offload_path` to be set. - offload_path (`str`, *optional*): - The path to the directory where offloaded parameters will be stored when `offload_to_disk` is `True`. + offload_to_disk_path (`str`, *optional*): + The path to the directory where offloaded parameters will be stored. num_blocks_per_group (`int`, *optional*): The number of blocks per group when using offload_type="block_level". This is required when using offload_type="block_level". @@ -486,6 +478,8 @@ def apply_group_offloading( option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when the CPU memory is a bottleneck but may counteract the benefits of using streams. + (TODO: include example with `offload_to_disk_path`) + Example: ```python >>> from diffusers import CogVideoXTransformer3DModel @@ -514,8 +508,6 @@ def apply_group_offloading( stream = torch.Stream() else: raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.") - if offload_to_disk and offload_path is None: - raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") _raise_error_if_accelerate_model_or_sequential_hook_present(module) @@ -528,8 +520,7 @@ def apply_group_offloading( num_blocks_per_group=num_blocks_per_group, offload_device=offload_device, onload_device=onload_device, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, non_blocking=non_blocking, stream=stream, record_stream=record_stream, @@ -540,8 +531,7 @@ def apply_group_offloading( module=module, offload_device=offload_device, onload_device=onload_device, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, non_blocking=non_blocking, stream=stream, record_stream=record_stream, @@ -555,8 +545,7 @@ def _apply_group_offloading_block_level( module: torch.nn.Module, num_blocks_per_group: int, offload_device: torch.device, - offload_to_disk: bool, - offload_path: Optional[str], + offload_to_disk_path: Optional[str], onload_device: torch.device, non_blocking: bool, stream: Union[torch.cuda.Stream, torch.Stream, None] = None, @@ -572,6 +561,7 @@ def _apply_group_offloading_block_level( The module to which group offloading is applied. offload_device (`torch.device`): The device to which the group of modules are offloaded. This should typically be the CPU. + offload_to_disk_path: TODO onload_device (`torch.device`): The device to which the group of modules are onloaded. non_blocking (`bool`): @@ -611,8 +601,7 @@ def _apply_group_offloading_block_level( modules=current_modules, offload_device=offload_device, onload_device=onload_device, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, offload_leader=current_modules[-1], onload_leader=current_modules[0], non_blocking=non_blocking, @@ -645,8 +634,7 @@ def _apply_group_offloading_block_level( modules=unmatched_modules, offload_device=offload_device, onload_device=onload_device, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, offload_leader=module, onload_leader=module, parameters=parameters, @@ -666,8 +654,7 @@ def _apply_group_offloading_leaf_level( module: torch.nn.Module, offload_device: torch.device, onload_device: torch.device, - offload_to_disk: bool, - offload_path: Optional[str], + offload_to_disk_path: Optional[str], non_blocking: bool, stream: Union[torch.cuda.Stream, torch.Stream, None] = None, record_stream: Optional[bool] = False, @@ -686,6 +673,7 @@ def _apply_group_offloading_leaf_level( The device to which the group of modules are offloaded. This should typically be the CPU. onload_device (`torch.device`): The device to which the group of modules are onloaded. + offload_to_disk_path: TODO non_blocking (`bool`): If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation and data transfer. @@ -711,8 +699,7 @@ def _apply_group_offloading_leaf_level( modules=[submodule], offload_device=offload_device, onload_device=onload_device, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, offload_leader=submodule, onload_leader=submodule, non_blocking=non_blocking, @@ -759,8 +746,7 @@ def _apply_group_offloading_leaf_level( onload_device=onload_device, offload_leader=parent_module, onload_leader=parent_module, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, parameters=parameters, buffers=buffers, non_blocking=non_blocking, @@ -779,8 +765,7 @@ def _apply_group_offloading_leaf_level( modules=[], offload_device=offload_device, onload_device=onload_device, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, offload_leader=module, onload_leader=module, parameters=None, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index ce57d17ab0..c71a8b3b5a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -543,8 +543,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), offload_type: str = "block_level", - offload_to_disk: bool = False, - offload_path: Optional[str] = None, + offload_to_disk_path: Optional[str] = None, num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, @@ -599,8 +598,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): use_stream=use_stream, record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, - offload_to_disk=offload_to_disk, - offload_path=offload_path, + offload_to_disk_path=offload_to_disk_path, ) def save_pretrained(