mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Module Group Offloading (#10503)
* update * fix * non_blocking; handle parameters and buffers * update * Group offloading with cuda stream prefetching (#10516) * cuda stream prefetch * remove breakpoints * update * copy model hook implementation from pab * update; ~very workaround based implementation but it seems to work as expected; needs cleanup and rewrite * more workarounds to make it actually work * cleanup * rewrite * update * make sure to sync current stream before overwriting with pinned params not doing so will lead to erroneous computations on the GPU and cause bad results * better check * update * remove hook implementation to not deal with merge conflict * re-add hook changes * why use more memory when less memory do trick * why still use slightly more memory when less memory do trick * optimise * add model tests * add pipeline tests * update docs * add layernorm and groupnorm * address review comments * improve tests; add docs * improve docs * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * apply suggestions from code review * update tests * apply suggestions from review * enable_group_offloading -> enable_group_offload for naming consistency * raise errors if multiple offloading strategies used; add relevant tests * handle .to() when group offload applied * refactor some repeated code * remove unintentional change from merge conflict * handle .cuda() --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -45,3 +45,7 @@ Utility and helper functions for working with 🤗 Diffusers.
|
||||
## apply_layerwise_casting
|
||||
|
||||
[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting
|
||||
|
||||
## apply_group_offloading
|
||||
|
||||
[[autodoc]] hooks.group_offloading.apply_group_offloading
|
||||
|
||||
@@ -158,6 +158,46 @@ In order to properly offload models after they're called, it is required to run
|
||||
|
||||
</Tip>
|
||||
|
||||
## Group offloading
|
||||
|
||||
Group offloading is the middle ground between sequential and model offloading. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced.
|
||||
|
||||
To enable group offloading, call the [`~ModelMixin.enable_group_offload`] method on the model if it is a Diffusers model implementation. For any other model implementation, use [`~hooks.group_offloading.apply_group_offloading`]:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
# Load the pipeline
|
||||
onload_device = torch.device("cuda")
|
||||
offload_device = torch.device("cpu")
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
|
||||
# We can utilize the enable_group_offload method for Diffusers model implementations
|
||||
pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
|
||||
|
||||
# For any other model implementations, the apply_group_offloading function can be used
|
||||
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
|
||||
apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level")
|
||||
|
||||
prompt = (
|
||||
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
||||
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
||||
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
|
||||
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
|
||||
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
||||
"atmosphere of this unique musical performance."
|
||||
)
|
||||
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
# This utilized about 14.79 GB. It can be further reduced by using tiling and using leaf_level offloading throughout the pipeline.
|
||||
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
|
||||
export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams.
|
||||
|
||||
## FP8 layerwise weight-casting
|
||||
|
||||
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
|
||||
|
||||
@@ -2,6 +2,7 @@ from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .group_offloading import apply_group_offloading
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
|
||||
678
src/diffusers/hooks/group_offloading.py
Normal file
678
src/diffusers/hooks/group_offloading.py
Normal file
@@ -0,0 +1,678 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger, is_accelerate_available
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload
|
||||
from accelerate.utils import send_to_device
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# fmt: off
|
||||
_GROUP_OFFLOADING = "group_offloading"
|
||||
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
||||
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
|
||||
|
||||
_SUPPORTED_PYTORCH_LAYERS = (
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
||||
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
||||
torch.nn.Linear,
|
||||
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
|
||||
# because of double invocation of the same norm layer in CogVideoXLayerNorm
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ModuleGroup:
|
||||
def __init__(
|
||||
self,
|
||||
modules: List[torch.nn.Module],
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
offload_leader: torch.nn.Module,
|
||||
onload_leader: Optional[torch.nn.Module] = None,
|
||||
parameters: Optional[List[torch.nn.Parameter]] = None,
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
|
||||
onload_self: bool = True,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
self.offload_device = offload_device
|
||||
self.onload_device = onload_device
|
||||
self.offload_leader = offload_leader
|
||||
self.onload_leader = onload_leader
|
||||
self.parameters = parameters
|
||||
self.buffers = buffers
|
||||
self.non_blocking = non_blocking or stream is not None
|
||||
self.stream = stream
|
||||
self.cpu_param_dict = cpu_param_dict
|
||||
self.onload_self = onload_self
|
||||
|
||||
if self.stream is not None and self.cpu_param_dict is None:
|
||||
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
|
||||
|
||||
def onload_(self):
|
||||
r"""Onloads the group of modules to the onload_device."""
|
||||
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
|
||||
|
||||
class GroupOffloadingHook(ModelHook):
|
||||
r"""
|
||||
A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for
|
||||
computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader"
|
||||
module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module
|
||||
group is responsible for onloading the current module group.
|
||||
"""
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
) -> None:
|
||||
self.group = group
|
||||
self.next_group = next_group
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if self.group.offload_leader == module:
|
||||
self.group.offload_()
|
||||
return module
|
||||
|
||||
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
# If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward
|
||||
# method is the onload_leader of the group.
|
||||
if self.group.onload_leader is None:
|
||||
self.group.onload_leader = module
|
||||
|
||||
# If the current module is the onload_leader of the group, we onload the group if it is supposed
|
||||
# to onload itself. In the case of using prefetching with streams, we onload the next group if
|
||||
# it is not supposed to onload itself.
|
||||
if self.group.onload_leader == module:
|
||||
if self.group.onload_self:
|
||||
self.group.onload_()
|
||||
if self.next_group is not None and not self.next_group.onload_self:
|
||||
self.next_group.onload_()
|
||||
|
||||
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
|
||||
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module: torch.nn.Module, output):
|
||||
if self.group.offload_leader == module:
|
||||
self.group.offload_()
|
||||
return output
|
||||
|
||||
|
||||
class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
r"""
|
||||
A hook, used in conjuction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
|
||||
This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer
|
||||
invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows
|
||||
prefetching groups in the correct order.
|
||||
"""
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self):
|
||||
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
|
||||
self._layer_execution_tracker_module_names = set()
|
||||
|
||||
def initialize_hook(self, module):
|
||||
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
|
||||
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
|
||||
# layers are executed during the forward pass.
|
||||
for name, submodule in module.named_modules():
|
||||
if name == "" or not hasattr(submodule, "_diffusers_hook"):
|
||||
continue
|
||||
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
|
||||
|
||||
if group_offloading_hook is not None:
|
||||
|
||||
def make_execution_order_update_callback(current_name, current_submodule):
|
||||
def callback():
|
||||
logger.debug(f"Adding {current_name} to the execution order")
|
||||
self.execution_order.append((current_name, current_submodule))
|
||||
|
||||
return callback
|
||||
|
||||
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
|
||||
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
|
||||
self._layer_execution_tracker_module_names.add(name)
|
||||
|
||||
return module
|
||||
|
||||
def post_forward(self, module, output):
|
||||
# At this point, for the current modules' submodules, we know the execution order of the layers. We can now
|
||||
# remove the layer execution tracker hooks and apply prefetching by setting the next_group attribute for each
|
||||
# group offloading hook.
|
||||
num_executed = len(self.execution_order)
|
||||
execution_order_module_names = {name for name, _ in self.execution_order}
|
||||
|
||||
# It may be possible that some layers were not executed during the forward pass. This can happen if the layer
|
||||
# is not used in the forward pass, or if the layer is not executed due to some other reason. In such cases, we
|
||||
# may not be able to apply prefetching in the correct order, which can lead to device-mismatch related errors
|
||||
# if the missing layers end up being executed in the future.
|
||||
if execution_order_module_names != self._layer_execution_tracker_module_names:
|
||||
unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
|
||||
logger.warning(
|
||||
"It seems like some layers were not executed during the forward pass. This may lead to problems when "
|
||||
"applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
|
||||
"make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
|
||||
f"{unexecuted_layers=}"
|
||||
)
|
||||
|
||||
# Remove the layer execution tracker hooks from the submodules
|
||||
base_module_registry = module._diffusers_hook
|
||||
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
|
||||
|
||||
for i in range(num_executed):
|
||||
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
|
||||
|
||||
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
|
||||
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
|
||||
|
||||
# Apply lazy prefetching by setting required attributes
|
||||
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
|
||||
if num_executed > 0:
|
||||
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
|
||||
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
|
||||
base_module_group_offloading_hook.next_group.onload_self = False
|
||||
|
||||
for i in range(num_executed - 1):
|
||||
name1, _ = self.execution_order[i]
|
||||
name2, _ = self.execution_order[i + 1]
|
||||
logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
|
||||
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
|
||||
group_offloading_hooks[i].next_group.onload_self = False
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class LayerExecutionTrackerHook(ModelHook):
|
||||
r"""
|
||||
A hook that tracks the order in which the layers are executed during the forward pass by calling back to the
|
||||
LazyPrefetchGroupOffloadingHook to update the execution order.
|
||||
"""
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self, execution_order_update_callback):
|
||||
self.execution_order_update_callback = execution_order_update_callback
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
self.execution_order_update_callback()
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def apply_group_offloading(
|
||||
module: torch.nn.Module,
|
||||
onload_device: torch.device,
|
||||
offload_device: torch.device = torch.device("cpu"),
|
||||
offload_type: str = "block_level",
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
|
||||
where it is beneficial, we need to first provide some context on how other supported offloading methods work.
|
||||
|
||||
Typically, offloading is done at two levels:
|
||||
- Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
|
||||
works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator device
|
||||
when needed for computation. This method is more memory-efficient than keeping all components on the accelerator,
|
||||
but the memory requirements are still quite high. For this method to work, one needs memory equivalent to size of
|
||||
the model in runtime dtype + size of largest intermediate activation tensors to be able to complete the forward
|
||||
pass.
|
||||
- Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. It
|
||||
works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
|
||||
onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
|
||||
memory, but can be slower due to the excessive number of device synchronizations.
|
||||
|
||||
Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
|
||||
(either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
|
||||
offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations is
|
||||
reduced.
|
||||
|
||||
Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to
|
||||
overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This
|
||||
is enabled using layer prefetching with streams, i.e., the layer that is to be executed next starts onloading to
|
||||
the accelerator device while the current layer is being executed - this increases the memory requirements slightly.
|
||||
Note that this implementation also supports leaf-level offloading but can be made much faster when using streams.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to which group offloading is applied.
|
||||
onload_device (`torch.device`):
|
||||
The device to which the group of modules are onloaded.
|
||||
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
|
||||
offload_type (`str`, defaults to "block_level"):
|
||||
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
|
||||
"block_level".
|
||||
num_blocks_per_group (`int`, *optional*):
|
||||
The number of blocks per group when using offload_type="block_level". This is required when using
|
||||
offload_type="block_level".
|
||||
non_blocking (`bool`, defaults to `False`):
|
||||
If True, offloading and onloading is done with non-blocking data transfer.
|
||||
use_stream (`bool`, defaults to `False`):
|
||||
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
|
||||
overlapping computation and data transfer.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from diffusers import CogVideoXTransformer3DModel
|
||||
>>> from diffusers.hooks import apply_group_offloading
|
||||
|
||||
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
|
||||
>>> apply_group_offloading(
|
||||
... transformer,
|
||||
... onload_device=torch.device("cuda"),
|
||||
... offload_device=torch.device("cpu"),
|
||||
... offload_type="block_level",
|
||||
... num_blocks_per_group=2,
|
||||
... use_stream=True,
|
||||
... )
|
||||
```
|
||||
"""
|
||||
|
||||
stream = None
|
||||
if use_stream:
|
||||
if torch.cuda.is_available():
|
||||
stream = torch.cuda.Stream()
|
||||
else:
|
||||
raise ValueError("Using streams for data transfer requires a CUDA device.")
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
if offload_type == "block_level":
|
||||
if num_blocks_per_group is None:
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
|
||||
_apply_group_offloading_block_level(
|
||||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
|
||||
else:
|
||||
raise ValueError(f"Unsupported offload_type: {offload_type}")
|
||||
|
||||
|
||||
def _apply_group_offloading_block_level(
|
||||
module: torch.nn.Module,
|
||||
num_blocks_per_group: int,
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
||||
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to which group offloading is applied.
|
||||
offload_device (`torch.device`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU.
|
||||
onload_device (`torch.device`):
|
||||
The device to which the group of modules are onloaded.
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
unmatched_modules = []
|
||||
matched_module_groups = []
|
||||
for name, submodule in module.named_children():
|
||||
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
||||
unmatched_modules.append((name, submodule))
|
||||
modules_with_group_offloading.add(name)
|
||||
continue
|
||||
|
||||
for i in range(0, len(submodule), num_blocks_per_group):
|
||||
current_modules = submodule[i : i + num_blocks_per_group]
|
||||
group = ModuleGroup(
|
||||
modules=current_modules,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_leader=current_modules[-1],
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=stream is None,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
for j in range(i, i + len(current_modules)):
|
||||
modules_with_group_offloading.add(f"{name}.{j}")
|
||||
|
||||
# Apply group offloading hooks to the module groups
|
||||
for i, group in enumerate(matched_module_groups):
|
||||
next_group = (
|
||||
matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
|
||||
)
|
||||
|
||||
for group_module in group.modules:
|
||||
_apply_group_offloading_hook(group_module, group, next_group)
|
||||
|
||||
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
|
||||
# when the forward pass of this module is called. This is because the top-level module is not
|
||||
# part of any group (as doing so would lead to no VRAM savings).
|
||||
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
||||
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
||||
parameters = [param for _, param in parameters]
|
||||
buffers = [buffer for _, buffer in buffers]
|
||||
|
||||
# Create a group for the unmatched submodules of the top-level module so that they are on the correct
|
||||
# device when the forward pass is called.
|
||||
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=unmatched_modules,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=parameters,
|
||||
buffers=buffers,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
|
||||
_apply_group_offloading_hook(module, unmatched_group, next_group)
|
||||
|
||||
|
||||
def _apply_group_offloading_leaf_level(
|
||||
module: torch.nn.Module,
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
||||
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
|
||||
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
|
||||
reduce memory usage without any performance degradation.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to which group offloading is applied.
|
||||
offload_device (`torch.device`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU.
|
||||
onload_device (`torch.device`):
|
||||
The device to which the group of modules are onloaded.
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
for name, submodule in module.named_modules():
|
||||
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
|
||||
continue
|
||||
group = ModuleGroup(
|
||||
modules=[submodule],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_leader=submodule,
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
modules_with_group_offloading.add(name)
|
||||
|
||||
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
|
||||
# of the module is called
|
||||
module_dict = dict(module.named_modules())
|
||||
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
||||
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
||||
|
||||
# Find closest module parent for each parameter and buffer, and attach group hooks
|
||||
parent_to_parameters = {}
|
||||
for name, param in parameters:
|
||||
parent_name = _find_parent_module_in_module_dict(name, module_dict)
|
||||
if parent_name in parent_to_parameters:
|
||||
parent_to_parameters[parent_name].append(param)
|
||||
else:
|
||||
parent_to_parameters[parent_name] = [param]
|
||||
|
||||
parent_to_buffers = {}
|
||||
for name, buffer in buffers:
|
||||
parent_name = _find_parent_module_in_module_dict(name, module_dict)
|
||||
if parent_name in parent_to_buffers:
|
||||
parent_to_buffers[parent_name].append(buffer)
|
||||
else:
|
||||
parent_to_buffers[parent_name] = [buffer]
|
||||
|
||||
parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
|
||||
for name in parent_names:
|
||||
parameters = parent_to_parameters.get(name, [])
|
||||
buffers = parent_to_buffers.get(name, [])
|
||||
parent_module = module_dict[name]
|
||||
assert getattr(parent_module, "_diffusers_hook", None) is None
|
||||
group = ModuleGroup(
|
||||
modules=[],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_leader=parent_module,
|
||||
onload_leader=parent_module,
|
||||
parameters=parameters,
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
|
||||
if stream is not None:
|
||||
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
|
||||
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
|
||||
# execution order and apply prefetching in the correct order.
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=[],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=None,
|
||||
buffers=None,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
|
||||
def _apply_group_offloading_hook(
|
||||
module: torch.nn.Module,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||||
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||||
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||||
hook = GroupOffloadingHook(group, next_group)
|
||||
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||||
|
||||
|
||||
def _apply_lazy_group_offloading_hook(
|
||||
module: torch.nn.Module,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||||
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||||
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||||
hook = GroupOffloadingHook(group, next_group)
|
||||
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||||
|
||||
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
|
||||
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
|
||||
|
||||
|
||||
def _gather_parameters_with_no_group_offloading_parent(
|
||||
module: torch.nn.Module, modules_with_group_offloading: Set[str]
|
||||
) -> List[torch.nn.Parameter]:
|
||||
parameters = []
|
||||
for name, parameter in module.named_parameters():
|
||||
has_parent_with_group_offloading = False
|
||||
atoms = name.split(".")
|
||||
while len(atoms) > 0:
|
||||
parent_name = ".".join(atoms)
|
||||
if parent_name in modules_with_group_offloading:
|
||||
has_parent_with_group_offloading = True
|
||||
break
|
||||
atoms.pop()
|
||||
if not has_parent_with_group_offloading:
|
||||
parameters.append((name, parameter))
|
||||
return parameters
|
||||
|
||||
|
||||
def _gather_buffers_with_no_group_offloading_parent(
|
||||
module: torch.nn.Module, modules_with_group_offloading: Set[str]
|
||||
) -> List[torch.Tensor]:
|
||||
buffers = []
|
||||
for name, buffer in module.named_buffers():
|
||||
has_parent_with_group_offloading = False
|
||||
atoms = name.split(".")
|
||||
while len(atoms) > 0:
|
||||
parent_name = ".".join(atoms)
|
||||
if parent_name in modules_with_group_offloading:
|
||||
has_parent_with_group_offloading = True
|
||||
break
|
||||
atoms.pop()
|
||||
if not has_parent_with_group_offloading:
|
||||
buffers.append((name, buffer))
|
||||
return buffers
|
||||
|
||||
|
||||
def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str:
|
||||
atoms = name.split(".")
|
||||
while len(atoms) > 0:
|
||||
parent_name = ".".join(atoms)
|
||||
if parent_name in module_dict:
|
||||
return parent_name
|
||||
atoms.pop()
|
||||
return ""
|
||||
|
||||
|
||||
def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn.Module) -> None:
|
||||
if not is_accelerate_available():
|
||||
return
|
||||
for name, submodule in module.named_modules():
|
||||
if not hasattr(submodule, "_hf_hook"):
|
||||
continue
|
||||
if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)):
|
||||
raise ValueError(
|
||||
f"Cannot apply group offloading to a module that is already applying an alternative "
|
||||
f"offloading strategy from Accelerate. If you want to apply group offloading, please "
|
||||
f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})"
|
||||
)
|
||||
|
||||
|
||||
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
|
||||
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
|
||||
raise ValueError("Group offloading is not enabled for the provided module.")
|
||||
@@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
_supports_group_offloading = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -68,6 +68,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
```
|
||||
"""
|
||||
|
||||
_supports_group_offloading = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -72,6 +72,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_skip_layerwise_casting_patterns = ["quantize"]
|
||||
_supports_group_offloading = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -34,7 +34,7 @@ from torch import Tensor, nn
|
||||
from typing_extensions import Self
|
||||
|
||||
from .. import __version__
|
||||
from ..hooks import apply_layerwise_casting
|
||||
from ..hooks import apply_group_offloading, apply_layerwise_casting
|
||||
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
|
||||
from ..quantizers.quantization_config import QuantizationMethod
|
||||
from ..utils import (
|
||||
@@ -87,7 +87,17 @@ if is_accelerate_available():
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
||||
from ..hooks.group_offloading import _get_group_onload_device
|
||||
|
||||
try:
|
||||
# Try to get the onload device from the group offloading hook
|
||||
return _get_group_onload_device(parameter)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
# If the onload device is not available due to no group offloading hooks, try to get the device
|
||||
# from the first parameter or buffer
|
||||
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
||||
return next(parameters_and_buffers).device
|
||||
except StopIteration:
|
||||
@@ -166,6 +176,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
_no_split_modules = None
|
||||
_keep_in_fp32_modules = None
|
||||
_skip_layerwise_casting_patterns = None
|
||||
_supports_group_offloading = True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -437,6 +448,55 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
|
||||
)
|
||||
|
||||
def enable_group_offload(
|
||||
self,
|
||||
onload_device: torch.device,
|
||||
offload_device: torch.device = torch.device("cpu"),
|
||||
offload_type: str = "block_level",
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Activates group offloading for the current model.
|
||||
|
||||
See [`~hooks.group_offloading.apply_group_offloading`] for more information.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from diffusers import CogVideoXTransformer3DModel
|
||||
|
||||
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
|
||||
>>> transformer.enable_group_offload(
|
||||
... onload_device=torch.device("cuda"),
|
||||
... offload_device=torch.device("cpu"),
|
||||
... offload_type="leaf_level",
|
||||
... use_stream=True,
|
||||
... )
|
||||
```
|
||||
"""
|
||||
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
|
||||
msg = (
|
||||
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
|
||||
"forward pass is executed with tiling enabled. Please make sure to either:\n"
|
||||
"1. Run a forward pass with small input shapes.\n"
|
||||
"2. Or, run a forward pass with tiling disabled (can still use small dummy inputs)."
|
||||
)
|
||||
logger.warning(msg)
|
||||
if not self._supports_group_offloading:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support group offloading. Please make sure to set the boolean attribute "
|
||||
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
|
||||
f"open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
apply_group_offloading(
|
||||
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
|
||||
)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
@@ -1170,6 +1230,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# Adapted from `transformers`.
|
||||
@wraps(torch.nn.Module.cuda)
|
||||
def cuda(self, *args, **kwargs):
|
||||
from ..hooks.group_offloading import _is_group_offload_enabled
|
||||
|
||||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
@@ -1182,13 +1244,34 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
||||
)
|
||||
|
||||
# Checks if group offloading is enabled
|
||||
if _is_group_offload_enabled(self):
|
||||
logger.warning(
|
||||
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.cuda()` is not supported."
|
||||
)
|
||||
return self
|
||||
|
||||
return super().cuda(*args, **kwargs)
|
||||
|
||||
# Adapted from `transformers`.
|
||||
@wraps(torch.nn.Module.to)
|
||||
def to(self, *args, **kwargs):
|
||||
from ..hooks.group_offloading import _is_group_offload_enabled
|
||||
|
||||
device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
|
||||
dtype_present_in_args = "dtype" in kwargs
|
||||
|
||||
# Try converting arguments to torch.device in case they are passed as strings
|
||||
for arg in args:
|
||||
if not isinstance(arg, str):
|
||||
continue
|
||||
try:
|
||||
torch.device(arg)
|
||||
device_arg_or_kwarg_present = True
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
if not dtype_present_in_args:
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.dtype):
|
||||
@@ -1213,6 +1296,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
||||
)
|
||||
|
||||
if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
|
||||
logger.warning(
|
||||
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."
|
||||
)
|
||||
return self
|
||||
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
# Taken from `transformers`.
|
||||
|
||||
@@ -66,6 +66,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_supports_gradient_checkpointing = True
|
||||
_supports_group_offloading = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -245,6 +245,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"]
|
||||
_supports_group_offloading = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -394,6 +394,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
device = device or device_arg
|
||||
device_type = torch.device(device).type if device is not None else None
|
||||
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
|
||||
|
||||
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
|
||||
@@ -424,7 +425,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
|
||||
)
|
||||
|
||||
if device and torch.device(device).type == "cuda":
|
||||
if device_type == "cuda":
|
||||
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
|
||||
raise ValueError(
|
||||
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
||||
@@ -437,7 +438,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
||||
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
|
||||
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
|
||||
if pipeline_is_offloaded and device_type == "cuda":
|
||||
logger.warning(
|
||||
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
|
||||
)
|
||||
@@ -449,6 +450,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
||||
for module in modules:
|
||||
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
|
||||
is_group_offloaded = self._maybe_raise_error_if_group_offload_active(module=module)
|
||||
|
||||
if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
|
||||
logger.warning(
|
||||
@@ -460,11 +462,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
|
||||
)
|
||||
|
||||
# Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
|
||||
# components can be from outside diffusers too, but still have group offloading enabled.
|
||||
if (
|
||||
self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module)
|
||||
and device is not None
|
||||
):
|
||||
logger.warning(
|
||||
f"The module '{module.__class__.__name__}' is group offloaded and moving it to {device} via `.to()` is not supported."
|
||||
)
|
||||
|
||||
# This can happen for `transformer` models. CPU placement was added in
|
||||
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
|
||||
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
|
||||
module.to(device=device)
|
||||
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
|
||||
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
|
||||
module.to(device, dtype)
|
||||
|
||||
if (
|
||||
@@ -1023,6 +1035,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
|
||||
Accelerate's module hooks.
|
||||
"""
|
||||
from ..hooks.group_offloading import _get_group_onload_device
|
||||
|
||||
# When apply group offloading at the leaf_level, we're in the same situation as accelerate's sequential
|
||||
# offloading. We need to return the onload device of the group offloading hooks so that the intermediates
|
||||
# required for computation (latents, prompt embeddings, etc.) can be created on the correct device.
|
||||
for name, model in self.components.items():
|
||||
if not isinstance(model, torch.nn.Module):
|
||||
continue
|
||||
try:
|
||||
return _get_group_onload_device(model)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
for name, model in self.components.items():
|
||||
if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
|
||||
continue
|
||||
@@ -1061,6 +1086,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
||||
default to "cuda".
|
||||
"""
|
||||
self._maybe_raise_error_if_group_offload_active(raise_error=True)
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
@@ -1172,6 +1199,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
||||
default to "cuda".
|
||||
"""
|
||||
self._maybe_raise_error_if_group_offload_active(raise_error=True)
|
||||
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
@@ -1896,6 +1925,24 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
return new_pipeline
|
||||
|
||||
def _maybe_raise_error_if_group_offload_active(
|
||||
self, raise_error: bool = False, module: Optional[torch.nn.Module] = None
|
||||
) -> bool:
|
||||
from ..hooks.group_offloading import _is_group_offload_enabled
|
||||
|
||||
components = self.components.values() if module is None else [module]
|
||||
components = [component for component in components if isinstance(component, torch.nn.Module)]
|
||||
for component in components:
|
||||
if _is_group_offload_enabled(component):
|
||||
if raise_error:
|
||||
raise ValueError(
|
||||
"You are trying to apply model/sequential CPU offloading to a pipeline that contains components "
|
||||
"with group offloading enabled. This is not supported. Please disable group offloading for "
|
||||
"components of the pipeline to use other offloading methods."
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class StableDiffusionMixin:
|
||||
r"""
|
||||
|
||||
214
tests/hooks/test_group_offloading.py
Normal file
214
tests/hooks/test_group_offloading.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.utils import get_logger
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
|
||||
|
||||
|
||||
class DummyBlock(torch.nn.Module):
|
||||
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.proj_in = torch.nn.Linear(in_features, hidden_features)
|
||||
self.activation = torch.nn.ReLU()
|
||||
self.proj_out = torch.nn.Linear(hidden_features, out_features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj_in(x)
|
||||
x = self.activation(x)
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class DummyModel(ModelMixin):
|
||||
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
|
||||
self.activation = torch.nn.ReLU()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
|
||||
)
|
||||
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear_1(x)
|
||||
x = self.activation(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class DummyPipeline(DiffusionPipeline):
|
||||
model_cpu_offload_seq = "model"
|
||||
|
||||
def __init__(self, model: torch.nn.Module) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(model=model)
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for _ in range(2):
|
||||
x = x + 0.1 * self.model(x)
|
||||
return x
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class GroupOffloadTests(unittest.TestCase):
|
||||
in_features = 64
|
||||
hidden_features = 256
|
||||
out_features = 64
|
||||
num_layers = 4
|
||||
|
||||
def setUp(self):
|
||||
with torch.no_grad():
|
||||
self.model = self.get_model()
|
||||
self.input = torch.randn((4, self.in_features)).to(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
del self.model
|
||||
del self.input
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
def get_model(self):
|
||||
torch.manual_seed(0)
|
||||
return DummyModel(
|
||||
in_features=self.in_features,
|
||||
hidden_features=self.hidden_features,
|
||||
out_features=self.out_features,
|
||||
num_layers=self.num_layers,
|
||||
)
|
||||
|
||||
def test_offloading_forward_pass(self):
|
||||
@torch.no_grad()
|
||||
def run_forward(model):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.assertTrue(
|
||||
all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in model.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
)
|
||||
)
|
||||
model.eval()
|
||||
output = model(self.input)[0].cpu()
|
||||
max_memory_allocated = torch.cuda.max_memory_allocated()
|
||||
return output, max_memory_allocated
|
||||
|
||||
self.model.to(torch_device)
|
||||
output_without_group_offloading, mem_baseline = run_forward(self.model)
|
||||
self.model.to("cpu")
|
||||
|
||||
model = self.get_model()
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
|
||||
output_with_group_offloading1, mem1 = run_forward(model)
|
||||
|
||||
model = self.get_model()
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
|
||||
output_with_group_offloading2, mem2 = run_forward(model)
|
||||
|
||||
model = self.get_model()
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
|
||||
output_with_group_offloading3, mem3 = run_forward(model)
|
||||
|
||||
model = self.get_model()
|
||||
model.enable_group_offload(torch_device, offload_type="leaf_level")
|
||||
output_with_group_offloading4, mem4 = run_forward(model)
|
||||
|
||||
model = self.get_model()
|
||||
model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True)
|
||||
output_with_group_offloading5, mem5 = run_forward(model)
|
||||
|
||||
# Precision assertions - offloading should not impact the output
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5))
|
||||
|
||||
# Memory assertions - offloading should reduce memory usage
|
||||
self.assertTrue(mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline)
|
||||
|
||||
def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self):
|
||||
if torch.device(torch_device).type != "cuda":
|
||||
return
|
||||
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
|
||||
logger = get_logger("diffusers.models.modeling_utils")
|
||||
logger.setLevel("INFO")
|
||||
with self.assertLogs(logger, level="WARNING") as cm:
|
||||
self.model.to(torch_device)
|
||||
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
|
||||
|
||||
def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self):
|
||||
if torch.device(torch_device).type != "cuda":
|
||||
return
|
||||
pipe = DummyPipeline(self.model)
|
||||
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
|
||||
logger = get_logger("diffusers.pipelines.pipeline_utils")
|
||||
logger.setLevel("INFO")
|
||||
with self.assertLogs(logger, level="WARNING") as cm:
|
||||
pipe.to(torch_device)
|
||||
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
|
||||
|
||||
def test_error_raised_if_streams_used_and_no_cuda_device(self):
|
||||
original_is_available = torch.cuda.is_available
|
||||
torch.cuda.is_available = lambda: False
|
||||
with self.assertRaises(ValueError):
|
||||
self.model.enable_group_offload(
|
||||
onload_device=torch.device("cuda"), offload_type="leaf_level", use_stream=True
|
||||
)
|
||||
torch.cuda.is_available = original_is_available
|
||||
|
||||
def test_error_raised_if_supports_group_offloading_false(self):
|
||||
self.model._supports_group_offloading = False
|
||||
with self.assertRaisesRegex(ValueError, "does not support group offloading"):
|
||||
self.model.enable_group_offload(onload_device=torch.device("cuda"))
|
||||
|
||||
def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self):
|
||||
pipe = DummyPipeline(self.model)
|
||||
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
|
||||
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
def test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module(self):
|
||||
pipe = DummyPipeline(self.model)
|
||||
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
|
||||
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
def test_error_raised_if_group_offloading_applied_on_model_offloaded_module(self):
|
||||
pipe = DummyPipeline(self.model)
|
||||
pipe.enable_model_cpu_offload()
|
||||
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
|
||||
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
|
||||
|
||||
def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module(self):
|
||||
pipe = DummyPipeline(self.model)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
|
||||
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
|
||||
@@ -1458,6 +1458,55 @@ class ModelTesterMixin:
|
||||
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_group_offloading(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
torch.manual_seed(0)
|
||||
|
||||
@torch.no_grad()
|
||||
def run_forward(model):
|
||||
self.assertTrue(
|
||||
all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in model.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
)
|
||||
)
|
||||
model.eval()
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
if not getattr(model, "_supports_group_offloading", True):
|
||||
return
|
||||
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
|
||||
output_with_group_offloading1 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
|
||||
output_with_group_offloading2 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="leaf_level")
|
||||
output_with_group_offloading3 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True)
|
||||
output_with_group_offloading4 = run_forward(model)
|
||||
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ModelPushToHubTester(unittest.TestCase):
|
||||
|
||||
@@ -58,6 +58,7 @@ class AllegroPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTes
|
||||
)
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -39,6 +39,7 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -61,6 +61,7 @@ class AnimateDiffPipelineFastTests(
|
||||
]
|
||||
)
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
cross_attention_dim = 8
|
||||
|
||||
@@ -31,6 +31,7 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -60,6 +60,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastT
|
||||
)
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -56,6 +56,7 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
)
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -57,6 +57,7 @@ class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -59,6 +59,7 @@ class ConsisIDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -127,6 +127,7 @@ class ControlNetPipelineFastTests(
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -76,6 +76,7 @@ class StableDiffusionXLControlNetPipelineFastTests(
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -51,6 +51,7 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -60,6 +60,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(
|
||||
self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm", use_dual_attention=False
|
||||
|
||||
@@ -140,6 +140,7 @@ class ControlNetXSPipelineFastTests(
|
||||
|
||||
test_attention_slicing = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -79,6 +79,7 @@ class StableDiffusionXLControlNetXSPipelineFastTests(
|
||||
|
||||
test_attention_slicing = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -35,6 +35,7 @@ class FluxPipelineFastTests(
|
||||
# there is no xformers processor for Flux
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -23,6 +23,7 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
# there is no xformers processor for Flux
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -24,6 +24,7 @@ class FluxFillPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
batch_params = frozenset(["prompt"])
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -54,6 +54,7 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca
|
||||
# there is no xformers processor for Flux
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -54,6 +54,7 @@ class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTeste
|
||||
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
pab_config = PyramidAttentionBroadcastConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
|
||||
@@ -47,6 +47,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -33,6 +33,7 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM
|
||||
|
||||
supports_dduf = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -56,6 +56,7 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -56,6 +56,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr
|
||||
]
|
||||
)
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
cross_attention_dim = 8
|
||||
|
||||
@@ -51,6 +51,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -56,6 +56,7 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -53,6 +53,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -124,6 +124,7 @@ class StableDiffusionPipelineFastTests(
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
cross_attention_dim = 8
|
||||
|
||||
@@ -76,6 +76,7 @@ class StableDiffusion2PipelineFastTests(
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -36,6 +36,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -76,6 +76,7 @@ class StableDiffusionXLPipelineFastTests(
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -29,6 +29,7 @@ from diffusers import (
|
||||
StableDiffusionXLPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
|
||||
@@ -47,6 +48,7 @@ from diffusers.utils.testing_utils import (
|
||||
require_accelerator,
|
||||
require_hf_hub_version_greater,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_transformers_version_greater,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
@@ -990,6 +992,7 @@ class PipelineTesterMixin:
|
||||
|
||||
test_xformers_attention = True
|
||||
test_layerwise_casting = False
|
||||
test_group_offloading = False
|
||||
supports_dduf = True
|
||||
|
||||
def get_generator(self, seed):
|
||||
@@ -2044,6 +2047,79 @@ class PipelineTesterMixin:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
@require_torch_gpu
|
||||
def test_group_offloading_inference(self):
|
||||
if not self.test_group_offloading:
|
||||
return
|
||||
|
||||
def create_pipe():
|
||||
torch.manual_seed(0)
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
def enable_group_offload_on_component(pipe, group_offloading_kwargs):
|
||||
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
|
||||
# tiling is enabled and a forward pass is run, when cuda streams are used, the execution order of
|
||||
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
|
||||
# warmup forward pass (even with dummy small inputs) is recommended.
|
||||
for component_name in [
|
||||
"text_encoder",
|
||||
"text_encoder_2",
|
||||
"text_encoder_3",
|
||||
"transformer",
|
||||
"unet",
|
||||
"controlnet",
|
||||
]:
|
||||
if not hasattr(pipe, component_name):
|
||||
continue
|
||||
component = getattr(pipe, component_name)
|
||||
if not getattr(component, "_supports_group_offloading", True):
|
||||
continue
|
||||
if hasattr(component, "enable_group_offload"):
|
||||
# For diffusers ModelMixin implementations
|
||||
component.enable_group_offload(torch.device(torch_device), **group_offloading_kwargs)
|
||||
else:
|
||||
# For other models not part of diffusers
|
||||
apply_group_offloading(
|
||||
component, onload_device=torch.device(torch_device), **group_offloading_kwargs
|
||||
)
|
||||
self.assertTrue(
|
||||
all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in component.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
)
|
||||
)
|
||||
for component_name in ["vae", "vqvae"]:
|
||||
if hasattr(pipe, component_name):
|
||||
getattr(pipe, component_name).to(torch_device)
|
||||
|
||||
def run_forward(pipe):
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
return pipe(**inputs)[0]
|
||||
|
||||
pipe = create_pipe().to(torch_device)
|
||||
output_without_group_offloading = run_forward(pipe)
|
||||
|
||||
pipe = create_pipe()
|
||||
enable_group_offload_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1})
|
||||
output_with_group_offloading1 = run_forward(pipe)
|
||||
|
||||
pipe = create_pipe()
|
||||
enable_group_offload_on_component(pipe, {"offload_type": "leaf_level"})
|
||||
output_with_group_offloading2 = run_forward(pipe)
|
||||
|
||||
if torch.is_tensor(output_without_group_offloading):
|
||||
output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
|
||||
output_with_group_offloading1 = output_with_group_offloading1.detach().cpu().numpy()
|
||||
output_with_group_offloading2 = output_with_group_offloading2.detach().cpu().numpy()
|
||||
|
||||
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
|
||||
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class PipelinePushToHubTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user