mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[core] Pyramid Attention Broadcast (#9562)
* start pyramid attention broadcast * add coauthor Co-Authored-By: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> * update * make style * update * make style * add docs * add tests * update * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Pyramid Attention Broadcast rewrite + introduce hooks (#9826) * rewrite implementation with hooks * make style * update * merge pyramid-attention-rewrite-2 * make style * remove changes from latte transformer * revert docs changes * better debug message * add todos for future * update tests * make style * cleanup * fix * improve log message; fix latte test * refactor * update * update * update * revert changes to tests * update docs * update tests * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update * fix flux test * reorder * refactor * make fix-copies * update docs * fixes * more fixes * make style * update tests * update code example * make fix-copies * refactor based on reviews * use maybe_free_model_hooks * CacheMixin * make style * update * add current_timestep property; update docs * make fix-copies * update * improve tests * try circular import fix * apply suggestions from review * address review comments * Apply suggestions from code review * refactor hook implementation * add test suite for hooks * PAB Refactor (#10667) * update * update * update --------- Co-authored-by: DN6 <dhruv.nair@gmail.com> * update * fix remove hook behaviour --------- Co-authored-by: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: DN6 <dhruv.nair@gmail.com>
This commit is contained in:
@@ -598,6 +598,8 @@
|
||||
title: Attention Processor
|
||||
- local: api/activations
|
||||
title: Custom activation functions
|
||||
- local: api/cache
|
||||
title: Caching methods
|
||||
- local: api/normalization
|
||||
title: Custom normalization layers
|
||||
- local: api/utilities
|
||||
|
||||
49
docs/source/en/api/cache.md
Normal file
49
docs/source/en/api/cache.md
Normal file
@@ -0,0 +1,49 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# Caching methods
|
||||
|
||||
## Pyramid Attention Broadcast
|
||||
|
||||
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
|
||||
|
||||
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
|
||||
|
||||
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
|
||||
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
|
||||
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
|
||||
# poorer quality of generated videos.
|
||||
config = PyramidAttentionBroadcastConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(100, 800),
|
||||
current_timestep_callback=lambda: pipe.current_timestep,
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
### CacheMixin
|
||||
|
||||
[[autodoc]] CacheMixin
|
||||
|
||||
### PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] apply_pyramid_attention_broadcast
|
||||
@@ -28,6 +28,7 @@ from .utils import (
|
||||
|
||||
_import_structure = {
|
||||
"configuration_utils": ["ConfigMixin"],
|
||||
"hooks": [],
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
"pipelines": [],
|
||||
@@ -75,6 +76,13 @@ except OptionalDependencyNotAvailable:
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
||||
|
||||
else:
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"HookRegistry",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"AllegroTransformer3DModel",
|
||||
@@ -90,6 +98,7 @@ else:
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
"AutoencoderOobleck",
|
||||
"AutoencoderTiny",
|
||||
"CacheMixin",
|
||||
"CogVideoXTransformer3DModel",
|
||||
"CogView3PlusTransformer2DModel",
|
||||
"ConsisIDTransformer3DModel",
|
||||
@@ -588,6 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .models import (
|
||||
AllegroTransformer3DModel,
|
||||
AsymmetricAutoencoderKL,
|
||||
@@ -602,6 +612,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderTiny,
|
||||
CacheMixin,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
ConsisIDTransformer3DModel,
|
||||
|
||||
@@ -2,4 +2,6 @@ from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
|
||||
@@ -30,6 +30,9 @@ class ModelHook:
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self):
|
||||
self.fn_ref: "HookFunctionReference" = None
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
r"""
|
||||
Hook that is executed when a model is initialized.
|
||||
@@ -48,8 +51,6 @@ class ModelHook:
|
||||
module (`torch.nn.Module`):
|
||||
The module attached to this hook.
|
||||
"""
|
||||
module.forward = module._old_forward
|
||||
del module._old_forward
|
||||
return module
|
||||
|
||||
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
|
||||
@@ -99,6 +100,29 @@ class ModelHook:
|
||||
return module
|
||||
|
||||
|
||||
class HookFunctionReference:
|
||||
def __init__(self) -> None:
|
||||
"""A container class that maintains mutable references to forward pass functions in a hook chain.
|
||||
|
||||
Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the
|
||||
entire forward pass structure.
|
||||
|
||||
Attributes:
|
||||
pre_forward: A callable that processes inputs before the main forward pass.
|
||||
post_forward: A callable that processes outputs after the main forward pass.
|
||||
forward: The current forward function in the hook chain.
|
||||
original_forward: The original forward function, stored when a hook provides a custom new_forward.
|
||||
|
||||
The class enables hook removal by allowing updates to the forward chain through reference modification rather
|
||||
than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to
|
||||
be updated, preserving the execution order of the remaining hooks.
|
||||
"""
|
||||
self.pre_forward = None
|
||||
self.post_forward = None
|
||||
self.forward = None
|
||||
self.original_forward = None
|
||||
|
||||
|
||||
class HookRegistry:
|
||||
def __init__(self, module_ref: torch.nn.Module) -> None:
|
||||
super().__init__()
|
||||
@@ -107,51 +131,71 @@ class HookRegistry:
|
||||
|
||||
self._module_ref = module_ref
|
||||
self._hook_order = []
|
||||
self._fn_refs = []
|
||||
|
||||
def register_hook(self, hook: ModelHook, name: str) -> None:
|
||||
if name in self.hooks.keys():
|
||||
logger.warning(f"Hook with name {name} already exists, replacing it.")
|
||||
|
||||
if hasattr(self._module_ref, "_old_forward"):
|
||||
old_forward = self._module_ref._old_forward
|
||||
else:
|
||||
old_forward = self._module_ref.forward
|
||||
self._module_ref._old_forward = self._module_ref.forward
|
||||
raise ValueError(
|
||||
f"Hook with name {name} already exists in the registry. Please use a different name or "
|
||||
f"first remove the existing hook and then add a new one."
|
||||
)
|
||||
|
||||
self._module_ref = hook.initialize_hook(self._module_ref)
|
||||
|
||||
def create_new_forward(function_reference: HookFunctionReference):
|
||||
def new_forward(module, *args, **kwargs):
|
||||
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
|
||||
output = function_reference.forward(*args, **kwargs)
|
||||
return function_reference.post_forward(module, output)
|
||||
|
||||
return new_forward
|
||||
|
||||
forward = self._module_ref.forward
|
||||
|
||||
fn_ref = HookFunctionReference()
|
||||
fn_ref.pre_forward = hook.pre_forward
|
||||
fn_ref.post_forward = hook.post_forward
|
||||
fn_ref.forward = forward
|
||||
|
||||
if hasattr(hook, "new_forward"):
|
||||
rewritten_forward = hook.new_forward
|
||||
|
||||
def new_forward(module, *args, **kwargs):
|
||||
args, kwargs = hook.pre_forward(module, *args, **kwargs)
|
||||
output = rewritten_forward(module, *args, **kwargs)
|
||||
return hook.post_forward(module, output)
|
||||
else:
|
||||
|
||||
def new_forward(module, *args, **kwargs):
|
||||
args, kwargs = hook.pre_forward(module, *args, **kwargs)
|
||||
output = old_forward(*args, **kwargs)
|
||||
return hook.post_forward(module, output)
|
||||
fn_ref.original_forward = forward
|
||||
fn_ref.forward = functools.update_wrapper(
|
||||
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
|
||||
)
|
||||
|
||||
rewritten_forward = create_new_forward(fn_ref)
|
||||
self._module_ref.forward = functools.update_wrapper(
|
||||
functools.partial(new_forward, self._module_ref), old_forward
|
||||
functools.partial(rewritten_forward, self._module_ref), rewritten_forward
|
||||
)
|
||||
|
||||
hook.fn_ref = fn_ref
|
||||
self.hooks[name] = hook
|
||||
self._hook_order.append(name)
|
||||
self._fn_refs.append(fn_ref)
|
||||
|
||||
def get_hook(self, name: str) -> Optional[ModelHook]:
|
||||
if name not in self.hooks.keys():
|
||||
return None
|
||||
return self.hooks[name]
|
||||
return self.hooks.get(name, None)
|
||||
|
||||
def remove_hook(self, name: str, recurse: bool = True) -> None:
|
||||
if name in self.hooks.keys():
|
||||
num_hooks = len(self._hook_order)
|
||||
hook = self.hooks[name]
|
||||
index = self._hook_order.index(name)
|
||||
fn_ref = self._fn_refs[index]
|
||||
|
||||
old_forward = fn_ref.forward
|
||||
if fn_ref.original_forward is not None:
|
||||
old_forward = fn_ref.original_forward
|
||||
|
||||
if index == num_hooks - 1:
|
||||
self._module_ref.forward = old_forward
|
||||
else:
|
||||
self._fn_refs[index + 1].forward = old_forward
|
||||
|
||||
self._module_ref = hook.deinitalize_hook(self._module_ref)
|
||||
del self.hooks[name]
|
||||
self._hook_order.remove(name)
|
||||
self._hook_order.pop(index)
|
||||
self._fn_refs.pop(index)
|
||||
|
||||
if recurse:
|
||||
for module_name, module in self._module_ref.named_modules():
|
||||
@@ -161,7 +205,7 @@ class HookRegistry:
|
||||
module._diffusers_hook.remove_hook(name, recurse=False)
|
||||
|
||||
def reset_stateful_hooks(self, recurse: bool = True) -> None:
|
||||
for hook_name in self._hook_order:
|
||||
for hook_name in reversed(self._hook_order):
|
||||
hook = self.hooks[hook_name]
|
||||
if hook._is_stateful:
|
||||
hook.reset_state(self._module_ref)
|
||||
@@ -180,9 +224,13 @@ class HookRegistry:
|
||||
return module._diffusers_hook
|
||||
|
||||
def __repr__(self) -> str:
|
||||
hook_repr = ""
|
||||
registry_repr = ""
|
||||
for i, hook_name in enumerate(self._hook_order):
|
||||
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
|
||||
if self.hooks[hook_name].__class__.__repr__ is not object.__repr__:
|
||||
hook_repr = self.hooks[hook_name].__repr__()
|
||||
else:
|
||||
hook_repr = self.hooks[hook_name].__class__.__name__
|
||||
registry_repr += f" ({i}) {hook_name} - {hook_repr}"
|
||||
if i < len(self._hook_order) - 1:
|
||||
hook_repr += "\n"
|
||||
return f"HookRegistry(\n{hook_repr}\n)"
|
||||
registry_repr += "\n"
|
||||
return f"HookRegistry(\n{registry_repr}\n)"
|
||||
|
||||
314
src/diffusers/hooks/pyramid_attention_broadcast.py
Normal file
314
src/diffusers/hooks/pyramid_attention_broadcast.py
Normal file
@@ -0,0 +1,314 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
from ..utils import logging
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PyramidAttentionBroadcastConfig:
|
||||
r"""
|
||||
Configuration for Pyramid Attention Broadcast.
|
||||
|
||||
Args:
|
||||
spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
||||
The number of times a specific spatial attention broadcast is skipped before computing the attention states
|
||||
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
|
||||
old attention states will be re-used) before computing the new attention states again.
|
||||
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
||||
The number of times a specific temporal attention broadcast is skipped before computing the attention
|
||||
states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times
|
||||
(i.e., old attention states will be re-used) before computing the new attention states again.
|
||||
cross_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
||||
The number of times a specific cross-attention broadcast is skipped before computing the attention states
|
||||
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
|
||||
old attention states will be re-used) before computing the new attention states again.
|
||||
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The range of timesteps to skip in the spatial attention layer. The attention computations will be
|
||||
conditionally skipped if the current timestep is within the specified range.
|
||||
temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The range of timesteps to skip in the temporal attention layer. The attention computations will be
|
||||
conditionally skipped if the current timestep is within the specified range.
|
||||
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The range of timesteps to skip in the cross-attention layer. The attention computations will be
|
||||
conditionally skipped if the current timestep is within the specified range.
|
||||
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
|
||||
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
|
||||
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
|
||||
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
|
||||
cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
|
||||
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
|
||||
"""
|
||||
|
||||
spatial_attention_block_skip_range: Optional[int] = None
|
||||
temporal_attention_block_skip_range: Optional[int] = None
|
||||
cross_attention_block_skip_range: Optional[int] = None
|
||||
|
||||
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
|
||||
|
||||
current_timestep_callback: Callable[[], int] = None
|
||||
|
||||
# TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase
|
||||
# so not added for now)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PyramidAttentionBroadcastConfig("
|
||||
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
|
||||
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
|
||||
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
|
||||
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
|
||||
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
|
||||
f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range},\n"
|
||||
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
|
||||
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
|
||||
f" cross_attention_block_identifiers={self.cross_attention_block_identifiers},\n"
|
||||
f" current_timestep_callback={self.current_timestep_callback}\n"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastState:
|
||||
r"""
|
||||
State for Pyramid Attention Broadcast.
|
||||
|
||||
Attributes:
|
||||
iteration (`int`):
|
||||
The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is
|
||||
called before starting a new inference forward pass for PAB to work correctly.
|
||||
cache (`Any`):
|
||||
The cached output from the previous forward pass. This is used to re-use the attention states when the
|
||||
attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.iteration = 0
|
||||
self.cache = None
|
||||
|
||||
def reset(self):
|
||||
self.iteration = 0
|
||||
self.cache = None
|
||||
|
||||
def __repr__(self):
|
||||
cache_repr = ""
|
||||
if self.cache is None:
|
||||
cache_repr = "None"
|
||||
else:
|
||||
cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})"
|
||||
return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})"
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastHook(ModelHook):
|
||||
r"""A hook that applies Pyramid Attention Broadcast to a given module."""
|
||||
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(
|
||||
self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.timestep_skip_range = timestep_skip_range
|
||||
self.block_skip_range = block_skip_range
|
||||
self.current_timestep_callback = current_timestep_callback
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self.state = PyramidAttentionBroadcastState()
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
||||
is_within_timestep_range = (
|
||||
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
|
||||
)
|
||||
should_compute_attention = (
|
||||
self.state.cache is None
|
||||
or self.state.iteration == 0
|
||||
or not is_within_timestep_range
|
||||
or self.state.iteration % self.block_skip_range == 0
|
||||
)
|
||||
|
||||
if should_compute_attention:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
else:
|
||||
output = self.state.cache
|
||||
|
||||
self.state.cache = output
|
||||
self.state.iteration += 1
|
||||
return output
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> None:
|
||||
self.state.reset()
|
||||
return module
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(
|
||||
module: torch.nn.Module,
|
||||
config: PyramidAttentionBroadcastConfig,
|
||||
):
|
||||
r"""
|
||||
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
|
||||
|
||||
PAB is an attention approximation method that leverages the similarity in attention states between timesteps to
|
||||
reduce the computational cost of attention computation. The key takeaway from the paper is that the attention
|
||||
similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and
|
||||
spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently
|
||||
than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to apply Pyramid Attention Broadcast to.
|
||||
config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`):
|
||||
The configuration to use for Pyramid Attention Broadcast.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> config = PyramidAttentionBroadcastConfig(
|
||||
... spatial_attention_block_skip_range=2,
|
||||
... spatial_attention_timestep_skip_range=(100, 800),
|
||||
... current_timestep_callback=lambda: pipe.current_timestep,
|
||||
... )
|
||||
>>> apply_pyramid_attention_broadcast(pipe.transformer, config)
|
||||
```
|
||||
"""
|
||||
if config.current_timestep_callback is None:
|
||||
raise ValueError(
|
||||
"The `current_timestep_callback` function must be provided in the configuration to apply Pyramid Attention Broadcast."
|
||||
)
|
||||
|
||||
if (
|
||||
config.spatial_attention_block_skip_range is None
|
||||
and config.temporal_attention_block_skip_range is None
|
||||
and config.cross_attention_block_skip_range is None
|
||||
):
|
||||
logger.warning(
|
||||
"Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` "
|
||||
"or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. "
|
||||
"To avoid this warning, please set one of the above parameters."
|
||||
)
|
||||
config.spatial_attention_block_skip_range = 2
|
||||
|
||||
for name, submodule in module.named_modules():
|
||||
if not isinstance(submodule, _ATTENTION_CLASSES):
|
||||
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
|
||||
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
|
||||
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
|
||||
continue
|
||||
_apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config)
|
||||
|
||||
|
||||
def _apply_pyramid_attention_broadcast_on_attention_class(
|
||||
name: str, module: Attention, config: PyramidAttentionBroadcastConfig
|
||||
) -> bool:
|
||||
is_spatial_self_attention = (
|
||||
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
|
||||
and config.spatial_attention_block_skip_range is not None
|
||||
and not getattr(module, "is_cross_attention", False)
|
||||
)
|
||||
is_temporal_self_attention = (
|
||||
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
|
||||
and config.temporal_attention_block_skip_range is not None
|
||||
and not getattr(module, "is_cross_attention", False)
|
||||
)
|
||||
is_cross_attention = (
|
||||
any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers)
|
||||
and config.cross_attention_block_skip_range is not None
|
||||
and getattr(module, "is_cross_attention", False)
|
||||
)
|
||||
|
||||
block_skip_range, timestep_skip_range, block_type = None, None, None
|
||||
if is_spatial_self_attention:
|
||||
block_skip_range = config.spatial_attention_block_skip_range
|
||||
timestep_skip_range = config.spatial_attention_timestep_skip_range
|
||||
block_type = "spatial"
|
||||
elif is_temporal_self_attention:
|
||||
block_skip_range = config.temporal_attention_block_skip_range
|
||||
timestep_skip_range = config.temporal_attention_timestep_skip_range
|
||||
block_type = "temporal"
|
||||
elif is_cross_attention:
|
||||
block_skip_range = config.cross_attention_block_skip_range
|
||||
timestep_skip_range = config.cross_attention_timestep_skip_range
|
||||
block_type = "cross"
|
||||
|
||||
if block_skip_range is None or timestep_skip_range is None:
|
||||
logger.info(
|
||||
f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does '
|
||||
f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, "
|
||||
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
|
||||
f"block identifiers in the configuration."
|
||||
)
|
||||
return False
|
||||
|
||||
logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}")
|
||||
_apply_pyramid_attention_broadcast_hook(
|
||||
module, timestep_skip_range, block_skip_range, config.current_timestep_callback
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _apply_pyramid_attention_broadcast_hook(
|
||||
module: Union[Attention, MochiAttention],
|
||||
timestep_skip_range: Tuple[int, int],
|
||||
block_skip_range: int,
|
||||
current_timestep_callback: Callable[[], int],
|
||||
):
|
||||
r"""
|
||||
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to apply Pyramid Attention Broadcast to.
|
||||
timestep_skip_range (`Tuple[int, int]`):
|
||||
The range of timesteps to skip in the attention layer. The attention computations will be conditionally
|
||||
skipped if the current timestep is within the specified range.
|
||||
block_skip_range (`int`):
|
||||
The number of times a specific attention broadcast is skipped before computing the attention states to
|
||||
re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old
|
||||
attention states will be re-used) before computing the new attention states again.
|
||||
current_timestep_callback (`Callable[[], int]`):
|
||||
A callback function that returns the current inference timestep.
|
||||
"""
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
|
||||
registry.register_hook(hook, "pyramid_attention_broadcast")
|
||||
@@ -39,6 +39,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
_import_structure["cache_utils"] = ["CacheMixin"]
|
||||
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_hunyuan"] = [
|
||||
@@ -109,6 +110,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ConsistencyDecoderVAE,
|
||||
VQModel,
|
||||
)
|
||||
from .cache_utils import CacheMixin
|
||||
from .controlnets import (
|
||||
ControlNetModel,
|
||||
ControlNetUnionModel,
|
||||
|
||||
89
src/diffusers/models/cache_utils.py
Normal file
89
src/diffusers/models/cache_utils.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CacheMixin:
|
||||
r"""
|
||||
A class for enable/disabling caching techniques on diffusion models.
|
||||
|
||||
Supported caching techniques:
|
||||
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
|
||||
"""
|
||||
|
||||
_cache_config = None
|
||||
|
||||
@property
|
||||
def is_cache_enabled(self) -> bool:
|
||||
return self._cache_config is not None
|
||||
|
||||
def enable_cache(self, config) -> None:
|
||||
r"""
|
||||
Enable caching techniques on the model.
|
||||
|
||||
Args:
|
||||
config (`Union[PyramidAttentionBroadcastConfig]`):
|
||||
The configuration for applying the caching technique. Currently supported caching techniques are:
|
||||
- [`~hooks.PyramidAttentionBroadcastConfig`]
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
|
||||
|
||||
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> config = PyramidAttentionBroadcastConfig(
|
||||
... spatial_attention_block_skip_range=2,
|
||||
... spatial_attention_timestep_skip_range=(100, 800),
|
||||
... current_timestep_callback=lambda: pipe.current_timestep,
|
||||
... )
|
||||
>>> pipe.transformer.enable_cache(config)
|
||||
```
|
||||
"""
|
||||
|
||||
from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
|
||||
if isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(config)} is not supported.")
|
||||
|
||||
self._cache_config = config
|
||||
|
||||
def disable_cache(self) -> None:
|
||||
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig
|
||||
|
||||
if self._cache_config is None:
|
||||
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
||||
return
|
||||
|
||||
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry.remove_hook("pyramid_attention_broadcast", recurse=True)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
|
||||
|
||||
self._cache_config = None
|
||||
|
||||
def _reset_stateful_cache(self, recurse: bool = True) -> None:
|
||||
from ..hooks import HookRegistry
|
||||
|
||||
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
|
||||
@@ -24,6 +24,7 @@ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_lay
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -156,7 +157,7 @@ class CogVideoXBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -19,13 +20,14 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
|
||||
|
||||
class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
"""
|
||||
|
||||
@@ -24,6 +24,7 @@ from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import AllegroAttnProcessor2_0, Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -172,7 +173,7 @@ class AllegroTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
"""
|
||||
|
||||
@@ -35,6 +35,7 @@ from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, Ad
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
|
||||
@@ -227,7 +228,7 @@ class FluxTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class FluxTransformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Flux.
|
||||
|
||||
@@ -25,6 +25,7 @@ from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings,
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
@@ -502,7 +503,7 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_lay
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -305,7 +306,7 @@ class MochiRoPE(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
|
||||
|
||||
|
||||
@@ -683,6 +683,10 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -815,6 +819,7 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
@@ -892,6 +897,7 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -933,6 +939,8 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
video = self.decode_latents(latents)
|
||||
|
||||
@@ -494,6 +494,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -627,6 +631,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
@@ -705,6 +710,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -763,6 +769,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
# Discard any padding frames that were added for CogVideoX 1.5
|
||||
latents = latents[:, additional_frames:]
|
||||
|
||||
@@ -540,6 +540,10 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -680,6 +684,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
@@ -766,6 +771,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -818,6 +824,8 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
|
||||
@@ -591,6 +591,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -728,6 +732,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
@@ -815,6 +820,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -877,6 +883,8 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
# Discard any padding frames that were added for CogVideoX 1.5
|
||||
latents = latents[:, additional_frames:]
|
||||
|
||||
@@ -564,6 +564,10 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -700,6 +704,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
@@ -786,6 +791,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -844,6 +850,8 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
|
||||
@@ -28,8 +28,7 @@ from transformers import (
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...models import AutoencoderKL, FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -620,6 +619,10 @@ class FluxPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -775,6 +778,7 @@ class FluxPipeline(
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
@@ -899,6 +903,7 @@ class FluxPipeline(
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
if image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
@@ -957,9 +962,10 @@ class FluxPipeline(
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
@@ -456,6 +456,10 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -577,6 +581,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
@@ -644,6 +649,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
@@ -678,6 +684,8 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
@@ -602,6 +602,10 @@ class LattePipeline(DiffusionPipeline):
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -633,7 +637,7 @@ class LattePipeline(DiffusionPipeline):
|
||||
clean_caption: bool = True,
|
||||
mask_feature: bool = True,
|
||||
enable_temporal_attentions: bool = True,
|
||||
decode_chunk_size: Optional[int] = None,
|
||||
decode_chunk_size: int = 14,
|
||||
) -> Union[LattePipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -729,6 +733,7 @@ class LattePipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
@@ -790,6 +795,7 @@ class LattePipeline(DiffusionPipeline):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -850,6 +856,8 @@ class LattePipeline(DiffusionPipeline):
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if output_type == "latents":
|
||||
deprecation_message = (
|
||||
"Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead."
|
||||
@@ -858,7 +866,7 @@ class LattePipeline(DiffusionPipeline):
|
||||
output_type = "latent"
|
||||
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents, video_length, decode_chunk_size=14)
|
||||
video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
@@ -21,8 +21,7 @@ from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import Mochi1LoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKLMochi
|
||||
from ...models.transformers import MochiTransformer3DModel
|
||||
from ...models import AutoencoderKLMochi, MochiTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
is_torch_xla_available,
|
||||
@@ -467,6 +466,10 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -591,6 +594,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
@@ -660,6 +664,9 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need
|
||||
# to make sure we're using the correct non-reversed timestep values.
|
||||
self._current_timestep = 1000 - t
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||
@@ -705,6 +712,8 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
|
||||
@@ -1133,11 +1133,20 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
def maybe_free_model_hooks(self):
|
||||
r"""
|
||||
Function that offloads all components, removes all model hooks that were added when using
|
||||
`enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function
|
||||
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
|
||||
functions correctly when applying enable_model_cpu_offload.
|
||||
Method that performs the following:
|
||||
- Offloads all components.
|
||||
- Removes all model hooks that were added when using `enable_model_cpu_offload`, and then applies them again.
|
||||
In case the model has not been offloaded, this function is a no-op.
|
||||
- Resets stateful diffusers hooks of denoiser components if they were added with
|
||||
[`~hooks.HookRegistry.register_hook`].
|
||||
|
||||
Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions
|
||||
correctly when applying `enable_model_cpu_offload`.
|
||||
"""
|
||||
for component in self.components.values():
|
||||
if hasattr(component, "_reset_stateful_cache"):
|
||||
component._reset_stateful_cache()
|
||||
|
||||
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
|
||||
# `enable_model_cpu_offload` has not be called, so silently do nothing
|
||||
return
|
||||
|
||||
@@ -2,6 +2,40 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class HookRegistry(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(*args, **kwargs):
|
||||
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
|
||||
|
||||
|
||||
class AllegroTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -197,6 +231,21 @@ class AutoencoderTiny(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CacheMixin(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
382
tests/hooks/test_hooks.py
Normal file
382
tests/hooks/test_hooks.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.hooks import HookRegistry, ModelHook
|
||||
from diffusers.training_utils import free_memory
|
||||
from diffusers.utils.logging import get_logger
|
||||
from diffusers.utils.testing_utils import CaptureLogger, torch_device
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DummyBlock(torch.nn.Module):
|
||||
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.proj_in = torch.nn.Linear(in_features, hidden_features)
|
||||
self.activation = torch.nn.ReLU()
|
||||
self.proj_out = torch.nn.Linear(hidden_features, out_features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj_in(x)
|
||||
x = self.activation(x)
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class DummyModel(torch.nn.Module):
|
||||
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
|
||||
self.activation = torch.nn.ReLU()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
|
||||
)
|
||||
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear_1(x)
|
||||
x = self.activation(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class AddHook(ModelHook):
|
||||
def __init__(self, value: int):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
logger.debug("AddHook pre_forward")
|
||||
args = ((x + self.value) if torch.is_tensor(x) else x for x in args)
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module, output):
|
||||
logger.debug("AddHook post_forward")
|
||||
return output
|
||||
|
||||
|
||||
class MultiplyHook(ModelHook):
|
||||
def __init__(self, value: int):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
logger.debug("MultiplyHook pre_forward")
|
||||
args = ((x * self.value) if torch.is_tensor(x) else x for x in args)
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module, output):
|
||||
logger.debug("MultiplyHook post_forward")
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f"MultiplyHook(value={self.value})"
|
||||
|
||||
|
||||
class StatefulAddHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, value: int):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
self.increment = 0
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
logger.debug("StatefulAddHook pre_forward")
|
||||
add_value = self.value + self.increment
|
||||
self.increment += 1
|
||||
args = ((x + add_value) if torch.is_tensor(x) else x for x in args)
|
||||
return args, kwargs
|
||||
|
||||
def reset_state(self, module):
|
||||
self.increment = 0
|
||||
|
||||
|
||||
class SkipLayerHook(ModelHook):
|
||||
def __init__(self, skip_layer: bool):
|
||||
super().__init__()
|
||||
self.skip_layer = skip_layer
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
logger.debug("SkipLayerHook pre_forward")
|
||||
return args, kwargs
|
||||
|
||||
def new_forward(self, module, *args, **kwargs):
|
||||
logger.debug("SkipLayerHook new_forward")
|
||||
if self.skip_layer:
|
||||
return args[0]
|
||||
return self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
def post_forward(self, module, output):
|
||||
logger.debug("SkipLayerHook post_forward")
|
||||
return output
|
||||
|
||||
|
||||
class HookTests(unittest.TestCase):
|
||||
in_features = 4
|
||||
hidden_features = 8
|
||||
out_features = 4
|
||||
num_layers = 2
|
||||
|
||||
def setUp(self):
|
||||
params = self.get_module_parameters()
|
||||
self.model = DummyModel(**params)
|
||||
self.model.to(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
del self.model
|
||||
gc.collect()
|
||||
free_memory()
|
||||
|
||||
def get_module_parameters(self):
|
||||
return {
|
||||
"in_features": self.in_features,
|
||||
"hidden_features": self.hidden_features,
|
||||
"out_features": self.out_features,
|
||||
"num_layers": self.num_layers,
|
||||
}
|
||||
|
||||
def get_generator(self):
|
||||
return torch.manual_seed(0)
|
||||
|
||||
def test_hook_registry(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
registry.register_hook(AddHook(1), "add_hook")
|
||||
registry.register_hook(MultiplyHook(2), "multiply_hook")
|
||||
|
||||
registry_repr = repr(registry)
|
||||
expected_repr = (
|
||||
"HookRegistry(\n" " (0) add_hook - AddHook\n" " (1) multiply_hook - MultiplyHook(value=2)\n" ")"
|
||||
)
|
||||
|
||||
self.assertEqual(len(registry.hooks), 2)
|
||||
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
|
||||
self.assertEqual(registry_repr, expected_repr)
|
||||
|
||||
registry.remove_hook("add_hook")
|
||||
|
||||
self.assertEqual(len(registry.hooks), 1)
|
||||
self.assertEqual(registry._hook_order, ["multiply_hook"])
|
||||
|
||||
def test_stateful_hook(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
registry.register_hook(StatefulAddHook(1), "stateful_add_hook")
|
||||
|
||||
self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0)
|
||||
|
||||
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
|
||||
num_repeats = 3
|
||||
|
||||
for i in range(num_repeats):
|
||||
result = self.model(input)
|
||||
if i == 0:
|
||||
output1 = result
|
||||
|
||||
self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats)
|
||||
|
||||
registry.reset_stateful_hooks()
|
||||
output2 = self.model(input)
|
||||
|
||||
self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1)
|
||||
self.assertTrue(torch.allclose(output1, output2))
|
||||
|
||||
def test_inference(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
registry.register_hook(AddHook(1), "add_hook")
|
||||
registry.register_hook(MultiplyHook(2), "multiply_hook")
|
||||
|
||||
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
|
||||
output1 = self.model(input).mean().detach().cpu().item()
|
||||
|
||||
registry.remove_hook("multiply_hook")
|
||||
new_input = input * 2
|
||||
output2 = self.model(new_input).mean().detach().cpu().item()
|
||||
|
||||
registry.remove_hook("add_hook")
|
||||
new_input = input * 2 + 1
|
||||
output3 = self.model(new_input).mean().detach().cpu().item()
|
||||
|
||||
self.assertAlmostEqual(output1, output2, places=5)
|
||||
self.assertAlmostEqual(output1, output3, places=5)
|
||||
|
||||
def test_skip_layer_hook(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
|
||||
|
||||
input = torch.zeros(1, 4, device=torch_device)
|
||||
output = self.model(input).mean().detach().cpu().item()
|
||||
self.assertEqual(output, 0.0)
|
||||
|
||||
registry.remove_hook("skip_layer_hook")
|
||||
registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook")
|
||||
output = self.model(input).mean().detach().cpu().item()
|
||||
self.assertNotEqual(output, 0.0)
|
||||
|
||||
def test_skip_layer_internal_block(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1)
|
||||
input = torch.zeros(1, 4, device=torch_device)
|
||||
|
||||
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
self.model(input).mean().detach().cpu().item()
|
||||
self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception))
|
||||
|
||||
registry.remove_hook("skip_layer_hook")
|
||||
output = self.model(input).mean().detach().cpu().item()
|
||||
self.assertNotEqual(output, 0.0)
|
||||
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1])
|
||||
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
|
||||
output = self.model(input).mean().detach().cpu().item()
|
||||
self.assertNotEqual(output, 0.0)
|
||||
|
||||
def test_invocation_order_stateful_first(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
registry.register_hook(StatefulAddHook(1), "add_hook")
|
||||
registry.register_hook(AddHook(2), "add_hook_2")
|
||||
registry.register_hook(MultiplyHook(3), "multiply_hook")
|
||||
|
||||
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.setLevel("DEBUG")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
self.model(input)
|
||||
output = cap_logger.out.replace(" ", "").replace("\n", "")
|
||||
expected_invocation_order_log = (
|
||||
(
|
||||
"MultiplyHook pre_forward\n"
|
||||
"AddHook pre_forward\n"
|
||||
"StatefulAddHook pre_forward\n"
|
||||
"AddHook post_forward\n"
|
||||
"MultiplyHook post_forward\n"
|
||||
)
|
||||
.replace(" ", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
|
||||
registry.remove_hook("add_hook")
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
self.model(input)
|
||||
output = cap_logger.out.replace(" ", "").replace("\n", "")
|
||||
expected_invocation_order_log = (
|
||||
(
|
||||
"MultiplyHook pre_forward\n"
|
||||
"AddHook pre_forward\n"
|
||||
"AddHook post_forward\n"
|
||||
"MultiplyHook post_forward\n"
|
||||
)
|
||||
.replace(" ", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
|
||||
def test_invocation_order_stateful_middle(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
registry.register_hook(AddHook(2), "add_hook")
|
||||
registry.register_hook(StatefulAddHook(1), "add_hook_2")
|
||||
registry.register_hook(MultiplyHook(3), "multiply_hook")
|
||||
|
||||
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.setLevel("DEBUG")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
self.model(input)
|
||||
output = cap_logger.out.replace(" ", "").replace("\n", "")
|
||||
expected_invocation_order_log = (
|
||||
(
|
||||
"MultiplyHook pre_forward\n"
|
||||
"StatefulAddHook pre_forward\n"
|
||||
"AddHook pre_forward\n"
|
||||
"AddHook post_forward\n"
|
||||
"MultiplyHook post_forward\n"
|
||||
)
|
||||
.replace(" ", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
|
||||
registry.remove_hook("add_hook")
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
self.model(input)
|
||||
output = cap_logger.out.replace(" ", "").replace("\n", "")
|
||||
expected_invocation_order_log = (
|
||||
("MultiplyHook pre_forward\nStatefulAddHook pre_forward\nMultiplyHook post_forward\n")
|
||||
.replace(" ", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
|
||||
registry.remove_hook("add_hook_2")
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
self.model(input)
|
||||
output = cap_logger.out.replace(" ", "").replace("\n", "")
|
||||
expected_invocation_order_log = (
|
||||
("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "")
|
||||
)
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
|
||||
def test_invocation_order_stateful_last(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
registry.register_hook(AddHook(1), "add_hook")
|
||||
registry.register_hook(MultiplyHook(2), "multiply_hook")
|
||||
registry.register_hook(StatefulAddHook(3), "add_hook_2")
|
||||
|
||||
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.setLevel("DEBUG")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
self.model(input)
|
||||
output = cap_logger.out.replace(" ", "").replace("\n", "")
|
||||
expected_invocation_order_log = (
|
||||
(
|
||||
"StatefulAddHook pre_forward\n"
|
||||
"MultiplyHook pre_forward\n"
|
||||
"AddHook pre_forward\n"
|
||||
"AddHook post_forward\n"
|
||||
"MultiplyHook post_forward\n"
|
||||
)
|
||||
.replace(" ", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
|
||||
registry.remove_hook("add_hook")
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
self.model(input)
|
||||
output = cap_logger.out.replace(" ", "").replace("\n", "")
|
||||
expected_invocation_order_log = (
|
||||
("StatefulAddHook pre_forward\nMultiplyHook pre_forward\nMultiplyHook post_forward\n")
|
||||
.replace(" ", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
@@ -34,13 +34,13 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class AllegroPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
|
||||
pipeline_class = AllegroPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
@@ -59,14 +59,14 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = AllegroTransformer3DModel(
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
num_layers=1,
|
||||
num_layers=num_layers,
|
||||
cross_attention_dim=24,
|
||||
sample_width=8,
|
||||
sample_height=8,
|
||||
|
||||
@@ -32,6 +32,7 @@ from diffusers.utils.testing_utils import (
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import (
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
to_np,
|
||||
@@ -41,7 +42,7 @@ from ..test_pipelines_common import (
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
|
||||
pipeline_class = CogVideoXPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
@@ -60,7 +61,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = CogVideoXTransformer3DModel(
|
||||
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
|
||||
@@ -72,7 +73,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
out_channels=4,
|
||||
time_embed_dim=2,
|
||||
text_embed_dim=32, # Must match with tiny-random-t5
|
||||
num_layers=1,
|
||||
num_layers=num_layers,
|
||||
sample_width=2, # latent width: 2 -> final width: 16
|
||||
sample_height=2, # latent height: 2 -> final height: 16
|
||||
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9
|
||||
|
||||
@@ -19,12 +19,15 @@ from diffusers.utils.testing_utils import (
|
||||
from ..test_pipelines_common import (
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
|
||||
|
||||
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
|
||||
class FluxPipelineFastTests(
|
||||
unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin
|
||||
):
|
||||
pipeline_class = FluxPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
@@ -33,13 +36,13 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = FluxTransformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
num_layers=1,
|
||||
num_single_layers=1,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=32,
|
||||
|
||||
@@ -30,13 +30,13 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
|
||||
pipeline_class = HunyuanVideoPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
@@ -55,15 +55,15 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = HunyuanVideoTransformer3DModel(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=10,
|
||||
num_layers=1,
|
||||
num_single_layers=1,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
num_refiner_layers=1,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
|
||||
@@ -27,6 +27,7 @@ from diffusers import (
|
||||
DDIMScheduler,
|
||||
LattePipeline,
|
||||
LatteTransformer3DModel,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
@@ -38,13 +39,13 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LattePipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
@@ -54,11 +55,23 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params
|
||||
test_layerwise_casting = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
pab_config = PyramidAttentionBroadcastConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
temporal_attention_block_skip_range=2,
|
||||
cross_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(100, 700),
|
||||
temporal_attention_timestep_skip_range=(100, 800),
|
||||
cross_attention_timestep_skip_range=(100, 800),
|
||||
spatial_attention_block_identifiers=["transformer_blocks"],
|
||||
temporal_attention_block_identifiers=["temporal_transformer_blocks"],
|
||||
cross_attention_block_identifiers=["transformer_blocks"],
|
||||
)
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = LatteTransformer3DModel(
|
||||
sample_size=8,
|
||||
num_layers=1,
|
||||
num_layers=num_layers,
|
||||
patch_size=2,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=3,
|
||||
|
||||
@@ -24,10 +24,12 @@ from diffusers import (
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
KolorsPipeline,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
@@ -2322,6 +2324,141 @@ class SDXLOptionalComponentsTesterMixin:
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastTesterMixin:
|
||||
pab_config = PyramidAttentionBroadcastConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(100, 800),
|
||||
spatial_attention_block_identifiers=["transformer_blocks"],
|
||||
)
|
||||
|
||||
def test_pyramid_attention_broadcast_layers(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
num_layers = 0
|
||||
num_single_layers = 0
|
||||
dummy_component_kwargs = {}
|
||||
dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters
|
||||
if "num_layers" in dummy_component_parameters:
|
||||
num_layers = 2
|
||||
dummy_component_kwargs["num_layers"] = num_layers
|
||||
if "num_single_layers" in dummy_component_parameters:
|
||||
num_single_layers = 2
|
||||
dummy_component_kwargs["num_single_layers"] = num_single_layers
|
||||
|
||||
components = self.get_dummy_components(**dummy_component_kwargs)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
self.pab_config.current_timestep_callback = lambda: pipe.current_timestep
|
||||
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
|
||||
denoiser.enable_cache(self.pab_config)
|
||||
|
||||
expected_hooks = 0
|
||||
if self.pab_config.spatial_attention_block_skip_range is not None:
|
||||
expected_hooks += num_layers + num_single_layers
|
||||
if self.pab_config.temporal_attention_block_skip_range is not None:
|
||||
expected_hooks += num_layers + num_single_layers
|
||||
if self.pab_config.cross_attention_block_skip_range is not None:
|
||||
expected_hooks += num_layers + num_single_layers
|
||||
|
||||
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
|
||||
count = 0
|
||||
for module in denoiser.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
|
||||
if hook is None:
|
||||
continue
|
||||
count += 1
|
||||
self.assertTrue(
|
||||
isinstance(hook, PyramidAttentionBroadcastHook),
|
||||
"Hook should be of type PyramidAttentionBroadcastHook.",
|
||||
)
|
||||
self.assertTrue(hook.state.cache is None, "Cache should be None at initialization.")
|
||||
self.assertEqual(count, expected_hooks, "Number of hooks should match the expected number.")
|
||||
|
||||
# Perform dummy inference step to ensure state is updated
|
||||
def pab_state_check_callback(pipe, i, t, kwargs):
|
||||
for module in denoiser.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
|
||||
if hook is None:
|
||||
continue
|
||||
self.assertTrue(
|
||||
hook.state.cache is not None,
|
||||
"Cache should have updated during inference.",
|
||||
)
|
||||
self.assertTrue(
|
||||
hook.state.iteration == i + 1,
|
||||
"Hook iteration state should have updated during inference.",
|
||||
)
|
||||
return {}
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 2
|
||||
inputs["callback_on_step_end"] = pab_state_check_callback
|
||||
pipe(**inputs)[0]
|
||||
|
||||
# After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states
|
||||
for module in denoiser.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
|
||||
if hook is None:
|
||||
continue
|
||||
self.assertTrue(
|
||||
hook.state.cache is None,
|
||||
"Cache should be reset to None after inference.",
|
||||
)
|
||||
self.assertTrue(
|
||||
hook.state.iteration == 0,
|
||||
"Iteration should be reset to 0 after inference.",
|
||||
)
|
||||
|
||||
def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2):
|
||||
# We need to use higher tolerance because we are using a random model. With a converged/trained
|
||||
# model, the tolerance can be lower.
|
||||
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
num_layers = 2
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Run inference without PAB
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
output = pipe(**inputs)[0]
|
||||
original_image_slice = output.flatten()
|
||||
original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:]))
|
||||
|
||||
# Run inference with PAB enabled
|
||||
self.pab_config.current_timestep_callback = lambda: pipe.current_timestep
|
||||
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
|
||||
denoiser.enable_cache(self.pab_config)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
output = pipe(**inputs)[0]
|
||||
image_slice_pab_enabled = output.flatten()
|
||||
image_slice_pab_enabled = np.concatenate((image_slice_pab_enabled[:8], image_slice_pab_enabled[-8:]))
|
||||
|
||||
# Run inference with PAB disabled
|
||||
denoiser.disable_cache()
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
output = pipe(**inputs)[0]
|
||||
image_slice_pab_disabled = output.flatten()
|
||||
image_slice_pab_disabled = np.concatenate((image_slice_pab_disabled[:8], image_slice_pab_disabled[-8:]))
|
||||
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_pab_enabled, atol=expected_atol
|
||||
), "PAB outputs should not differ much in specified timestep range."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_pab_disabled, atol=1e-4
|
||||
), "Outputs from normal inference and after disabling cache should not differ."
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
# reference image.
|
||||
|
||||
Reference in New Issue
Block a user