mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[core] FasterCache (#10163)
* init * update * update * update * make style * update * fix * make it work with guidance distilled models * update * make fix-copies * add tests * update * apply_faster_cache -> apply_fastercache * fix * reorder * update * refactor * update docs * add fastercache to CacheMixin * update tests * Apply suggestions from code review * make style * try to fix partial import error * Apply style fixes * raise warning * update --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -38,6 +38,33 @@ config = PyramidAttentionBroadcastConfig(
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## Faster Cache
|
||||
|
||||
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
|
||||
|
||||
FasterCache is a method that speeds up inference in diffusion transformers by:
|
||||
- Reusing attention states between successive inference steps, due to high similarity between them
|
||||
- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, FasterCacheConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
config = FasterCacheConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(-1, 681),
|
||||
current_timestep_callback=lambda: pipe.current_timestep,
|
||||
attention_weight_callback=lambda _: 0.3,
|
||||
unconditional_batch_skip_range=5,
|
||||
unconditional_batch_timestep_skip_range=(-1, 781),
|
||||
tensor_format="BFCHW",
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
### CacheMixin
|
||||
|
||||
[[autodoc]] CacheMixin
|
||||
@@ -47,3 +74,9 @@ pipe.transformer.enable_cache(config)
|
||||
[[autodoc]] PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] apply_pyramid_attention_broadcast
|
||||
|
||||
### FasterCacheConfig
|
||||
|
||||
[[autodoc]] FasterCacheConfig
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
@@ -131,8 +131,10 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"FasterCacheConfig",
|
||||
"HookRegistry",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
@@ -703,7 +705,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .hooks import (
|
||||
FasterCacheConfig,
|
||||
HookRegistry,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
from .models import (
|
||||
AllegroTransformer3DModel,
|
||||
AsymmetricAutoencoderKL,
|
||||
|
||||
@@ -2,6 +2,7 @@ from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
|
||||
653
src/diffusers/hooks/faster_cache.py
Normal file
653
src/diffusers/hooks/faster_cache.py
Normal file
@@ -0,0 +1,653 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
from ..models.modeling_outputs import Transformer2DModelOutput
|
||||
from ..utils import logging
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
|
||||
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
|
||||
"^blocks.*attn",
|
||||
"^transformer_blocks.*attn",
|
||||
"^single_transformer_blocks.*attn",
|
||||
)
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
|
||||
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = (
|
||||
"hidden_states",
|
||||
"encoder_hidden_states",
|
||||
"timestep",
|
||||
"attention_mask",
|
||||
"encoder_attention_mask",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FasterCacheConfig:
|
||||
r"""
|
||||
Configuration for [FasterCache](https://huggingface.co/papers/2410.19355).
|
||||
|
||||
Attributes:
|
||||
spatial_attention_block_skip_range (`int`, defaults to `2`):
|
||||
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
|
||||
be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
|
||||
states again.
|
||||
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
||||
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
|
||||
be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
|
||||
states again.
|
||||
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
|
||||
The timestep range within which the spatial attention computation can be skipped without a significant loss
|
||||
in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
||||
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
|
||||
timestep 0). For the default values, this would mean that the spatial attention computation skipping will
|
||||
be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
|
||||
process.
|
||||
temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
|
||||
The timestep range within which the temporal attention computation can be skipped without a significant
|
||||
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
||||
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
|
||||
timestep 0).
|
||||
low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
|
||||
The timestep range within which the low frequency weight scaling update is applied. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
||||
function for the update is called only within this range.
|
||||
high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
|
||||
The timestep range within which the high frequency weight scaling update is applied. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
||||
function for the update is called only within this range.
|
||||
alpha_low_frequency (`float`, defaults to `1.1`):
|
||||
The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from
|
||||
the conditional branch outputs.
|
||||
alpha_high_frequency (`float`, defaults to `1.1`):
|
||||
The weight to scale the high frequency updates by. This is used to approximate the unconditional branch
|
||||
from the conditional branch outputs.
|
||||
unconditional_batch_skip_range (`int`, defaults to `5`):
|
||||
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
|
||||
computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before
|
||||
computing the new unconditional branch states again.
|
||||
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
|
||||
The timestep range within which the unconditional branch computation can be skipped without a significant
|
||||
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound.
|
||||
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
|
||||
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
|
||||
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
||||
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
||||
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
|
||||
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
|
||||
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
||||
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
||||
attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
||||
The callback function to determine the weight to scale the attention outputs by. This function should take
|
||||
the attention module as input and return a float value. This is used to approximate the unconditional
|
||||
branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps.
|
||||
Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference
|
||||
progresses. Users are encouraged to experiment and provide custom weight schedules that take into account
|
||||
the number of inference steps and underlying model behaviour as denoising progresses.
|
||||
low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
||||
The callback function to determine the weight to scale the low frequency updates by. If not provided, the
|
||||
default weight is 1.1 for timesteps within the range specified (as described in the paper).
|
||||
high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
||||
The callback function to determine the weight to scale the high frequency updates by. If not provided, the
|
||||
default weight is 1.1 for timesteps within the range specified (as described in the paper).
|
||||
tensor_format (`str`, defaults to `"BCFHW"`):
|
||||
The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is
|
||||
used to split individual latent frames in order for low and high frequency components to be computed.
|
||||
is_guidance_distilled (`bool`, defaults to `False`):
|
||||
Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
|
||||
applied at the denoiser-level to skip the unconditional branch computation (as there is none).
|
||||
_unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
|
||||
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
|
||||
conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
|
||||
split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
|
||||
names that contain the batchwise-concatenated unconditional and conditional inputs.
|
||||
"""
|
||||
|
||||
# In the paper and codebase, they hardcode these values to 2. However, it can be made configurable
|
||||
# after some testing. We default to 2 if these parameters are not provided.
|
||||
spatial_attention_block_skip_range: int = 2
|
||||
temporal_attention_block_skip_range: Optional[int] = None
|
||||
|
||||
spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
|
||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
|
||||
|
||||
# Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
|
||||
low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
|
||||
high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
|
||||
|
||||
# ⍺1 and ⍺2 as mentioned in Equation 11 of the paper
|
||||
alpha_low_frequency: float = 1.1
|
||||
alpha_high_frequency: float = 1.1
|
||||
|
||||
# n as described in CFG-Cache explanation in the paper - dependant on the model
|
||||
unconditional_batch_skip_range: int = 5
|
||||
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
|
||||
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
|
||||
attention_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
|
||||
tensor_format: str = "BCFHW"
|
||||
is_guidance_distilled: bool = False
|
||||
|
||||
current_timestep_callback: Callable[[], int] = None
|
||||
|
||||
_unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"FasterCacheConfig(\n"
|
||||
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
|
||||
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
|
||||
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
|
||||
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
|
||||
f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n"
|
||||
f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n"
|
||||
f" alpha_low_frequency={self.alpha_low_frequency},\n"
|
||||
f" alpha_high_frequency={self.alpha_high_frequency},\n"
|
||||
f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n"
|
||||
f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n"
|
||||
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
|
||||
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
|
||||
f" tensor_format={self.tensor_format},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class FasterCacheDenoiserState:
|
||||
r"""
|
||||
State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.iteration: int = 0
|
||||
self.low_frequency_delta: torch.Tensor = None
|
||||
self.high_frequency_delta: torch.Tensor = None
|
||||
|
||||
def reset(self):
|
||||
self.iteration = 0
|
||||
self.low_frequency_delta = None
|
||||
self.high_frequency_delta = None
|
||||
|
||||
|
||||
class FasterCacheBlockState:
|
||||
r"""
|
||||
State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is
|
||||
applied to will have an instance of this state.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.iteration: int = 0
|
||||
self.batch_size: int = None
|
||||
self.cache: Tuple[torch.Tensor, torch.Tensor] = None
|
||||
|
||||
def reset(self):
|
||||
self.iteration = 0
|
||||
self.batch_size = None
|
||||
self.cache = None
|
||||
|
||||
|
||||
class FasterCacheDenoiserHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unconditional_batch_skip_range: int,
|
||||
unconditional_batch_timestep_skip_range: Tuple[int, int],
|
||||
tensor_format: str,
|
||||
is_guidance_distilled: bool,
|
||||
uncond_cond_input_kwargs_identifiers: List[str],
|
||||
current_timestep_callback: Callable[[], int],
|
||||
low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
||||
high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.unconditional_batch_skip_range = unconditional_batch_skip_range
|
||||
self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range
|
||||
# We can't easily detect what args are to be split in unconditional and conditional branches. We
|
||||
# can only do it for kwargs, hence they are the only ones we split. The args are passed as-is.
|
||||
# If a model is to be made compatible with FasterCache, the user must ensure that the inputs that
|
||||
# contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs.
|
||||
self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers
|
||||
self.tensor_format = tensor_format
|
||||
self.is_guidance_distilled = is_guidance_distilled
|
||||
|
||||
self.current_timestep_callback = current_timestep_callback
|
||||
self.low_frequency_weight_callback = low_frequency_weight_callback
|
||||
self.high_frequency_weight_callback = high_frequency_weight_callback
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self.state = FasterCacheDenoiserState()
|
||||
return module
|
||||
|
||||
@staticmethod
|
||||
def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
|
||||
# followed by conditional inputs.
|
||||
_, cond = input.chunk(2, dim=0)
|
||||
return cond
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
||||
# Split the unconditional and conditional inputs. We only want to infer the conditional branch if the
|
||||
# requirements for skipping the unconditional branch are met as described in the paper.
|
||||
# We skip the unconditional branch only if the following conditions are met:
|
||||
# 1. We have completed at least one iteration of the denoiser
|
||||
# 2. The current timestep is within the range specified by the user. This is the optimal timestep range
|
||||
# where approximating the unconditional branch from the computation of the conditional branch is possible
|
||||
# without a significant loss in quality.
|
||||
# 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
|
||||
# we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
|
||||
is_within_timestep_range = (
|
||||
self.unconditional_batch_timestep_skip_range[0]
|
||||
< self.current_timestep_callback()
|
||||
< self.unconditional_batch_timestep_skip_range[1]
|
||||
)
|
||||
should_skip_uncond = (
|
||||
self.state.iteration > 0
|
||||
and is_within_timestep_range
|
||||
and self.state.iteration % self.unconditional_batch_skip_range != 0
|
||||
and not self.is_guidance_distilled
|
||||
)
|
||||
|
||||
if should_skip_uncond:
|
||||
is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys())
|
||||
if is_any_kwarg_uncond:
|
||||
logger.debug("FasterCache - Skipping unconditional branch computation")
|
||||
args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args])
|
||||
kwargs = {
|
||||
k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v)
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
if self.is_guidance_distilled:
|
||||
self.state.iteration += 1
|
||||
return output
|
||||
|
||||
if torch.is_tensor(output):
|
||||
hidden_states = output
|
||||
elif isinstance(output, (tuple, Transformer2DModelOutput)):
|
||||
hidden_states = output[0]
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
if should_skip_uncond:
|
||||
self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback(
|
||||
module
|
||||
)
|
||||
self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback(
|
||||
module
|
||||
)
|
||||
|
||||
if self.tensor_format == "BCFHW":
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
||||
hidden_states = hidden_states.flatten(0, 1)
|
||||
|
||||
low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float())
|
||||
|
||||
# Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper
|
||||
low_freq_uncond = self.state.low_frequency_delta + low_freq_cond
|
||||
high_freq_uncond = self.state.high_frequency_delta + high_freq_cond
|
||||
uncond_freq = low_freq_uncond + high_freq_uncond
|
||||
|
||||
uncond_states = torch.fft.ifftshift(uncond_freq)
|
||||
uncond_states = torch.fft.ifft2(uncond_states).real
|
||||
|
||||
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
||||
uncond_states = uncond_states.unflatten(0, (batch_size, -1))
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1))
|
||||
if self.tensor_format == "BCFHW":
|
||||
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
|
||||
# Concatenate the approximated unconditional and predicted conditional branches
|
||||
uncond_states = uncond_states.to(hidden_states.dtype)
|
||||
hidden_states = torch.cat([uncond_states, hidden_states], dim=0)
|
||||
else:
|
||||
uncond_states, cond_states = hidden_states.chunk(2, dim=0)
|
||||
if self.tensor_format == "BCFHW":
|
||||
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
|
||||
cond_states = cond_states.permute(0, 2, 1, 3, 4)
|
||||
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
||||
uncond_states = uncond_states.flatten(0, 1)
|
||||
cond_states = cond_states.flatten(0, 1)
|
||||
|
||||
low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float())
|
||||
low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float())
|
||||
self.state.low_frequency_delta = low_freq_uncond - low_freq_cond
|
||||
self.state.high_frequency_delta = high_freq_uncond - high_freq_cond
|
||||
|
||||
self.state.iteration += 1
|
||||
if torch.is_tensor(output):
|
||||
output = hidden_states
|
||||
elif isinstance(output, tuple):
|
||||
output = (hidden_states, *output[1:])
|
||||
else:
|
||||
output.sample = hidden_states
|
||||
|
||||
return output
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
self.state.reset()
|
||||
return module
|
||||
|
||||
|
||||
class FasterCacheBlockHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_skip_range: int,
|
||||
timestep_skip_range: Tuple[int, int],
|
||||
is_guidance_distilled: bool,
|
||||
weight_callback: Callable[[torch.nn.Module], float],
|
||||
current_timestep_callback: Callable[[], int],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.block_skip_range = block_skip_range
|
||||
self.timestep_skip_range = timestep_skip_range
|
||||
self.is_guidance_distilled = is_guidance_distilled
|
||||
|
||||
self.weight_callback = weight_callback
|
||||
self.current_timestep_callback = current_timestep_callback
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self.state = FasterCacheBlockState()
|
||||
return module
|
||||
|
||||
def _compute_approximated_attention_output(
|
||||
self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int
|
||||
) -> torch.Tensor:
|
||||
if t_2_output.size(0) != batch_size:
|
||||
# The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
|
||||
# take the conditional branch outputs.
|
||||
assert t_2_output.size(0) == 2 * batch_size
|
||||
t_2_output = t_2_output[batch_size:]
|
||||
if t_output.size(0) != batch_size:
|
||||
# The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
|
||||
# take the conditional branch outputs.
|
||||
assert t_output.size(0) == 2 * batch_size
|
||||
t_output = t_output[batch_size:]
|
||||
return t_output + (t_output - t_2_output) * weight
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
||||
batch_size = [
|
||||
*[arg.size(0) for arg in args if torch.is_tensor(arg)],
|
||||
*[v.size(0) for v in kwargs.values() if torch.is_tensor(v)],
|
||||
][0]
|
||||
if self.state.batch_size is None:
|
||||
# Will be updated on first forward pass through the denoiser
|
||||
self.state.batch_size = batch_size
|
||||
|
||||
# If we have to skip due to the skip conditions, then let's skip as expected.
|
||||
# But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This
|
||||
# is because the expected output shapes of attention layer will not match if we only return values from
|
||||
# the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true
|
||||
# unconditional-conditional batch size) is same as the current batch size, we don't perform the layer
|
||||
# skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns.
|
||||
is_within_timestep_range = (
|
||||
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
|
||||
)
|
||||
if not is_within_timestep_range:
|
||||
should_skip_attention = False
|
||||
else:
|
||||
should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0
|
||||
should_skip_attention = not should_compute_attention
|
||||
if should_skip_attention:
|
||||
should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size
|
||||
|
||||
if should_skip_attention:
|
||||
logger.debug("FasterCache - Skipping attention and using approximation")
|
||||
if torch.is_tensor(self.state.cache[-1]):
|
||||
t_2_output, t_output = self.state.cache
|
||||
weight = self.weight_callback(module)
|
||||
output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size)
|
||||
else:
|
||||
# The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them.
|
||||
# Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity.
|
||||
# In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from
|
||||
# a forward pass of the block. We need to compute the approximated output for each of these tensors.
|
||||
# The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which
|
||||
# allows us to compute the approximated attention output for each tensor in the cache.
|
||||
output = ()
|
||||
for t_2_output, t_output in zip(*self.state.cache):
|
||||
result = self._compute_approximated_attention_output(
|
||||
t_2_output, t_output, self.weight_callback(module), batch_size
|
||||
)
|
||||
output += (result,)
|
||||
else:
|
||||
logger.debug("FasterCache - Computing attention")
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
# Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return
|
||||
# a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle
|
||||
# both cases.
|
||||
if torch.is_tensor(output):
|
||||
cache_output = output
|
||||
if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size:
|
||||
# The output here can be both unconditional-conditional branch outputs or just conditional branch outputs.
|
||||
# This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs.
|
||||
cache_output = cache_output.chunk(2, dim=0)[1]
|
||||
else:
|
||||
# Cache all return values and perform the same operation as above
|
||||
cache_output = ()
|
||||
for out in output:
|
||||
if not self.is_guidance_distilled and out.size(0) == self.state.batch_size:
|
||||
out = out.chunk(2, dim=0)[1]
|
||||
cache_output += (out,)
|
||||
|
||||
if self.state.cache is None:
|
||||
self.state.cache = [cache_output, cache_output]
|
||||
else:
|
||||
self.state.cache = [self.state.cache[-1], cache_output]
|
||||
|
||||
self.state.iteration += 1
|
||||
return output
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
self.state.reset()
|
||||
return module
|
||||
|
||||
|
||||
def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None:
|
||||
r"""
|
||||
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
|
||||
|
||||
Args:
|
||||
pipeline (`DiffusionPipeline`):
|
||||
The diffusion pipeline to apply FasterCache to.
|
||||
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
|
||||
The configuration to use for FasterCache.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache
|
||||
|
||||
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> config = FasterCacheConfig(
|
||||
... spatial_attention_block_skip_range=2,
|
||||
... spatial_attention_timestep_skip_range=(-1, 681),
|
||||
... low_frequency_weight_update_timestep_range=(99, 641),
|
||||
... high_frequency_weight_update_timestep_range=(-1, 301),
|
||||
... spatial_attention_block_identifiers=["transformer_blocks"],
|
||||
... attention_weight_callback=lambda _: 0.3,
|
||||
... tensor_format="BFCHW",
|
||||
... )
|
||||
>>> apply_faster_cache(pipe.transformer, config)
|
||||
```
|
||||
"""
|
||||
|
||||
logger.warning(
|
||||
"FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. "
|
||||
"The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at "
|
||||
"https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
|
||||
if config.attention_weight_callback is None:
|
||||
# If the user has not provided a weight callback, we default to 0.5 for all timesteps.
|
||||
# In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
|
||||
# this depends from model-to-model. It is required by the user to provide a weight callback if they want to
|
||||
# use a different weight function. Defaulting to 0.5 works well in practice for most cases.
|
||||
logger.warning(
|
||||
"No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
|
||||
)
|
||||
config.attention_weight_callback = lambda _: 0.5
|
||||
|
||||
if config.low_frequency_weight_callback is None:
|
||||
logger.debug(
|
||||
"Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
|
||||
)
|
||||
|
||||
def low_frequency_weight_callback(module: torch.nn.Module) -> float:
|
||||
is_within_range = (
|
||||
config.low_frequency_weight_update_timestep_range[0]
|
||||
< config.current_timestep_callback()
|
||||
< config.low_frequency_weight_update_timestep_range[1]
|
||||
)
|
||||
return config.alpha_low_frequency if is_within_range else 1.0
|
||||
|
||||
config.low_frequency_weight_callback = low_frequency_weight_callback
|
||||
|
||||
if config.high_frequency_weight_callback is None:
|
||||
logger.debug(
|
||||
"High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
|
||||
)
|
||||
|
||||
def high_frequency_weight_callback(module: torch.nn.Module) -> float:
|
||||
is_within_range = (
|
||||
config.high_frequency_weight_update_timestep_range[0]
|
||||
< config.current_timestep_callback()
|
||||
< config.high_frequency_weight_update_timestep_range[1]
|
||||
)
|
||||
return config.alpha_high_frequency if is_within_range else 1.0
|
||||
|
||||
config.high_frequency_weight_callback = high_frequency_weight_callback
|
||||
|
||||
supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video
|
||||
if config.tensor_format not in supported_tensor_formats:
|
||||
raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
|
||||
|
||||
_apply_faster_cache_on_denoiser(module, config)
|
||||
|
||||
for name, submodule in module.named_modules():
|
||||
if not isinstance(submodule, _ATTENTION_CLASSES):
|
||||
continue
|
||||
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
|
||||
_apply_faster_cache_on_attention_class(name, submodule, config)
|
||||
|
||||
|
||||
def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None:
|
||||
hook = FasterCacheDenoiserHook(
|
||||
config.unconditional_batch_skip_range,
|
||||
config.unconditional_batch_timestep_skip_range,
|
||||
config.tensor_format,
|
||||
config.is_guidance_distilled,
|
||||
config._unconditional_conditional_input_kwargs_identifiers,
|
||||
config.current_timestep_callback,
|
||||
config.low_frequency_weight_callback,
|
||||
config.high_frequency_weight_callback,
|
||||
)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
|
||||
|
||||
|
||||
def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
|
||||
is_spatial_self_attention = (
|
||||
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
|
||||
and config.spatial_attention_block_skip_range is not None
|
||||
and not getattr(module, "is_cross_attention", False)
|
||||
)
|
||||
is_temporal_self_attention = (
|
||||
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
|
||||
and config.temporal_attention_block_skip_range is not None
|
||||
and not module.is_cross_attention
|
||||
)
|
||||
|
||||
block_skip_range, timestep_skip_range, block_type = None, None, None
|
||||
if is_spatial_self_attention:
|
||||
block_skip_range = config.spatial_attention_block_skip_range
|
||||
timestep_skip_range = config.spatial_attention_timestep_skip_range
|
||||
block_type = "spatial"
|
||||
elif is_temporal_self_attention:
|
||||
block_skip_range = config.temporal_attention_block_skip_range
|
||||
timestep_skip_range = config.temporal_attention_timestep_skip_range
|
||||
block_type = "temporal"
|
||||
|
||||
if block_skip_range is None or timestep_skip_range is None:
|
||||
logger.debug(
|
||||
f'Unable to apply FasterCache to the selected layer: "{name}" because it does '
|
||||
f"not match any of the required criteria for spatial or temporal attention layers. Note, "
|
||||
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
|
||||
f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` "
|
||||
f"function to apply FasterCache to this layer."
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}")
|
||||
hook = FasterCacheBlockHook(
|
||||
block_skip_range,
|
||||
timestep_skip_range,
|
||||
config.is_guidance_distilled,
|
||||
config.attention_weight_callback,
|
||||
config.current_timestep_callback,
|
||||
)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK)
|
||||
|
||||
|
||||
# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39
|
||||
@torch.no_grad()
|
||||
def _split_low_high_freq(x):
|
||||
fft = torch.fft.fft2(x)
|
||||
fft_shifted = torch.fft.fftshift(fft)
|
||||
height, width = x.shape[-2:]
|
||||
radius = min(height, width) // 5
|
||||
|
||||
y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width))
|
||||
center_x, center_y = width // 2, height // 2
|
||||
mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2
|
||||
|
||||
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device)
|
||||
high_freq_mask = ~low_freq_mask
|
||||
|
||||
low_freq_fft = fft_shifted * low_freq_mask
|
||||
high_freq_fft = fft_shifted * high_freq_mask
|
||||
|
||||
return low_freq_fft, high_freq_fft
|
||||
@@ -26,8 +26,8 @@ from .hooks import HookRegistry, ModelHook
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
||||
@@ -87,7 +87,7 @@ class PyramidAttentionBroadcastConfig:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PyramidAttentionBroadcastConfig("
|
||||
f"PyramidAttentionBroadcastConfig(\n"
|
||||
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
|
||||
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
|
||||
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
|
||||
@@ -175,10 +175,7 @@ class PyramidAttentionBroadcastHook(ModelHook):
|
||||
return module
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(
|
||||
module: torch.nn.Module,
|
||||
config: PyramidAttentionBroadcastConfig,
|
||||
):
|
||||
def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
|
||||
r"""
|
||||
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
|
||||
|
||||
@@ -311,4 +308,4 @@ def _apply_pyramid_attention_broadcast_hook(
|
||||
"""
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
|
||||
registry.register_hook(hook, "pyramid_attention_broadcast")
|
||||
registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)
|
||||
|
||||
@@ -24,6 +24,7 @@ class CacheMixin:
|
||||
|
||||
Supported caching techniques:
|
||||
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
|
||||
- [FasterCache](https://huggingface.co/papers/2410.19355)
|
||||
"""
|
||||
|
||||
_cache_config = None
|
||||
@@ -59,17 +60,31 @@ class CacheMixin:
|
||||
```
|
||||
"""
|
||||
|
||||
from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from ..hooks import (
|
||||
FasterCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
|
||||
if self.is_cache_enabled:
|
||||
raise ValueError(
|
||||
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
|
||||
)
|
||||
|
||||
if isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, FasterCacheConfig):
|
||||
apply_faster_cache(self, config)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(config)} is not supported.")
|
||||
|
||||
self._cache_config = config
|
||||
|
||||
def disable_cache(self) -> None:
|
||||
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
|
||||
if self._cache_config is None:
|
||||
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
||||
@@ -77,7 +92,11 @@ class CacheMixin:
|
||||
|
||||
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry.remove_hook("pyramid_attention_broadcast", recurse=True)
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, FasterCacheConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
|
||||
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
|
||||
|
||||
|
||||
@@ -336,7 +336,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
|
||||
" `from_numpy` is no longer required."
|
||||
" Pass `output_type='pt' to use the new version now."
|
||||
)
|
||||
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
||||
deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
|
||||
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
@@ -37,7 +37,6 @@ from torch import Tensor, nn
|
||||
from typing_extensions import Self
|
||||
|
||||
from .. import __version__
|
||||
from ..hooks import apply_group_offloading, apply_layerwise_casting
|
||||
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
|
||||
from ..quantizers.quantization_config import QuantizationMethod
|
||||
from ..utils import (
|
||||
@@ -504,6 +503,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
non_blocking (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, the weight casting operations are non-blocking.
|
||||
"""
|
||||
from ..hooks import apply_layerwise_casting
|
||||
|
||||
user_provided_patterns = True
|
||||
if skip_modules_pattern is None:
|
||||
@@ -570,6 +570,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
... )
|
||||
```
|
||||
"""
|
||||
from ..hooks import apply_group_offloading
|
||||
|
||||
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
|
||||
msg = (
|
||||
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
|
||||
|
||||
@@ -817,7 +817,7 @@ class LattePipeline(DiffusionPipeline):
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=current_timestep,
|
||||
enable_temporal_attentions=enable_temporal_attentions,
|
||||
|
||||
@@ -2,6 +2,21 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class FasterCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HookRegistry(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -32,6 +47,10 @@ class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def apply_faster_cache(*args, **kwargs):
|
||||
requires_backends(apply_faster_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(*args, **kwargs):
|
||||
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ from diffusers.utils.testing_utils import (
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
@@ -42,7 +43,9 @@ from ..test_pipelines_common import (
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
|
||||
class CogVideoXPipelineFastTests(
|
||||
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = CogVideoXPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
|
||||
@@ -7,7 +7,13 @@ import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FasterCacheConfig,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
FluxPipeline,
|
||||
FluxTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
nightly,
|
||||
@@ -18,6 +24,7 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
@@ -27,7 +34,11 @@ from ..test_pipelines_common import (
|
||||
|
||||
|
||||
class FluxPipelineFastTests(
|
||||
unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin
|
||||
unittest.TestCase,
|
||||
PipelineTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
):
|
||||
pipeline_class = FluxPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
@@ -38,6 +49,14 @@ class FluxPipelineFastTests(
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
faster_cache_config = FasterCacheConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(-1, 901),
|
||||
unconditional_batch_skip_range=2,
|
||||
attention_weight_callback=lambda _: 0.5,
|
||||
is_guidance_distilled=True,
|
||||
)
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = FluxTransformer2DModel(
|
||||
|
||||
@@ -21,6 +21,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, LlamaConf
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLHunyuanVideo,
|
||||
FasterCacheConfig,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
HunyuanVideoPipeline,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
@@ -30,13 +31,20 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
to_np,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
|
||||
class HunyuanVideoPipelineFastTests(
|
||||
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = HunyuanVideoPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
@@ -56,6 +64,14 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
faster_cache_config = FasterCacheConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(-1, 901),
|
||||
unconditional_batch_skip_range=2,
|
||||
attention_weight_callback=lambda _: 0.5,
|
||||
is_guidance_distilled=True,
|
||||
)
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = HunyuanVideoTransformer3DModel(
|
||||
|
||||
@@ -25,6 +25,7 @@ from transformers import AutoTokenizer, T5EncoderModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
FasterCacheConfig,
|
||||
LattePipeline,
|
||||
LatteTransformer3DModel,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
@@ -40,13 +41,20 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
to_np,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
|
||||
class LattePipelineFastTests(
|
||||
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = LattePipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
@@ -69,6 +77,15 @@ class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTeste
|
||||
cross_attention_block_identifiers=["transformer_blocks"],
|
||||
)
|
||||
|
||||
faster_cache_config = FasterCacheConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
temporal_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(-1, 901),
|
||||
temporal_attention_timestep_skip_range=(-1, 901),
|
||||
unconditional_batch_skip_range=2,
|
||||
attention_weight_callback=lambda _: 0.5,
|
||||
)
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = LatteTransformer3DModel(
|
||||
|
||||
@@ -33,13 +33,13 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
|
||||
pipeline_class = MochiPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
@@ -59,13 +59,13 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, num_layers: int = 2):
|
||||
torch.manual_seed(0)
|
||||
transformer = MochiTransformer3DModel(
|
||||
patch_size=2,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=8,
|
||||
num_layers=2,
|
||||
num_layers=num_layers,
|
||||
pooled_projection_dim=16,
|
||||
in_channels=12,
|
||||
out_channels=None,
|
||||
|
||||
@@ -23,13 +23,16 @@ from diffusers import (
|
||||
ConsistencyDecoderVAE,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
FasterCacheConfig,
|
||||
KolorsPipeline,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
UNet2DConditionModel,
|
||||
apply_faster_cache,
|
||||
)
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
|
||||
@@ -2551,6 +2554,167 @@ class PyramidAttentionBroadcastTesterMixin:
|
||||
), "Outputs from normal inference and after disabling cache should not differ."
|
||||
|
||||
|
||||
class FasterCacheTesterMixin:
|
||||
faster_cache_config = FasterCacheConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(-1, 901),
|
||||
unconditional_batch_skip_range=2,
|
||||
attention_weight_callback=lambda _: 0.5,
|
||||
)
|
||||
|
||||
def test_faster_cache_basic_warning_or_errors_raised(self):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
logger = logging.get_logger("diffusers.hooks.faster_cache")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Check if warning is raise when no attention_weight_callback is provided
|
||||
pipe = self.pipeline_class(**components)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
config = FasterCacheConfig(spatial_attention_block_skip_range=2, attention_weight_callback=None)
|
||||
apply_faster_cache(pipe.transformer, config)
|
||||
self.assertTrue("No `attention_weight_callback` provided when enabling FasterCache" in cap_logger.out)
|
||||
|
||||
# Check if error raised when unsupported tensor format used
|
||||
pipe = self.pipeline_class(**components)
|
||||
with self.assertRaises(ValueError):
|
||||
config = FasterCacheConfig(spatial_attention_block_skip_range=2, tensor_format="BFHWC")
|
||||
apply_faster_cache(pipe.transformer, config)
|
||||
|
||||
def test_faster_cache_inference(self, expected_atol: float = 0.1):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
def create_pipe():
|
||||
torch.manual_seed(0)
|
||||
num_layers = 2
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
def run_forward(pipe):
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
return pipe(**inputs)[0]
|
||||
|
||||
# Run inference without FasterCache
|
||||
pipe = create_pipe()
|
||||
output = run_forward(pipe).flatten()
|
||||
original_image_slice = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FasterCache enabled
|
||||
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.faster_cache_config)
|
||||
output = run_forward(pipe).flatten().flatten()
|
||||
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FasterCache disabled
|
||||
pipe.transformer.disable_cache()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_faster_cache_disabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol
|
||||
), "FasterCache outputs should not differ much in specified timestep range."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_faster_cache_disabled, atol=1e-4
|
||||
), "Outputs from normal inference and after disabling cache should not differ."
|
||||
|
||||
def test_faster_cache_state(self):
|
||||
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
num_layers = 0
|
||||
num_single_layers = 0
|
||||
dummy_component_kwargs = {}
|
||||
dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters
|
||||
if "num_layers" in dummy_component_parameters:
|
||||
num_layers = 2
|
||||
dummy_component_kwargs["num_layers"] = num_layers
|
||||
if "num_single_layers" in dummy_component_parameters:
|
||||
num_single_layers = 2
|
||||
dummy_component_kwargs["num_single_layers"] = num_single_layers
|
||||
|
||||
components = self.get_dummy_components(**dummy_component_kwargs)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
|
||||
pipe.transformer.enable_cache(self.faster_cache_config)
|
||||
|
||||
expected_hooks = 0
|
||||
if self.faster_cache_config.spatial_attention_block_skip_range is not None:
|
||||
expected_hooks += num_layers + num_single_layers
|
||||
if self.faster_cache_config.temporal_attention_block_skip_range is not None:
|
||||
expected_hooks += num_layers + num_single_layers
|
||||
|
||||
# Check if faster_cache denoiser hook is attached
|
||||
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
|
||||
self.assertTrue(
|
||||
hasattr(denoiser, "_diffusers_hook")
|
||||
and isinstance(denoiser._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK), FasterCacheDenoiserHook),
|
||||
"Hook should be of type FasterCacheDenoiserHook.",
|
||||
)
|
||||
|
||||
# Check if all blocks have faster_cache block hook attached
|
||||
count = 0
|
||||
for name, module in denoiser.named_modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
if name == "":
|
||||
# Skip the root denoiser module
|
||||
continue
|
||||
count += 1
|
||||
self.assertTrue(
|
||||
isinstance(module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK), FasterCacheBlockHook),
|
||||
"Hook should be of type FasterCacheBlockHook.",
|
||||
)
|
||||
self.assertEqual(count, expected_hooks, "Number of hooks should match expected number.")
|
||||
|
||||
# Perform inference to ensure that states are updated correctly
|
||||
def faster_cache_state_check_callback(pipe, i, t, kwargs):
|
||||
for name, module in denoiser.named_modules():
|
||||
if not hasattr(module, "_diffusers_hook"):
|
||||
continue
|
||||
if name == "":
|
||||
# Root denoiser module
|
||||
state = module._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK).state
|
||||
if not self.faster_cache_config.is_guidance_distilled:
|
||||
self.assertTrue(state.low_frequency_delta is not None, "Low frequency delta should be set.")
|
||||
self.assertTrue(state.high_frequency_delta is not None, "High frequency delta should be set.")
|
||||
else:
|
||||
# Internal blocks
|
||||
state = module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK).state
|
||||
self.assertTrue(state.cache is not None and len(state.cache) == 2, "Cache should be set.")
|
||||
self.assertTrue(state.iteration == i + 1, "Hook iteration state should have updated during inference.")
|
||||
return {}
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
inputs["callback_on_step_end"] = faster_cache_state_check_callback
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
# After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states
|
||||
for name, module in denoiser.named_modules():
|
||||
if not hasattr(module, "_diffusers_hook"):
|
||||
continue
|
||||
|
||||
if name == "":
|
||||
# Root denoiser module
|
||||
state = module._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK).state
|
||||
self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.")
|
||||
self.assertTrue(state.low_frequency_delta is None, "Low frequency delta should be reset to None.")
|
||||
self.assertTrue(state.high_frequency_delta is None, "High frequency delta should be reset to None.")
|
||||
else:
|
||||
# Internal blocks
|
||||
state = module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK).state
|
||||
self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.")
|
||||
self.assertTrue(state.batch_size is None, "Batch size should be reset to None.")
|
||||
self.assertTrue(state.cache is None, "Cache should be reset to None.")
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
# reference image.
|
||||
|
||||
Reference in New Issue
Block a user