1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[core] Pyramid Attention Broadcast (#9562)

* start pyramid attention broadcast

* add coauthor

Co-Authored-By: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>

* update

* make style

* update

* make style

* add docs

* add tests

* update

* Update docs/source/en/api/pipelines/cogvideox.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Pyramid Attention Broadcast rewrite + introduce hooks (#9826)

* rewrite implementation with hooks

* make style

* update

* merge pyramid-attention-rewrite-2

* make style

* remove changes from latte transformer

* revert docs changes

* better debug message

* add todos for future

* update tests

* make style

* cleanup

* fix

* improve log message; fix latte test

* refactor

* update

* update

* update

* revert changes to tests

* update docs

* update tests

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* fix flux test

* reorder

* refactor

* make fix-copies

* update docs

* fixes

* more fixes

* make style

* update tests

* update code example

* make fix-copies

* refactor based on reviews

* use maybe_free_model_hooks

* CacheMixin

* make style

* update

* add current_timestep property; update docs

* make fix-copies

* update

* improve tests

* try circular import fix

* apply suggestions from review

* address review comments

* Apply suggestions from code review

* refactor hook implementation

* add test suite for hooks

* PAB Refactor (#10667)

* update

* update

* update

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>

* update

* fix remove hook behaviour

---------

Co-authored-by: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: DN6 <dhruv.nair@gmail.com>
This commit is contained in:
Aryan
2025-01-28 05:09:04 +05:30
committed by GitHub
parent fb42066489
commit 658e24e86c
32 changed files with 1257 additions and 68 deletions

View File

@@ -598,6 +598,8 @@
title: Attention Processor
- local: api/activations
title: Custom activation functions
- local: api/cache
title: Caching methods
- local: api/normalization
title: Custom normalization layers
- local: api/utilities

View File

@@ -0,0 +1,49 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# Caching methods
## Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
```python
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
```
### CacheMixin
[[autodoc]] CacheMixin
### PyramidAttentionBroadcastConfig
[[autodoc]] PyramidAttentionBroadcastConfig
[[autodoc]] apply_pyramid_attention_broadcast

View File

@@ -28,6 +28,7 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
@@ -75,6 +76,13 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["hooks"].extend(
[
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"apply_pyramid_attention_broadcast",
]
)
_import_structure["models"].extend(
[
"AllegroTransformer3DModel",
@@ -90,6 +98,7 @@ else:
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
"AutoencoderTiny",
"CacheMixin",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"ConsisIDTransformer3DModel",
@@ -588,6 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
@@ -602,6 +612,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
CacheMixin,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
ConsisIDTransformer3DModel,

View File

@@ -2,4 +2,6 @@ from ..utils import is_torch_available
if is_torch_available():
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast

View File

@@ -30,6 +30,9 @@ class ModelHook:
_is_stateful = False
def __init__(self):
self.fn_ref: "HookFunctionReference" = None
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is initialized.
@@ -48,8 +51,6 @@ class ModelHook:
module (`torch.nn.Module`):
The module attached to this hook.
"""
module.forward = module._old_forward
del module._old_forward
return module
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
@@ -99,6 +100,29 @@ class ModelHook:
return module
class HookFunctionReference:
def __init__(self) -> None:
"""A container class that maintains mutable references to forward pass functions in a hook chain.
Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the
entire forward pass structure.
Attributes:
pre_forward: A callable that processes inputs before the main forward pass.
post_forward: A callable that processes outputs after the main forward pass.
forward: The current forward function in the hook chain.
original_forward: The original forward function, stored when a hook provides a custom new_forward.
The class enables hook removal by allowing updates to the forward chain through reference modification rather
than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to
be updated, preserving the execution order of the remaining hooks.
"""
self.pre_forward = None
self.post_forward = None
self.forward = None
self.original_forward = None
class HookRegistry:
def __init__(self, module_ref: torch.nn.Module) -> None:
super().__init__()
@@ -107,51 +131,71 @@ class HookRegistry:
self._module_ref = module_ref
self._hook_order = []
self._fn_refs = []
def register_hook(self, hook: ModelHook, name: str) -> None:
if name in self.hooks.keys():
logger.warning(f"Hook with name {name} already exists, replacing it.")
if hasattr(self._module_ref, "_old_forward"):
old_forward = self._module_ref._old_forward
else:
old_forward = self._module_ref.forward
self._module_ref._old_forward = self._module_ref.forward
raise ValueError(
f"Hook with name {name} already exists in the registry. Please use a different name or "
f"first remove the existing hook and then add a new one."
)
self._module_ref = hook.initialize_hook(self._module_ref)
def create_new_forward(function_reference: HookFunctionReference):
def new_forward(module, *args, **kwargs):
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
output = function_reference.forward(*args, **kwargs)
return function_reference.post_forward(module, output)
return new_forward
forward = self._module_ref.forward
fn_ref = HookFunctionReference()
fn_ref.pre_forward = hook.pre_forward
fn_ref.post_forward = hook.post_forward
fn_ref.forward = forward
if hasattr(hook, "new_forward"):
rewritten_forward = hook.new_forward
def new_forward(module, *args, **kwargs):
args, kwargs = hook.pre_forward(module, *args, **kwargs)
output = rewritten_forward(module, *args, **kwargs)
return hook.post_forward(module, output)
else:
def new_forward(module, *args, **kwargs):
args, kwargs = hook.pre_forward(module, *args, **kwargs)
output = old_forward(*args, **kwargs)
return hook.post_forward(module, output)
fn_ref.original_forward = forward
fn_ref.forward = functools.update_wrapper(
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
)
rewritten_forward = create_new_forward(fn_ref)
self._module_ref.forward = functools.update_wrapper(
functools.partial(new_forward, self._module_ref), old_forward
functools.partial(rewritten_forward, self._module_ref), rewritten_forward
)
hook.fn_ref = fn_ref
self.hooks[name] = hook
self._hook_order.append(name)
self._fn_refs.append(fn_ref)
def get_hook(self, name: str) -> Optional[ModelHook]:
if name not in self.hooks.keys():
return None
return self.hooks[name]
return self.hooks.get(name, None)
def remove_hook(self, name: str, recurse: bool = True) -> None:
if name in self.hooks.keys():
num_hooks = len(self._hook_order)
hook = self.hooks[name]
index = self._hook_order.index(name)
fn_ref = self._fn_refs[index]
old_forward = fn_ref.forward
if fn_ref.original_forward is not None:
old_forward = fn_ref.original_forward
if index == num_hooks - 1:
self._module_ref.forward = old_forward
else:
self._fn_refs[index + 1].forward = old_forward
self._module_ref = hook.deinitalize_hook(self._module_ref)
del self.hooks[name]
self._hook_order.remove(name)
self._hook_order.pop(index)
self._fn_refs.pop(index)
if recurse:
for module_name, module in self._module_ref.named_modules():
@@ -161,7 +205,7 @@ class HookRegistry:
module._diffusers_hook.remove_hook(name, recurse=False)
def reset_stateful_hooks(self, recurse: bool = True) -> None:
for hook_name in self._hook_order:
for hook_name in reversed(self._hook_order):
hook = self.hooks[hook_name]
if hook._is_stateful:
hook.reset_state(self._module_ref)
@@ -180,9 +224,13 @@ class HookRegistry:
return module._diffusers_hook
def __repr__(self) -> str:
hook_repr = ""
registry_repr = ""
for i, hook_name in enumerate(self._hook_order):
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
if self.hooks[hook_name].__class__.__repr__ is not object.__repr__:
hook_repr = self.hooks[hook_name].__repr__()
else:
hook_repr = self.hooks[hook_name].__class__.__name__
registry_repr += f" ({i}) {hook_name} - {hook_repr}"
if i < len(self._hook_order) - 1:
hook_repr += "\n"
return f"HookRegistry(\n{hook_repr}\n)"
registry_repr += "\n"
return f"HookRegistry(\n{registry_repr}\n)"

View File

@@ -0,0 +1,314 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass
from typing import Any, Callable, Optional, Tuple, Union
import torch
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from .hooks import HookRegistry, ModelHook
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
@dataclass
class PyramidAttentionBroadcastConfig:
r"""
Configuration for Pyramid Attention Broadcast.
Args:
spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific spatial attention broadcast is skipped before computing the attention states
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
old attention states will be re-used) before computing the new attention states again.
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific temporal attention broadcast is skipped before computing the attention
states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times
(i.e., old attention states will be re-used) before computing the new attention states again.
cross_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific cross-attention broadcast is skipped before computing the attention states
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
old attention states will be re-used) before computing the new attention states again.
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the spatial attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the temporal attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the cross-attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
"""
spatial_attention_block_skip_range: Optional[int] = None
temporal_attention_block_skip_range: Optional[int] = None
cross_attention_block_skip_range: Optional[int] = None
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
current_timestep_callback: Callable[[], int] = None
# TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase
# so not added for now)
def __repr__(self) -> str:
return (
f"PyramidAttentionBroadcastConfig("
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range},\n"
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
f" cross_attention_block_identifiers={self.cross_attention_block_identifiers},\n"
f" current_timestep_callback={self.current_timestep_callback}\n"
")"
)
class PyramidAttentionBroadcastState:
r"""
State for Pyramid Attention Broadcast.
Attributes:
iteration (`int`):
The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is
called before starting a new inference forward pass for PAB to work correctly.
cache (`Any`):
The cached output from the previous forward pass. This is used to re-use the attention states when the
attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module.
"""
def __init__(self) -> None:
self.iteration = 0
self.cache = None
def reset(self):
self.iteration = 0
self.cache = None
def __repr__(self):
cache_repr = ""
if self.cache is None:
cache_repr = "None"
else:
cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})"
return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})"
class PyramidAttentionBroadcastHook(ModelHook):
r"""A hook that applies Pyramid Attention Broadcast to a given module."""
_is_stateful = True
def __init__(
self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
) -> None:
super().__init__()
self.timestep_skip_range = timestep_skip_range
self.block_skip_range = block_skip_range
self.current_timestep_callback = current_timestep_callback
def initialize_hook(self, module):
self.state = PyramidAttentionBroadcastState()
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
is_within_timestep_range = (
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
)
should_compute_attention = (
self.state.cache is None
or self.state.iteration == 0
or not is_within_timestep_range
or self.state.iteration % self.block_skip_range == 0
)
if should_compute_attention:
output = self.fn_ref.original_forward(*args, **kwargs)
else:
output = self.state.cache
self.state.cache = output
self.state.iteration += 1
return output
def reset_state(self, module: torch.nn.Module) -> None:
self.state.reset()
return module
def apply_pyramid_attention_broadcast(
module: torch.nn.Module,
config: PyramidAttentionBroadcastConfig,
):
r"""
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
PAB is an attention approximation method that leverages the similarity in attention states between timesteps to
reduce the computational cost of attention computation. The key takeaway from the paper is that the attention
similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and
spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently
than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process.
Args:
module (`torch.nn.Module`):
The module to apply Pyramid Attention Broadcast to.
config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`):
The configuration to use for Pyramid Attention Broadcast.
Example:
```python
>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
>>> from diffusers.utils import export_to_video
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = PyramidAttentionBroadcastConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(100, 800),
... current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> apply_pyramid_attention_broadcast(pipe.transformer, config)
```
"""
if config.current_timestep_callback is None:
raise ValueError(
"The `current_timestep_callback` function must be provided in the configuration to apply Pyramid Attention Broadcast."
)
if (
config.spatial_attention_block_skip_range is None
and config.temporal_attention_block_skip_range is None
and config.cross_attention_block_skip_range is None
):
logger.warning(
"Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` "
"or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. "
"To avoid this warning, please set one of the above parameters."
)
config.spatial_attention_block_skip_range = 2
for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES):
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
continue
_apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config)
def _apply_pyramid_attention_broadcast_on_attention_class(
name: str, module: Attention, config: PyramidAttentionBroadcastConfig
) -> bool:
is_spatial_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
and config.spatial_attention_block_skip_range is not None
and not getattr(module, "is_cross_attention", False)
)
is_temporal_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
and config.temporal_attention_block_skip_range is not None
and not getattr(module, "is_cross_attention", False)
)
is_cross_attention = (
any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers)
and config.cross_attention_block_skip_range is not None
and getattr(module, "is_cross_attention", False)
)
block_skip_range, timestep_skip_range, block_type = None, None, None
if is_spatial_self_attention:
block_skip_range = config.spatial_attention_block_skip_range
timestep_skip_range = config.spatial_attention_timestep_skip_range
block_type = "spatial"
elif is_temporal_self_attention:
block_skip_range = config.temporal_attention_block_skip_range
timestep_skip_range = config.temporal_attention_timestep_skip_range
block_type = "temporal"
elif is_cross_attention:
block_skip_range = config.cross_attention_block_skip_range
timestep_skip_range = config.cross_attention_timestep_skip_range
block_type = "cross"
if block_skip_range is None or timestep_skip_range is None:
logger.info(
f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does '
f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, "
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
f"block identifiers in the configuration."
)
return False
logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}")
_apply_pyramid_attention_broadcast_hook(
module, timestep_skip_range, block_skip_range, config.current_timestep_callback
)
return True
def _apply_pyramid_attention_broadcast_hook(
module: Union[Attention, MochiAttention],
timestep_skip_range: Tuple[int, int],
block_skip_range: int,
current_timestep_callback: Callable[[], int],
):
r"""
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module.
Args:
module (`torch.nn.Module`):
The module to apply Pyramid Attention Broadcast to.
timestep_skip_range (`Tuple[int, int]`):
The range of timesteps to skip in the attention layer. The attention computations will be conditionally
skipped if the current timestep is within the specified range.
block_skip_range (`int`):
The number of times a specific attention broadcast is skipped before computing the attention states to
re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old
attention states will be re-used) before computing the new attention states again.
current_timestep_callback (`Callable[[], int]`):
A callback function that returns the current inference timestep.
"""
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
registry.register_hook(hook, "pyramid_attention_broadcast")

View File

@@ -39,6 +39,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["cache_utils"] = ["CacheMixin"]
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
_import_structure["controlnets.controlnet_hunyuan"] = [
@@ -109,6 +110,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ConsistencyDecoderVAE,
VQModel,
)
from .cache_utils import CacheMixin
from .controlnets import (
ControlNetModel,
ControlNetUnionModel,

View File

@@ -0,0 +1,89 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils.logging import get_logger
logger = get_logger(__name__) # pylint: disable=invalid-name
class CacheMixin:
r"""
A class for enable/disabling caching techniques on diffusion models.
Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
"""
_cache_config = None
@property
def is_cache_enabled(self) -> bool:
return self._cache_config is not None
def enable_cache(self, config) -> None:
r"""
Enable caching techniques on the model.
Args:
config (`Union[PyramidAttentionBroadcastConfig]`):
The configuration for applying the caching technique. Currently supported caching techniques are:
- [`~hooks.PyramidAttentionBroadcastConfig`]
Example:
```python
>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = PyramidAttentionBroadcastConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(100, 800),
... current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> pipe.transformer.enable_cache(config)
```
"""
from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
if isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config
def disable_cache(self) -> None:
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook("pyramid_attention_broadcast", recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
self._cache_config = None
def _reset_stateful_cache(self, recurse: bool = True) -> None:
from ..hooks import HookRegistry
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)

View File

@@ -24,6 +24,7 @@ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_lay
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -156,7 +157,7 @@ class CogVideoXBlock(nn.Module):
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
@@ -19,13 +20,14 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..attention import BasicTransformerBlock
from ..cache_utils import CacheMixin
from ..embeddings import PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
class LatteTransformer3DModel(ModelMixin, ConfigMixin):
class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True
"""

View File

@@ -24,6 +24,7 @@ from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import AllegroAttnProcessor2_0, Attention
from ..cache_utils import CacheMixin
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -172,7 +173,7 @@ class AllegroTransformerBlock(nn.Module):
return hidden_states
class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True
"""

View File

@@ -35,6 +35,7 @@ from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, Ad
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput
@@ -227,7 +228,7 @@ class FluxTransformerBlock(nn.Module):
class FluxTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
):
"""
The Transformer model introduced in Flux.

View File

@@ -25,6 +25,7 @@ from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor
from ..cache_utils import CacheMixin
from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
@@ -502,7 +503,7 @@ class HunyuanVideoTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).

View File

@@ -25,6 +25,7 @@ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_lay
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -305,7 +306,7 @@ class MochiRoPE(nn.Module):
@maybe_allow_in_graph
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).

View File

@@ -683,6 +683,10 @@ class AllegroPipeline(DiffusionPipeline):
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@@ -815,6 +819,7 @@ class AllegroPipeline(DiffusionPipeline):
negative_prompt_attention_mask,
)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
# 2. Default height and width to transformer
@@ -892,6 +897,7 @@ class AllegroPipeline(DiffusionPipeline):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -933,6 +939,8 @@ class AllegroPipeline(DiffusionPipeline):
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
video = self.decode_latents(latents)

View File

@@ -494,6 +494,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@@ -627,6 +631,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Default call parameters
@@ -705,6 +710,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -763,6 +769,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]

View File

@@ -540,6 +540,10 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@@ -680,6 +684,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Default call parameters
@@ -766,6 +771,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -818,6 +824,8 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)

View File

@@ -591,6 +591,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@@ -728,6 +732,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
negative_prompt_embeds=negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._attention_kwargs = attention_kwargs
self._interrupt = False
@@ -815,6 +820,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -877,6 +883,8 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]

View File

@@ -564,6 +564,10 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@@ -700,6 +704,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Default call parameters
@@ -786,6 +791,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -844,6 +850,8 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)

View File

@@ -28,8 +28,7 @@ from transformers import (
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...models import AutoencoderKL, FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
@@ -620,6 +619,10 @@ class FluxPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@@ -775,6 +778,7 @@ class FluxPipeline(
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
@@ -899,6 +903,7 @@ class FluxPipeline(
if self.interrupt:
continue
self._current_timestep = t
if image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -957,9 +962,10 @@ class FluxPipeline(
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor

View File

@@ -456,6 +456,10 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@@ -577,6 +581,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
device = self._execution_device
@@ -644,6 +649,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
@@ -678,6 +684,8 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
video = self.vae.decode(latents, return_dict=False)[0]

View File

@@ -602,6 +602,10 @@ class LattePipeline(DiffusionPipeline):
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@@ -633,7 +637,7 @@ class LattePipeline(DiffusionPipeline):
clean_caption: bool = True,
mask_feature: bool = True,
enable_temporal_attentions: bool = True,
decode_chunk_size: Optional[int] = None,
decode_chunk_size: int = 14,
) -> Union[LattePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -729,6 +733,7 @@ class LattePipeline(DiffusionPipeline):
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
# 2. Default height and width to transformer
@@ -790,6 +795,7 @@ class LattePipeline(DiffusionPipeline):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -850,6 +856,8 @@ class LattePipeline(DiffusionPipeline):
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latents":
deprecation_message = (
"Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead."
@@ -858,7 +866,7 @@ class LattePipeline(DiffusionPipeline):
output_type = "latent"
if not output_type == "latent":
video = self.decode_latents(latents, video_length, decode_chunk_size=14)
video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents

View File

@@ -21,8 +21,7 @@ from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import Mochi1LoraLoaderMixin
from ...models.autoencoders import AutoencoderKLMochi
from ...models.transformers import MochiTransformer3DModel
from ...models import AutoencoderKLMochi, MochiTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
is_torch_xla_available,
@@ -467,6 +466,10 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@@ -591,6 +594,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
@@ -660,6 +664,9 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
if self.interrupt:
continue
# Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need
# to make sure we're using the correct non-reversed timestep values.
self._current_timestep = 1000 - t
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
@@ -705,6 +712,8 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
video = latents
else:

View File

@@ -1133,11 +1133,20 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
def maybe_free_model_hooks(self):
r"""
Function that offloads all components, removes all model hooks that were added when using
`enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
functions correctly when applying enable_model_cpu_offload.
Method that performs the following:
- Offloads all components.
- Removes all model hooks that were added when using `enable_model_cpu_offload`, and then applies them again.
In case the model has not been offloaded, this function is a no-op.
- Resets stateful diffusers hooks of denoiser components if they were added with
[`~hooks.HookRegistry.register_hook`].
Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions
correctly when applying `enable_model_cpu_offload`.
"""
for component in self.components.values():
if hasattr(component, "_reset_stateful_cache"):
component._reset_stateful_cache()
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
# `enable_model_cpu_offload` has not be called, so silently do nothing
return

View File

@@ -2,6 +2,40 @@
from ..utils import DummyObject, requires_backends
class HookRegistry(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
class AllegroTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -197,6 +231,21 @@ class AutoencoderTiny(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class CacheMixin(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class CogVideoXTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

382
tests/hooks/test_hooks.py Normal file
View File

@@ -0,0 +1,382 @@
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import torch
from diffusers.hooks import HookRegistry, ModelHook
from diffusers.training_utils import free_memory
from diffusers.utils.logging import get_logger
from diffusers.utils.testing_utils import CaptureLogger, torch_device
logger = get_logger(__name__) # pylint: disable=invalid-name
class DummyBlock(torch.nn.Module):
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
super().__init__()
self.proj_in = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.proj_out = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj_in(x)
x = self.activation(x)
x = self.proj_out(x)
return x
class DummyModel(torch.nn.Module):
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.blocks = torch.nn.ModuleList(
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
)
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = self.activation(x)
for block in self.blocks:
x = block(x)
x = self.linear_2(x)
return x
class AddHook(ModelHook):
def __init__(self, value: int):
super().__init__()
self.value = value
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
logger.debug("AddHook pre_forward")
args = ((x + self.value) if torch.is_tensor(x) else x for x in args)
return args, kwargs
def post_forward(self, module, output):
logger.debug("AddHook post_forward")
return output
class MultiplyHook(ModelHook):
def __init__(self, value: int):
super().__init__()
self.value = value
def pre_forward(self, module, *args, **kwargs):
logger.debug("MultiplyHook pre_forward")
args = ((x * self.value) if torch.is_tensor(x) else x for x in args)
return args, kwargs
def post_forward(self, module, output):
logger.debug("MultiplyHook post_forward")
return output
def __repr__(self):
return f"MultiplyHook(value={self.value})"
class StatefulAddHook(ModelHook):
_is_stateful = True
def __init__(self, value: int):
super().__init__()
self.value = value
self.increment = 0
def pre_forward(self, module, *args, **kwargs):
logger.debug("StatefulAddHook pre_forward")
add_value = self.value + self.increment
self.increment += 1
args = ((x + add_value) if torch.is_tensor(x) else x for x in args)
return args, kwargs
def reset_state(self, module):
self.increment = 0
class SkipLayerHook(ModelHook):
def __init__(self, skip_layer: bool):
super().__init__()
self.skip_layer = skip_layer
def pre_forward(self, module, *args, **kwargs):
logger.debug("SkipLayerHook pre_forward")
return args, kwargs
def new_forward(self, module, *args, **kwargs):
logger.debug("SkipLayerHook new_forward")
if self.skip_layer:
return args[0]
return self.fn_ref.original_forward(*args, **kwargs)
def post_forward(self, module, output):
logger.debug("SkipLayerHook post_forward")
return output
class HookTests(unittest.TestCase):
in_features = 4
hidden_features = 8
out_features = 4
num_layers = 2
def setUp(self):
params = self.get_module_parameters()
self.model = DummyModel(**params)
self.model.to(torch_device)
def tearDown(self):
super().tearDown()
del self.model
gc.collect()
free_memory()
def get_module_parameters(self):
return {
"in_features": self.in_features,
"hidden_features": self.hidden_features,
"out_features": self.out_features,
"num_layers": self.num_layers,
}
def get_generator(self):
return torch.manual_seed(0)
def test_hook_registry(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(AddHook(1), "add_hook")
registry.register_hook(MultiplyHook(2), "multiply_hook")
registry_repr = repr(registry)
expected_repr = (
"HookRegistry(\n" " (0) add_hook - AddHook\n" " (1) multiply_hook - MultiplyHook(value=2)\n" ")"
)
self.assertEqual(len(registry.hooks), 2)
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
self.assertEqual(registry_repr, expected_repr)
registry.remove_hook("add_hook")
self.assertEqual(len(registry.hooks), 1)
self.assertEqual(registry._hook_order, ["multiply_hook"])
def test_stateful_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(StatefulAddHook(1), "stateful_add_hook")
self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0)
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
num_repeats = 3
for i in range(num_repeats):
result = self.model(input)
if i == 0:
output1 = result
self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats)
registry.reset_stateful_hooks()
output2 = self.model(input)
self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1)
self.assertTrue(torch.allclose(output1, output2))
def test_inference(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(AddHook(1), "add_hook")
registry.register_hook(MultiplyHook(2), "multiply_hook")
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
output1 = self.model(input).mean().detach().cpu().item()
registry.remove_hook("multiply_hook")
new_input = input * 2
output2 = self.model(new_input).mean().detach().cpu().item()
registry.remove_hook("add_hook")
new_input = input * 2 + 1
output3 = self.model(new_input).mean().detach().cpu().item()
self.assertAlmostEqual(output1, output2, places=5)
self.assertAlmostEqual(output1, output3, places=5)
def test_skip_layer_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
input = torch.zeros(1, 4, device=torch_device)
output = self.model(input).mean().detach().cpu().item()
self.assertEqual(output, 0.0)
registry.remove_hook("skip_layer_hook")
registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
self.assertNotEqual(output, 0.0)
def test_skip_layer_internal_block(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1)
input = torch.zeros(1, 4, device=torch_device)
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
with self.assertRaises(RuntimeError) as cm:
self.model(input).mean().detach().cpu().item()
self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception))
registry.remove_hook("skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
self.assertNotEqual(output, 0.0)
registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1])
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
self.assertNotEqual(output, 0.0)
def test_invocation_order_stateful_first(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(StatefulAddHook(1), "add_hook")
registry.register_hook(AddHook(2), "add_hook_2")
registry.register_hook(MultiplyHook(3), "multiply_hook")
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
logger = get_logger(__name__)
logger.setLevel("DEBUG")
with CaptureLogger(logger) as cap_logger:
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
(
"MultiplyHook pre_forward\n"
"AddHook pre_forward\n"
"StatefulAddHook pre_forward\n"
"AddHook post_forward\n"
"MultiplyHook post_forward\n"
)
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
(
"MultiplyHook pre_forward\n"
"AddHook pre_forward\n"
"AddHook post_forward\n"
"MultiplyHook post_forward\n"
)
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
def test_invocation_order_stateful_middle(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(AddHook(2), "add_hook")
registry.register_hook(StatefulAddHook(1), "add_hook_2")
registry.register_hook(MultiplyHook(3), "multiply_hook")
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
logger = get_logger(__name__)
logger.setLevel("DEBUG")
with CaptureLogger(logger) as cap_logger:
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
(
"MultiplyHook pre_forward\n"
"StatefulAddHook pre_forward\n"
"AddHook pre_forward\n"
"AddHook post_forward\n"
"MultiplyHook post_forward\n"
)
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
("MultiplyHook pre_forward\nStatefulAddHook pre_forward\nMultiplyHook post_forward\n")
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook_2")
with CaptureLogger(logger) as cap_logger:
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
def test_invocation_order_stateful_last(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(AddHook(1), "add_hook")
registry.register_hook(MultiplyHook(2), "multiply_hook")
registry.register_hook(StatefulAddHook(3), "add_hook_2")
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
logger = get_logger(__name__)
logger.setLevel("DEBUG")
with CaptureLogger(logger) as cap_logger:
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
(
"StatefulAddHook pre_forward\n"
"MultiplyHook pre_forward\n"
"AddHook pre_forward\n"
"AddHook post_forward\n"
"MultiplyHook post_forward\n"
)
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
("StatefulAddHook pre_forward\nMultiplyHook pre_forward\nMultiplyHook post_forward\n")
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)

View File

@@ -34,13 +34,13 @@ from diffusers.utils.testing_utils import (
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
enable_full_determinism()
class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class AllegroPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
pipeline_class = AllegroPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -59,14 +59,14 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = AllegroTransformer3DModel(
num_attention_heads=2,
attention_head_dim=12,
in_channels=4,
out_channels=4,
num_layers=1,
num_layers=num_layers,
cross_attention_dim=24,
sample_width=8,
sample_height=8,

View File

@@ -32,6 +32,7 @@ from diffusers.utils.testing_utils import (
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
to_np,
@@ -41,7 +42,7 @@ from ..test_pipelines_common import (
enable_full_determinism()
class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
pipeline_class = CogVideoXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -60,7 +61,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = CogVideoXTransformer3DModel(
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
@@ -72,7 +73,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
out_channels=4,
time_embed_dim=2,
text_embed_dim=32, # Must match with tiny-random-t5
num_layers=1,
num_layers=num_layers,
sample_width=2, # latent width: 2 -> final width: 16
sample_height=2, # latent height: 2 -> final height: 16
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9

View File

@@ -19,12 +19,15 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
class FluxPipelineFastTests(
unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin
):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
@@ -33,13 +36,13 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte
test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=4,
num_layers=1,
num_single_layers=1,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,

View File

@@ -30,13 +30,13 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin, to_np
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
enable_full_determinism()
class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
pipeline_class = HunyuanVideoPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
@@ -55,15 +55,15 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = HunyuanVideoTransformer3DModel(
in_channels=4,
out_channels=4,
num_attention_heads=2,
attention_head_dim=10,
num_layers=1,
num_single_layers=1,
num_layers=num_layers,
num_single_layers=num_single_layers,
num_refiner_layers=1,
patch_size=1,
patch_size_t=1,

View File

@@ -27,6 +27,7 @@ from diffusers import (
DDIMScheduler,
LattePipeline,
LatteTransformer3DModel,
PyramidAttentionBroadcastConfig,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
@@ -38,13 +39,13 @@ from diffusers.utils.testing_utils import (
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
enable_full_determinism()
class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
pipeline_class = LattePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -54,11 +55,23 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
required_optional_params = PipelineTesterMixin.required_optional_params
test_layerwise_casting = True
def get_dummy_components(self):
pab_config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
temporal_attention_block_skip_range=2,
cross_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 700),
temporal_attention_timestep_skip_range=(100, 800),
cross_attention_timestep_skip_range=(100, 800),
spatial_attention_block_identifiers=["transformer_blocks"],
temporal_attention_block_identifiers=["temporal_transformer_blocks"],
cross_attention_block_identifiers=["transformer_blocks"],
)
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = LatteTransformer3DModel(
sample_size=8,
num_layers=1,
num_layers=num_layers,
patch_size=2,
attention_head_dim=8,
num_attention_heads=3,

View File

@@ -24,10 +24,12 @@ from diffusers import (
DDIMScheduler,
DiffusionPipeline,
KolorsPipeline,
PyramidAttentionBroadcastConfig,
StableDiffusionPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
from diffusers.models.attention_processor import AttnProcessor
@@ -2322,6 +2324,141 @@ class SDXLOptionalComponentsTesterMixin:
self.assertLess(max_diff, expected_max_difference)
class PyramidAttentionBroadcastTesterMixin:
pab_config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
spatial_attention_block_identifiers=["transformer_blocks"],
)
def test_pyramid_attention_broadcast_layers(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
num_layers = 0
num_single_layers = 0
dummy_component_kwargs = {}
dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters
if "num_layers" in dummy_component_parameters:
num_layers = 2
dummy_component_kwargs["num_layers"] = num_layers
if "num_single_layers" in dummy_component_parameters:
num_single_layers = 2
dummy_component_kwargs["num_single_layers"] = num_single_layers
components = self.get_dummy_components(**dummy_component_kwargs)
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
self.pab_config.current_timestep_callback = lambda: pipe.current_timestep
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
denoiser.enable_cache(self.pab_config)
expected_hooks = 0
if self.pab_config.spatial_attention_block_skip_range is not None:
expected_hooks += num_layers + num_single_layers
if self.pab_config.temporal_attention_block_skip_range is not None:
expected_hooks += num_layers + num_single_layers
if self.pab_config.cross_attention_block_skip_range is not None:
expected_hooks += num_layers + num_single_layers
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
count = 0
for module in denoiser.modules():
if hasattr(module, "_diffusers_hook"):
hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
if hook is None:
continue
count += 1
self.assertTrue(
isinstance(hook, PyramidAttentionBroadcastHook),
"Hook should be of type PyramidAttentionBroadcastHook.",
)
self.assertTrue(hook.state.cache is None, "Cache should be None at initialization.")
self.assertEqual(count, expected_hooks, "Number of hooks should match the expected number.")
# Perform dummy inference step to ensure state is updated
def pab_state_check_callback(pipe, i, t, kwargs):
for module in denoiser.modules():
if hasattr(module, "_diffusers_hook"):
hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
if hook is None:
continue
self.assertTrue(
hook.state.cache is not None,
"Cache should have updated during inference.",
)
self.assertTrue(
hook.state.iteration == i + 1,
"Hook iteration state should have updated during inference.",
)
return {}
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 2
inputs["callback_on_step_end"] = pab_state_check_callback
pipe(**inputs)[0]
# After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states
for module in denoiser.modules():
if hasattr(module, "_diffusers_hook"):
hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
if hook is None:
continue
self.assertTrue(
hook.state.cache is None,
"Cache should be reset to None after inference.",
)
self.assertTrue(
hook.state.iteration == 0,
"Iteration should be reset to 0 after inference.",
)
def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2):
# We need to use higher tolerance because we are using a random model. With a converged/trained
# model, the tolerance can be lower.
device = "cpu" # ensure determinism for the device-dependent torch.Generator
num_layers = 2
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
# Run inference without PAB
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
output = pipe(**inputs)[0]
original_image_slice = output.flatten()
original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:]))
# Run inference with PAB enabled
self.pab_config.current_timestep_callback = lambda: pipe.current_timestep
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
denoiser.enable_cache(self.pab_config)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
output = pipe(**inputs)[0]
image_slice_pab_enabled = output.flatten()
image_slice_pab_enabled = np.concatenate((image_slice_pab_enabled[:8], image_slice_pab_enabled[-8:]))
# Run inference with PAB disabled
denoiser.disable_cache()
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
output = pipe(**inputs)[0]
image_slice_pab_disabled = output.flatten()
image_slice_pab_disabled = np.concatenate((image_slice_pab_disabled[:8], image_slice_pab_disabled[-8:]))
assert np.allclose(
original_image_slice, image_slice_pab_enabled, atol=expected_atol
), "PAB outputs should not differ much in specified timestep range."
assert np.allclose(
original_image_slice, image_slice_pab_disabled, atol=1e-4
), "Outputs from normal inference and after disabling cache should not differ."
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.