1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Module Group Offloading (#10503)

* update

* fix

* non_blocking; handle parameters and buffers

* update

* Group offloading with cuda stream prefetching (#10516)

* cuda stream prefetch

* remove breakpoints

* update

* copy model hook implementation from pab

* update; ~very workaround based implementation but it seems to work as expected; needs cleanup and rewrite

* more workarounds to make it actually work

* cleanup

* rewrite

* update

* make sure to sync current stream before overwriting with pinned params

not doing so will lead to erroneous computations on the GPU and cause bad results

* better check

* update

* remove hook implementation to not deal with merge conflict

* re-add hook changes

* why use more memory when less memory do trick

* why still use slightly more memory when less memory do trick

* optimise

* add model tests

* add pipeline tests

* update docs

* add layernorm and groupnorm

* address review comments

* improve tests; add docs

* improve docs

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* apply suggestions from code review

* update tests

* apply suggestions from review

* enable_group_offloading -> enable_group_offload for naming consistency

* raise errors if multiple offloading strategies used; add relevant tests

* handle .to() when group offload applied

* refactor some repeated code

* remove unintentional change from merge conflict

* handle .cuda()

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Aryan
2025-02-14 12:59:45 +05:30
committed by GitHub
parent ab428207a7
commit 9a147b82f7
44 changed files with 1239 additions and 4 deletions

View File

@@ -45,3 +45,7 @@ Utility and helper functions for working with 🤗 Diffusers.
## apply_layerwise_casting
[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting
## apply_group_offloading
[[autodoc]] hooks.group_offloading.apply_group_offloading

View File

@@ -158,6 +158,46 @@ In order to properly offload models after they're called, it is required to run
</Tip>
## Group offloading
Group offloading is the middle ground between sequential and model offloading. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced.
To enable group offloading, call the [`~ModelMixin.enable_group_offload`] method on the model if it is a Diffusers model implementation. For any other model implementation, use [`~hooks.group_offloading.apply_group_offloading`]:
```python
import torch
from diffusers import CogVideoXPipeline
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video
# Load the pipeline
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
# We can utilize the enable_group_offload method for Diffusers model implementations
pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
# For any other model implementations, the apply_group_offloading function can be used
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
# This utilized about 14.79 GB. It can be further reduced by using tiling and using leaf_level offloading throughout the pipeline.
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
export_to_video(video, "output.mp4", fps=8)
```
Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams.
## FP8 layerwise weight-casting
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.

View File

@@ -2,6 +2,7 @@ from ..utils import is_torch_available
if is_torch_available():
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast

View File

@@ -0,0 +1,678 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from typing import Dict, List, Optional, Set, Tuple
import torch
from ..utils import get_logger, is_accelerate_available
from .hooks import HookRegistry, ModelHook
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload
from accelerate.utils import send_to_device
logger = get_logger(__name__) # pylint: disable=invalid-name
# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
torch.nn.Linear,
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
# because of double invocation of the same norm layer in CogVideoXLayerNorm
)
# fmt: on
class ModuleGroup:
def __init__(
self,
modules: List[torch.nn.Module],
offload_device: torch.device,
onload_device: torch.device,
offload_leader: torch.nn.Module,
onload_leader: Optional[torch.nn.Module] = None,
parameters: Optional[List[torch.nn.Parameter]] = None,
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
onload_self: bool = True,
) -> None:
self.modules = modules
self.offload_device = offload_device
self.onload_device = onload_device
self.offload_leader = offload_leader
self.onload_leader = onload_leader
self.parameters = parameters
self.buffers = buffers
self.non_blocking = non_blocking or stream is not None
self.stream = stream
self.cpu_param_dict = cpu_param_dict
self.onload_self = onload_self
if self.stream is not None and self.cpu_param_dict is None:
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
def onload_(self):
r"""Onloads the group of modules to the onload_device."""
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
with context:
for group_module in self.modules:
group_module.to(self.onload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.stream is not None:
torch.cuda.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
class GroupOffloadingHook(ModelHook):
r"""
A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for
computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader"
module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module
group is responsible for onloading the current module group.
"""
_is_stateful = False
def __init__(
self,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
) -> None:
self.group = group
self.next_group = next_group
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module:
self.group.offload_()
return module
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
# If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward
# method is the onload_leader of the group.
if self.group.onload_leader is None:
self.group.onload_leader = module
# If the current module is the onload_leader of the group, we onload the group if it is supposed
# to onload itself. In the case of using prefetching with streams, we onload the next group if
# it is not supposed to onload itself.
if self.group.onload_leader == module:
if self.group.onload_self:
self.group.onload_()
if self.next_group is not None and not self.next_group.onload_self:
self.next_group.onload_()
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
return args, kwargs
def post_forward(self, module: torch.nn.Module, output):
if self.group.offload_leader == module:
self.group.offload_()
return output
class LazyPrefetchGroupOffloadingHook(ModelHook):
r"""
A hook, used in conjuction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer
invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows
prefetching groups in the correct order.
"""
_is_stateful = False
def __init__(self):
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
self._layer_execution_tracker_module_names = set()
def initialize_hook(self, module):
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
# layers are executed during the forward pass.
for name, submodule in module.named_modules():
if name == "" or not hasattr(submodule, "_diffusers_hook"):
continue
registry = HookRegistry.check_if_exists_or_initialize(submodule)
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
if group_offloading_hook is not None:
def make_execution_order_update_callback(current_name, current_submodule):
def callback():
logger.debug(f"Adding {current_name} to the execution order")
self.execution_order.append((current_name, current_submodule))
return callback
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
self._layer_execution_tracker_module_names.add(name)
return module
def post_forward(self, module, output):
# At this point, for the current modules' submodules, we know the execution order of the layers. We can now
# remove the layer execution tracker hooks and apply prefetching by setting the next_group attribute for each
# group offloading hook.
num_executed = len(self.execution_order)
execution_order_module_names = {name for name, _ in self.execution_order}
# It may be possible that some layers were not executed during the forward pass. This can happen if the layer
# is not used in the forward pass, or if the layer is not executed due to some other reason. In such cases, we
# may not be able to apply prefetching in the correct order, which can lead to device-mismatch related errors
# if the missing layers end up being executed in the future.
if execution_order_module_names != self._layer_execution_tracker_module_names:
unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
logger.warning(
"It seems like some layers were not executed during the forward pass. This may lead to problems when "
"applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
"make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
f"{unexecuted_layers=}"
)
# Remove the layer execution tracker hooks from the submodules
base_module_registry = module._diffusers_hook
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
for i in range(num_executed):
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
# Apply lazy prefetching by setting required attributes
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
if num_executed > 0:
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
base_module_group_offloading_hook.next_group.onload_self = False
for i in range(num_executed - 1):
name1, _ = self.execution_order[i]
name2, _ = self.execution_order[i + 1]
logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
group_offloading_hooks[i].next_group.onload_self = False
return output
class LayerExecutionTrackerHook(ModelHook):
r"""
A hook that tracks the order in which the layers are executed during the forward pass by calling back to the
LazyPrefetchGroupOffloadingHook to update the execution order.
"""
_is_stateful = False
def __init__(self, execution_order_update_callback):
self.execution_order_update_callback = execution_order_update_callback
def pre_forward(self, module, *args, **kwargs):
self.execution_order_update_callback()
return args, kwargs
def apply_group_offloading(
module: torch.nn.Module,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
offload_type: str = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
where it is beneficial, we need to first provide some context on how other supported offloading methods work.
Typically, offloading is done at two levels:
- Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator device
when needed for computation. This method is more memory-efficient than keeping all components on the accelerator,
but the memory requirements are still quite high. For this method to work, one needs memory equivalent to size of
the model in runtime dtype + size of largest intermediate activation tensors to be able to complete the forward
pass.
- Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. It
works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
memory, but can be slower due to the excessive number of device synchronizations.
Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
(either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations is
reduced.
Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to
overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This
is enabled using layer prefetching with streams, i.e., the layer that is to be executed next starts onloading to
the accelerator device while the current layer is being executed - this increases the memory requirements slightly.
Note that this implementation also supports leaf-level offloading but can be made much faster when using streams.
Args:
module (`torch.nn.Module`):
The module to which group offloading is applied.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
offload_type (`str`, defaults to "block_level"):
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
"block_level".
num_blocks_per_group (`int`, *optional*):
The number of blocks per group when using offload_type="block_level". This is required when using
offload_type="block_level".
non_blocking (`bool`, defaults to `False`):
If True, offloading and onloading is done with non-blocking data transfer.
use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer.
Example:
```python
>>> from diffusers import CogVideoXTransformer3DModel
>>> from diffusers.hooks import apply_group_offloading
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
... )
>>> apply_group_offloading(
... transformer,
... onload_device=torch.device("cuda"),
... offload_device=torch.device("cpu"),
... offload_type="block_level",
... num_blocks_per_group=2,
... use_stream=True,
... )
```
"""
stream = None
if use_stream:
if torch.cuda.is_available():
stream = torch.cuda.Stream()
else:
raise ValueError("Using streams for data transfer requires a CUDA device.")
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
if offload_type == "block_level":
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
_apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")
def _apply_group_offloading_block_level(
module: torch.nn.Module,
num_blocks_per_group: int,
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
Args:
module (`torch.nn.Module`):
The module to which group offloading is applied.
offload_device (`torch.device`):
The device to which the group of modules are offloaded. This should typically be the CPU.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
unmatched_modules = []
matched_module_groups = []
for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
unmatched_modules.append((name, submodule))
modules_with_group_offloading.add(name)
continue
for i in range(0, len(submodule), num_blocks_per_group):
current_modules = submodule[i : i + num_blocks_per_group]
group = ModuleGroup(
modules=current_modules,
offload_device=offload_device,
onload_device=onload_device,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=stream is None,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
next_group = (
matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
)
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, next_group)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
# part of any group (as doing so would lead to no VRAM savings).
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
parameters = [param for _, param in parameters]
buffers = [buffer for _, buffer in buffers]
# Create a group for the unmatched submodules of the top-level module so that they are on the correct
# device when the forward pass is called.
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=offload_device,
onload_device=onload_device,
offload_leader=module,
onload_leader=module,
parameters=parameters,
buffers=buffers,
non_blocking=False,
stream=None,
cpu_param_dict=None,
onload_self=True,
)
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
_apply_group_offloading_hook(module, unmatched_group, next_group)
def _apply_group_offloading_leaf_level(
module: torch.nn.Module,
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
reduce memory usage without any performance degradation.
Args:
module (`torch.nn.Module`):
The module to which group offloading is applied.
offload_device (`torch.device`):
The device to which the group of modules are offloaded. This should typically be the CPU.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
for name, submodule in module.named_modules():
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
continue
group = ModuleGroup(
modules=[submodule],
offload_device=offload_device,
onload_device=onload_device,
offload_leader=submodule,
onload_leader=submodule,
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=True,
)
_apply_group_offloading_hook(submodule, group, None)
modules_with_group_offloading.add(name)
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
# of the module is called
module_dict = dict(module.named_modules())
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
# Find closest module parent for each parameter and buffer, and attach group hooks
parent_to_parameters = {}
for name, param in parameters:
parent_name = _find_parent_module_in_module_dict(name, module_dict)
if parent_name in parent_to_parameters:
parent_to_parameters[parent_name].append(param)
else:
parent_to_parameters[parent_name] = [param]
parent_to_buffers = {}
for name, buffer in buffers:
parent_name = _find_parent_module_in_module_dict(name, module_dict)
if parent_name in parent_to_buffers:
parent_to_buffers[parent_name].append(buffer)
else:
parent_to_buffers[parent_name] = [buffer]
parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
for name in parent_names:
parameters = parent_to_parameters.get(name, [])
buffers = parent_to_buffers.get(name, [])
parent_module = module_dict[name]
assert getattr(parent_module, "_diffusers_hook", None) is None
group = ModuleGroup(
modules=[],
offload_device=offload_device,
onload_device=onload_device,
offload_leader=parent_module,
onload_leader=parent_module,
parameters=parameters,
buffers=buffers,
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=True,
)
_apply_group_offloading_hook(parent_module, group, None)
if stream is not None:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
# execution order and apply prefetching in the correct order.
unmatched_group = ModuleGroup(
modules=[],
offload_device=offload_device,
onload_device=onload_device,
offload_leader=module,
onload_leader=module,
parameters=None,
buffers=None,
non_blocking=False,
stream=None,
cpu_param_dict=None,
onload_self=True,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
def _apply_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group)
registry.register_hook(hook, _GROUP_OFFLOADING)
def _apply_lazy_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group)
registry.register_hook(hook, _GROUP_OFFLOADING)
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
def _gather_parameters_with_no_group_offloading_parent(
module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> List[torch.nn.Parameter]:
parameters = []
for name, parameter in module.named_parameters():
has_parent_with_group_offloading = False
atoms = name.split(".")
while len(atoms) > 0:
parent_name = ".".join(atoms)
if parent_name in modules_with_group_offloading:
has_parent_with_group_offloading = True
break
atoms.pop()
if not has_parent_with_group_offloading:
parameters.append((name, parameter))
return parameters
def _gather_buffers_with_no_group_offloading_parent(
module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> List[torch.Tensor]:
buffers = []
for name, buffer in module.named_buffers():
has_parent_with_group_offloading = False
atoms = name.split(".")
while len(atoms) > 0:
parent_name = ".".join(atoms)
if parent_name in modules_with_group_offloading:
has_parent_with_group_offloading = True
break
atoms.pop()
if not has_parent_with_group_offloading:
buffers.append((name, buffer))
return buffers
def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str:
atoms = name.split(".")
while len(atoms) > 0:
parent_name = ".".join(atoms)
if parent_name in module_dict:
return parent_name
atoms.pop()
return ""
def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn.Module) -> None:
if not is_accelerate_available():
return
for name, submodule in module.named_modules():
if not hasattr(submodule, "_hf_hook"):
continue
if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)):
raise ValueError(
f"Cannot apply group offloading to a module that is already applying an alternative "
f"offloading strategy from Accelerate. If you want to apply group offloading, please "
f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})"
)
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
for submodule in module.modules():
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
return True
return False
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
for submodule in module.modules():
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
raise ValueError("Group offloading is not enabled for the provided module.")

View File

@@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = False
_supports_group_offloading = False
@register_to_config
def __init__(

View File

@@ -68,6 +68,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
```
"""
_supports_group_offloading = False
@register_to_config
def __init__(
self,

View File

@@ -72,6 +72,7 @@ class VQModel(ModelMixin, ConfigMixin):
"""
_skip_layerwise_casting_patterns = ["quantize"]
_supports_group_offloading = False
@register_to_config
def __init__(

View File

@@ -34,7 +34,7 @@ from torch import Tensor, nn
from typing_extensions import Self
from .. import __version__
from ..hooks import apply_layerwise_casting
from ..hooks import apply_group_offloading, apply_layerwise_casting
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
@@ -87,7 +87,17 @@ if is_accelerate_available():
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
from ..hooks.group_offloading import _get_group_onload_device
try:
# Try to get the onload device from the group offloading hook
return _get_group_onload_device(parameter)
except ValueError:
pass
try:
# If the onload device is not available due to no group offloading hooks, try to get the device
# from the first parameter or buffer
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
return next(parameters_and_buffers).device
except StopIteration:
@@ -166,6 +176,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_no_split_modules = None
_keep_in_fp32_modules = None
_skip_layerwise_casting_patterns = None
_supports_group_offloading = True
def __init__(self):
super().__init__()
@@ -437,6 +448,55 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
)
def enable_group_offload(
self,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
offload_type: str = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
) -> None:
r"""
Activates group offloading for the current model.
See [`~hooks.group_offloading.apply_group_offloading`] for more information.
Example:
```python
>>> from diffusers import CogVideoXTransformer3DModel
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
... )
>>> transformer.enable_group_offload(
... onload_device=torch.device("cuda"),
... offload_device=torch.device("cpu"),
... offload_type="leaf_level",
... use_stream=True,
... )
```
"""
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
msg = (
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
"forward pass is executed with tiling enabled. Please make sure to either:\n"
"1. Run a forward pass with small input shapes.\n"
"2. Or, run a forward pass with tiling disabled (can still use small dummy inputs)."
)
logger.warning(msg)
if not self._supports_group_offloading:
raise ValueError(
f"{self.__class__.__name__} does not support group offloading. Please make sure to set the boolean attribute "
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
f"open an issue at https://github.com/huggingface/diffusers/issues."
)
apply_group_offloading(
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
@@ -1170,6 +1230,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Adapted from `transformers`.
@wraps(torch.nn.Module.cuda)
def cuda(self, *args, **kwargs):
from ..hooks.group_offloading import _is_group_offload_enabled
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if getattr(self, "is_loaded_in_8bit", False):
@@ -1182,13 +1244,34 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
# Checks if group offloading is enabled
if _is_group_offload_enabled(self):
logger.warning(
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.cuda()` is not supported."
)
return self
return super().cuda(*args, **kwargs)
# Adapted from `transformers`.
@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs):
from ..hooks.group_offloading import _is_group_offload_enabled
device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
dtype_present_in_args = "dtype" in kwargs
# Try converting arguments to torch.device in case they are passed as strings
for arg in args:
if not isinstance(arg, str):
continue
try:
torch.device(arg)
device_arg_or_kwarg_present = True
except RuntimeError:
pass
if not dtype_present_in_args:
for arg in args:
if isinstance(arg, torch.dtype):
@@ -1213,6 +1296,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
logger.warning(
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."
)
return self
return super().to(*args, **kwargs)
# Taken from `transformers`.

View File

@@ -66,6 +66,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_supports_gradient_checkpointing = True
_supports_group_offloading = False
@register_to_config
def __init__(

View File

@@ -245,6 +245,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
"""
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"]
_supports_group_offloading = False
@register_to_config
def __init__(

View File

@@ -394,6 +394,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
)
device = device or device_arg
device_type = torch.device(device).type if device is not None else None
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
@@ -424,7 +425,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
)
if device and torch.device(device).type == "cuda":
if device_type == "cuda":
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
@@ -437,7 +438,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
if pipeline_is_offloaded and device_type == "cuda":
logger.warning(
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)
@@ -449,6 +450,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
for module in modules:
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
is_group_offloaded = self._maybe_raise_error_if_group_offload_active(module=module)
if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
logger.warning(
@@ -460,11 +462,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
)
# Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
# components can be from outside diffusers too, but still have group offloading enabled.
if (
self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module)
and device is not None
):
logger.warning(
f"The module '{module.__class__.__name__}' is group offloaded and moving it to {device} via `.to()` is not supported."
)
# This can happen for `transformer` models. CPU placement was added in
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
module.to(device=device)
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
module.to(device, dtype)
if (
@@ -1023,6 +1035,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
Accelerate's module hooks.
"""
from ..hooks.group_offloading import _get_group_onload_device
# When apply group offloading at the leaf_level, we're in the same situation as accelerate's sequential
# offloading. We need to return the onload device of the group offloading hooks so that the intermediates
# required for computation (latents, prompt embeddings, etc.) can be created on the correct device.
for name, model in self.components.items():
if not isinstance(model, torch.nn.Module):
continue
try:
return _get_group_onload_device(model)
except ValueError:
pass
for name, model in self.components.items():
if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
continue
@@ -1061,6 +1086,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
@@ -1172,6 +1199,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
from accelerate import cpu_offload
else:
@@ -1896,6 +1925,24 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
return new_pipeline
def _maybe_raise_error_if_group_offload_active(
self, raise_error: bool = False, module: Optional[torch.nn.Module] = None
) -> bool:
from ..hooks.group_offloading import _is_group_offload_enabled
components = self.components.values() if module is None else [module]
components = [component for component in components if isinstance(component, torch.nn.Module)]
for component in components:
if _is_group_offload_enabled(component):
if raise_error:
raise ValueError(
"You are trying to apply model/sequential CPU offloading to a pipeline that contains components "
"with group offloading enabled. This is not supported. Please disable group offloading for "
"components of the pipeline to use other offloading methods."
)
return True
return False
class StableDiffusionMixin:
r"""

View File

@@ -0,0 +1,214 @@
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import torch
from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import get_logger
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
class DummyBlock(torch.nn.Module):
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
super().__init__()
self.proj_in = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.proj_out = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj_in(x)
x = self.activation(x)
x = self.proj_out(x)
return x
class DummyModel(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.blocks = torch.nn.ModuleList(
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
)
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = self.activation(x)
for block in self.blocks:
x = block(x)
x = self.linear_2(x)
return x
class DummyPipeline(DiffusionPipeline):
model_cpu_offload_seq = "model"
def __init__(self, model: torch.nn.Module) -> None:
super().__init__()
self.register_modules(model=model)
def __call__(self, x: torch.Tensor) -> torch.Tensor:
for _ in range(2):
x = x + 0.1 * self.model(x)
return x
@require_torch_gpu
class GroupOffloadTests(unittest.TestCase):
in_features = 64
hidden_features = 256
out_features = 64
num_layers = 4
def setUp(self):
with torch.no_grad():
self.model = self.get_model()
self.input = torch.randn((4, self.in_features)).to(torch_device)
def tearDown(self):
super().tearDown()
del self.model
del self.input
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
def get_model(self):
torch.manual_seed(0)
return DummyModel(
in_features=self.in_features,
hidden_features=self.hidden_features,
out_features=self.out_features,
num_layers=self.num_layers,
)
def test_offloading_forward_pass(self):
@torch.no_grad()
def run_forward(model):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
)
)
model.eval()
output = model(self.input)[0].cpu()
max_memory_allocated = torch.cuda.max_memory_allocated()
return output, max_memory_allocated
self.model.to(torch_device)
output_without_group_offloading, mem_baseline = run_forward(self.model)
self.model.to("cpu")
model = self.get_model()
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
output_with_group_offloading1, mem1 = run_forward(model)
model = self.get_model()
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
output_with_group_offloading2, mem2 = run_forward(model)
model = self.get_model()
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
output_with_group_offloading3, mem3 = run_forward(model)
model = self.get_model()
model.enable_group_offload(torch_device, offload_type="leaf_level")
output_with_group_offloading4, mem4 = run_forward(model)
model = self.get_model()
model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True)
output_with_group_offloading5, mem5 = run_forward(model)
# Precision assertions - offloading should not impact the output
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5))
# Memory assertions - offloading should reduce memory usage
self.assertTrue(mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline)
def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self):
if torch.device(torch_device).type != "cuda":
return
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
logger = get_logger("diffusers.models.modeling_utils")
logger.setLevel("INFO")
with self.assertLogs(logger, level="WARNING") as cm:
self.model.to(torch_device)
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self):
if torch.device(torch_device).type != "cuda":
return
pipe = DummyPipeline(self.model)
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
logger = get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel("INFO")
with self.assertLogs(logger, level="WARNING") as cm:
pipe.to(torch_device)
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
def test_error_raised_if_streams_used_and_no_cuda_device(self):
original_is_available = torch.cuda.is_available
torch.cuda.is_available = lambda: False
with self.assertRaises(ValueError):
self.model.enable_group_offload(
onload_device=torch.device("cuda"), offload_type="leaf_level", use_stream=True
)
torch.cuda.is_available = original_is_available
def test_error_raised_if_supports_group_offloading_false(self):
self.model._supports_group_offloading = False
with self.assertRaisesRegex(ValueError, "does not support group offloading"):
self.model.enable_group_offload(onload_device=torch.device("cuda"))
def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
pipe.enable_model_cpu_offload()
def test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
pipe.enable_sequential_cpu_offload()
def test_error_raised_if_group_offloading_applied_on_model_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.enable_model_cpu_offload()
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.enable_sequential_cpu_offload()
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)

View File

@@ -1458,6 +1458,55 @@ class ModelTesterMixin:
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
)
@require_torch_gpu
def test_group_offloading(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
@torch.no_grad()
def run_forward(model):
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
)
)
model.eval()
return model(**inputs_dict)[0]
model = self.model_class(**init_dict)
if not getattr(model, "_supports_group_offloading", True):
return
model.to(torch_device)
output_without_group_offloading = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
output_with_group_offloading1 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
output_with_group_offloading2 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level")
output_with_group_offloading3 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True)
output_with_group_offloading4 = run_forward(model)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):

View File

@@ -58,6 +58,7 @@ class AllegroPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTes
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)

View File

@@ -39,6 +39,7 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -61,6 +61,7 @@ class AnimateDiffPipelineFastTests(
]
)
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
cross_attention_dim = 8

View File

@@ -31,6 +31,7 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
)
batch_params = frozenset(["prompt", "negative_prompt"])
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -60,6 +60,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastT
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)

View File

@@ -56,6 +56,7 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -57,6 +57,7 @@ class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -59,6 +59,7 @@ class ConsisIDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -127,6 +127,7 @@ class ControlNetPipelineFastTests(
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)

View File

@@ -76,6 +76,7 @@ class StableDiffusionXLControlNetPipelineFastTests(
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)

View File

@@ -51,6 +51,7 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -60,6 +60,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
)
batch_params = frozenset(["prompt", "negative_prompt"])
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(
self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm", use_dual_attention=False

View File

@@ -140,6 +140,7 @@ class ControlNetXSPipelineFastTests(
test_attention_slicing = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)

View File

@@ -79,6 +79,7 @@ class StableDiffusionXLControlNetXSPipelineFastTests(
test_attention_slicing = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -35,6 +35,7 @@ class FluxPipelineFastTests(
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)

View File

@@ -23,6 +23,7 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -24,6 +24,7 @@ class FluxFillPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
batch_params = frozenset(["prompt"])
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -54,6 +54,7 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)

View File

@@ -54,6 +54,7 @@ class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTeste
required_optional_params = PipelineTesterMixin.required_optional_params
test_layerwise_casting = True
test_group_offloading = True
pab_config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,

View File

@@ -47,6 +47,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -33,6 +33,7 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM
supports_dduf = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -56,6 +56,7 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -56,6 +56,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr
]
)
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
cross_attention_dim = 8

View File

@@ -51,6 +51,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
required_optional_params = PipelineTesterMixin.required_optional_params
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -56,6 +56,7 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
required_optional_params = PipelineTesterMixin.required_optional_params
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -53,6 +53,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -124,6 +124,7 @@ class StableDiffusionPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None):
cross_attention_dim = 8

View File

@@ -76,6 +76,7 @@ class StableDiffusion2PipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -36,6 +36,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
)
batch_params = frozenset(["prompt", "negative_prompt"])
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -76,6 +76,7 @@ class StableDiffusionXLPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)

View File

@@ -29,6 +29,7 @@ from diffusers import (
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
@@ -47,6 +48,7 @@ from diffusers.utils.testing_utils import (
require_accelerator,
require_hf_hub_version_greater,
require_torch,
require_torch_gpu,
require_transformers_version_greater,
skip_mps,
torch_device,
@@ -990,6 +992,7 @@ class PipelineTesterMixin:
test_xformers_attention = True
test_layerwise_casting = False
test_group_offloading = False
supports_dduf = True
def get_generator(self, seed):
@@ -2044,6 +2047,79 @@ class PipelineTesterMixin:
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)[0]
@require_torch_gpu
def test_group_offloading_inference(self):
if not self.test_group_offloading:
return
def create_pipe():
torch.manual_seed(0)
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
return pipe
def enable_group_offload_on_component(pipe, group_offloading_kwargs):
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
# tiling is enabled and a forward pass is run, when cuda streams are used, the execution order of
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
# warmup forward pass (even with dummy small inputs) is recommended.
for component_name in [
"text_encoder",
"text_encoder_2",
"text_encoder_3",
"transformer",
"unet",
"controlnet",
]:
if not hasattr(pipe, component_name):
continue
component = getattr(pipe, component_name)
if not getattr(component, "_supports_group_offloading", True):
continue
if hasattr(component, "enable_group_offload"):
# For diffusers ModelMixin implementations
component.enable_group_offload(torch.device(torch_device), **group_offloading_kwargs)
else:
# For other models not part of diffusers
apply_group_offloading(
component, onload_device=torch.device(torch_device), **group_offloading_kwargs
)
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in component.modules()
if hasattr(module, "_diffusers_hook")
)
)
for component_name in ["vae", "vqvae"]:
if hasattr(pipe, component_name):
getattr(pipe, component_name).to(torch_device)
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
return pipe(**inputs)[0]
pipe = create_pipe().to(torch_device)
output_without_group_offloading = run_forward(pipe)
pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1})
output_with_group_offloading1 = run_forward(pipe)
pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "leaf_level"})
output_with_group_offloading2 = run_forward(pipe)
if torch.is_tensor(output_without_group_offloading):
output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
output_with_group_offloading1 = output_with_group_offloading1.detach().cpu().numpy()
output_with_group_offloading2 = output_with_group_offloading2.detach().cpu().numpy()
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))
@is_staging_test
class PipelinePushToHubTester(unittest.TestCase):