mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
First Block Cache (#11180)
* update * modify flux single blocks to make compatible with cache techniques (without too much model-specific intrusion code) * remove debug logs * update * cache context for different batches of data * fix hs residual bug for single return outputs; support ltx * fix controlnet flux * support flux, ltx i2v, ltx condition * update * update * Update docs/source/en/api/cache.md * Update src/diffusers/hooks/hooks.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * address review comments pt. 1 * address review comments pt. 2 * cache context refacotr; address review pt. 3 * address review comments * metadata registration with decorators instead of centralized * support cogvideox * support mochi * fix * remove unused function * remove central registry based on review * update --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -28,3 +28,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
|
||||
[[autodoc]] FasterCacheConfig
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
### FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] apply_first_block_cache
|
||||
|
||||
@@ -133,9 +133,11 @@ else:
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"FasterCacheConfig",
|
||||
"FirstBlockCacheConfig",
|
||||
"HookRegistry",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
@@ -751,9 +753,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
from .models import (
|
||||
|
||||
@@ -1,8 +1,23 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
|
||||
30
src/diffusers/hooks/_common.py
Normal file
30
src/diffusers/hooks/_common.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||
|
||||
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
||||
{
|
||||
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
}
|
||||
)
|
||||
264
src/diffusers/hooks/_helpers.py
Normal file
264
src/diffusers/hooks/_helpers.py
Normal file
@@ -0,0 +1,264 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Type
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionProcessorMetadata:
|
||||
skip_processor_output_fn: Callable[[Any], Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerBlockMetadata:
|
||||
return_hidden_states_index: int = None
|
||||
return_encoder_hidden_states_index: int = None
|
||||
|
||||
_cls: Type = None
|
||||
_cached_parameter_indices: Dict[str, int] = None
|
||||
|
||||
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
if identifier in kwargs:
|
||||
return kwargs[identifier]
|
||||
if self._cached_parameter_indices is not None:
|
||||
return args[self._cached_parameter_indices[identifier]]
|
||||
if self._cls is None:
|
||||
raise ValueError("Model class is not set for metadata.")
|
||||
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
|
||||
parameters = parameters[1:] # skip `self`
|
||||
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
|
||||
if identifier not in self._cached_parameter_indices:
|
||||
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
|
||||
index = self._cached_parameter_indices[identifier]
|
||||
if index >= len(args):
|
||||
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
|
||||
return args[index]
|
||||
|
||||
|
||||
class AttentionProcessorRegistry:
|
||||
_registry = {}
|
||||
# TODO(aryan): this is only required for the time being because we need to do the registrations
|
||||
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
|
||||
# import errors because of the models imported in this file.
|
||||
_is_registered = False
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
|
||||
cls._register()
|
||||
cls._registry[model_class] = metadata
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
|
||||
cls._register()
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
|
||||
@classmethod
|
||||
def _register(cls):
|
||||
if cls._is_registered:
|
||||
return
|
||||
cls._is_registered = True
|
||||
_register_attention_processors_metadata()
|
||||
|
||||
|
||||
class TransformerBlockRegistry:
|
||||
_registry = {}
|
||||
# TODO(aryan): this is only required for the time being because we need to do the registrations
|
||||
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
|
||||
# import errors because of the models imported in this file.
|
||||
_is_registered = False
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
|
||||
cls._register()
|
||||
metadata._cls = model_class
|
||||
cls._registry[model_class] = metadata
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> TransformerBlockMetadata:
|
||||
cls._register()
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
|
||||
@classmethod
|
||||
def _register(cls):
|
||||
if cls._is_registered:
|
||||
return
|
||||
cls._is_registered = True
|
||||
_register_transformer_blocks_metadata()
|
||||
|
||||
|
||||
def _register_attention_processors_metadata():
|
||||
from ..models.attention_processor import AttnProcessor2_0
|
||||
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
|
||||
|
||||
# AttnProcessor2_0
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=AttnProcessor2_0,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
|
||||
# CogView4AttnProcessor
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=CogView4AttnProcessor,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
|
||||
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
|
||||
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
from ..models.transformers.transformer_hunyuan_video import (
|
||||
HunyuanVideoSingleTransformerBlock,
|
||||
HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
HunyuanVideoTokenReplaceTransformerBlock,
|
||||
HunyuanVideoTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanTransformerBlock
|
||||
|
||||
# BasicTransformerBlock
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=BasicTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
# CogVideoX
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=CogVideoXBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# CogView4
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=CogView4TransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Flux
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=FluxTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=FluxSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
|
||||
# HunyuanVideo
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTokenReplaceTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# LTXVideo
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=LTXVideoTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Mochi
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=MochiTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Wan
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=WanTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
if encoder_hidden_states is None and len(args) > 1:
|
||||
encoder_hidden_states = args[1]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
|
||||
# fmt: on
|
||||
227
src/diffusers/hooks/first_block_cache.py
Normal file
227
src/diffusers/hooks/first_block_cache.py
Normal file
@@ -0,0 +1,227 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from ._helpers import TransformerBlockRegistry
|
||||
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
|
||||
_FBC_BLOCK_HOOK = "fbc_block_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FirstBlockCacheConfig:
|
||||
r"""
|
||||
Configuration for [First Block
|
||||
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
|
||||
|
||||
Args:
|
||||
threshold (`float`, defaults to `0.05`):
|
||||
The threshold to determine whether or not a forward pass through all layers of the model is required. A
|
||||
higher threshold usually results in a forward pass through a lower number of layers and faster inference,
|
||||
but might lead to poorer generation quality. A lower threshold may not result in significant generation
|
||||
speedup. The threshold is compared against the absmean difference of the residuals between the current and
|
||||
cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
|
||||
is skipped.
|
||||
"""
|
||||
|
||||
threshold: float = 0.05
|
||||
|
||||
|
||||
class FBCSharedBlockState(BaseState):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.head_block_residual: torch.Tensor = None
|
||||
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.should_compute: bool = True
|
||||
|
||||
def reset(self):
|
||||
self.tail_block_residuals = None
|
||||
self.should_compute = True
|
||||
|
||||
|
||||
class FBCHeadBlockHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, state_manager: StateManager, threshold: float):
|
||||
self.state_manager = state_manager
|
||||
self.threshold = threshold
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
unwrapped_module = unwrap_module(module)
|
||||
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
is_output_tuple = isinstance(output, tuple)
|
||||
|
||||
if is_output_tuple:
|
||||
hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
|
||||
else:
|
||||
hidden_states_residual = output - original_hidden_states
|
||||
|
||||
shared_state: FBCSharedBlockState = self.state_manager.get_state()
|
||||
hidden_states = encoder_hidden_states = None
|
||||
should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
|
||||
shared_state.should_compute = should_compute
|
||||
|
||||
if not should_compute:
|
||||
# Apply caching
|
||||
if is_output_tuple:
|
||||
hidden_states = (
|
||||
shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
|
||||
)
|
||||
else:
|
||||
hidden_states = shared_state.tail_block_residuals[0] + output
|
||||
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
assert is_output_tuple
|
||||
encoder_hidden_states = (
|
||||
shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
|
||||
)
|
||||
|
||||
if is_output_tuple:
|
||||
return_output = [None] * len(output)
|
||||
return_output[self._metadata.return_hidden_states_index] = hidden_states
|
||||
return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
|
||||
return_output = tuple(return_output)
|
||||
else:
|
||||
return_output = hidden_states
|
||||
output = return_output
|
||||
else:
|
||||
if is_output_tuple:
|
||||
head_block_output = [None] * len(output)
|
||||
head_block_output[0] = output[self._metadata.return_hidden_states_index]
|
||||
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
|
||||
else:
|
||||
head_block_output = output
|
||||
shared_state.head_block_output = head_block_output
|
||||
shared_state.head_block_residual = hidden_states_residual
|
||||
|
||||
return output
|
||||
|
||||
def reset_state(self, module):
|
||||
self.state_manager.reset()
|
||||
return module
|
||||
|
||||
@torch.compiler.disable
|
||||
def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
|
||||
shared_state = self.state_manager.get_state()
|
||||
if shared_state.head_block_residual is None:
|
||||
return True
|
||||
prev_hidden_states_residual = shared_state.head_block_residual
|
||||
absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
|
||||
prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
|
||||
diff = (absmean / prev_hidden_states_absmean).item()
|
||||
return diff > self.threshold
|
||||
|
||||
|
||||
class FBCBlockHook(ModelHook):
|
||||
def __init__(self, state_manager: StateManager, is_tail: bool = False):
|
||||
super().__init__()
|
||||
self.state_manager = state_manager
|
||||
self.is_tail = is_tail
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
unwrapped_module = unwrap_module(module)
|
||||
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
||||
original_encoder_hidden_states = None
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
||||
"encoder_hidden_states", args, kwargs
|
||||
)
|
||||
|
||||
shared_state = self.state_manager.get_state()
|
||||
|
||||
if shared_state.should_compute:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
if self.is_tail:
|
||||
hidden_states_residual = encoder_hidden_states_residual = None
|
||||
if isinstance(output, tuple):
|
||||
hidden_states_residual = (
|
||||
output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
|
||||
)
|
||||
encoder_hidden_states_residual = (
|
||||
output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
|
||||
)
|
||||
else:
|
||||
hidden_states_residual = output - shared_state.head_block_output
|
||||
shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
|
||||
return output
|
||||
|
||||
if original_encoder_hidden_states is None:
|
||||
return_output = original_hidden_states
|
||||
else:
|
||||
return_output = [None, None]
|
||||
return_output[self._metadata.return_hidden_states_index] = original_hidden_states
|
||||
return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
|
||||
return_output = tuple(return_output)
|
||||
return return_output
|
||||
|
||||
|
||||
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
|
||||
state_manager = StateManager(FBCSharedBlockState, (), {})
|
||||
remaining_blocks = []
|
||||
|
||||
for name, submodule in module.named_children():
|
||||
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
||||
continue
|
||||
for index, block in enumerate(submodule):
|
||||
remaining_blocks.append((f"{name}.{index}", block))
|
||||
|
||||
head_block_name, head_block = remaining_blocks.pop(0)
|
||||
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
||||
|
||||
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
|
||||
_apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
|
||||
|
||||
for name, block in remaining_blocks:
|
||||
logger.debug(f"Applying FBCBlockHook to '{name}'")
|
||||
_apply_fbc_block_hook(block, state_manager)
|
||||
|
||||
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
|
||||
_apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
|
||||
|
||||
|
||||
def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCHeadBlockHook(state_manager, threshold)
|
||||
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
|
||||
|
||||
|
||||
def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCBlockHook(state_manager, is_tail)
|
||||
registry.register_hook(hook, _FBC_BLOCK_HOOK)
|
||||
@@ -18,11 +18,44 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BaseState:
|
||||
def reset(self, *args, **kwargs) -> None:
|
||||
raise NotImplementedError(
|
||||
"BaseState::reset is not implemented. Please implement this method in the derived class."
|
||||
)
|
||||
|
||||
|
||||
class StateManager:
|
||||
def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
|
||||
self._state_cls = state_cls
|
||||
self._init_args = init_args if init_args is not None else ()
|
||||
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
|
||||
self._state_cache = {}
|
||||
self._current_context = None
|
||||
|
||||
def get_state(self):
|
||||
if self._current_context is None:
|
||||
raise ValueError("No context is set. Please set a context before retrieving the state.")
|
||||
if self._current_context not in self._state_cache.keys():
|
||||
self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
|
||||
return self._state_cache[self._current_context]
|
||||
|
||||
def set_context(self, name: str) -> None:
|
||||
self._current_context = name
|
||||
|
||||
def reset(self, *args, **kwargs) -> None:
|
||||
for name, state in list(self._state_cache.items()):
|
||||
state.reset(*args, **kwargs)
|
||||
self._state_cache.pop(name)
|
||||
self._current_context = None
|
||||
|
||||
|
||||
class ModelHook:
|
||||
r"""
|
||||
A hook that contains callbacks to be executed just before and after the forward method of a model.
|
||||
@@ -99,6 +132,14 @@ class ModelHook:
|
||||
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
|
||||
return module
|
||||
|
||||
def _set_context(self, module: torch.nn.Module, name: str) -> None:
|
||||
# Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
|
||||
for attr_name in dir(self):
|
||||
attr = getattr(self, attr_name)
|
||||
if isinstance(attr, StateManager):
|
||||
attr.set_context(name)
|
||||
return module
|
||||
|
||||
|
||||
class HookFunctionReference:
|
||||
def __init__(self) -> None:
|
||||
@@ -211,9 +252,10 @@ class HookRegistry:
|
||||
hook.reset_state(self._module_ref)
|
||||
|
||||
if recurse:
|
||||
for module_name, module in self._module_ref.named_modules():
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook.reset_stateful_hooks(recurse=False)
|
||||
|
||||
@@ -223,6 +265,19 @@ class HookRegistry:
|
||||
module._diffusers_hook = cls(module)
|
||||
return module._diffusers_hook
|
||||
|
||||
def _set_context(self, name: Optional[str] = None) -> None:
|
||||
for hook_name in reversed(self._hook_order):
|
||||
hook = self.hooks[hook_name]
|
||||
if hook._is_stateful:
|
||||
hook._set_context(self._module_ref, name)
|
||||
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook._set_context(name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
registry_repr = ""
|
||||
for i, hook_name in enumerate(self._hook_order):
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
|
||||
@@ -25,6 +27,7 @@ class CacheMixin:
|
||||
Supported caching techniques:
|
||||
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
|
||||
- [FasterCache](https://huggingface.co/papers/2410.19355)
|
||||
- [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
|
||||
"""
|
||||
|
||||
_cache_config = None
|
||||
@@ -62,8 +65,10 @@ class CacheMixin:
|
||||
|
||||
from ..hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
|
||||
@@ -72,31 +77,36 @@ class CacheMixin:
|
||||
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
|
||||
)
|
||||
|
||||
if isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, FasterCacheConfig):
|
||||
if isinstance(config, FasterCacheConfig):
|
||||
apply_faster_cache(self, config)
|
||||
elif isinstance(config, FirstBlockCacheConfig):
|
||||
apply_first_block_cache(self, config)
|
||||
elif isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(config)} is not supported.")
|
||||
|
||||
self._cache_config = config
|
||||
|
||||
def disable_cache(self) -> None:
|
||||
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
|
||||
if self._cache_config is None:
|
||||
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
||||
return
|
||||
|
||||
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, FasterCacheConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
if isinstance(self._cache_config, FasterCacheConfig):
|
||||
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
|
||||
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, FirstBlockCacheConfig):
|
||||
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
|
||||
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
|
||||
|
||||
@@ -106,3 +116,15 @@ class CacheMixin:
|
||||
from ..hooks import HookRegistry
|
||||
|
||||
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
|
||||
|
||||
@contextmanager
|
||||
def cache_context(self, name: str):
|
||||
r"""Context manager that provides additional methods for cache management."""
|
||||
from ..hooks import HookRegistry
|
||||
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry._set_context(name)
|
||||
|
||||
yield
|
||||
|
||||
registry._set_context(None)
|
||||
|
||||
@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
block_samples = block_samples + (hidden_states,)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
single_block_samples = ()
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
|
||||
single_block_samples = single_block_samples + (hidden_states,)
|
||||
|
||||
# controlnet block
|
||||
controlnet_block_samples = ()
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -453,6 +454,7 @@ class CogView4TrainingAttnProcessor:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class CogView4TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -79,10 +79,14 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
residual = hidden_states
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
@@ -100,7 +104,8 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@@ -507,20 +512,21 @@ class FluxTransformer2DModel(
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
@@ -530,12 +536,7 @@ class FluxTransformer2DModel(
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
@@ -22,6 +22,7 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -249,6 +250,7 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class WanTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -718,14 +718,15 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# perform guidance
|
||||
|
||||
@@ -784,14 +784,15 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# perform guidance
|
||||
|
||||
@@ -831,15 +831,16 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
ofs=ofs_emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
ofs=ofs_emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# perform guidance
|
||||
|
||||
@@ -799,14 +799,15 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# perform guidance
|
||||
|
||||
@@ -619,22 +619,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond = self.transformer(
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred_cond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
@@ -643,6 +631,19 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
@@ -912,32 +912,35 @@ class FluxPipeline(
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
if negative_image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
||||
neg_noise_pred = self.transformer(
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
txt_ids=negative_text_ids,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
if negative_image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
||||
|
||||
with self.transformer.cache_context("uncond"):
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
txt_ids=negative_text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
||||
@@ -693,28 +693,30 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
neg_noise_pred = self.transformer(
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_attention_mask=negative_prompt_attention_mask,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_attention_mask=negative_prompt_attention_mask,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
||||
@@ -757,18 +757,19 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
|
||||
@@ -1177,15 +1177,16 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
if is_conditioning_image_or_video:
|
||||
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
video_coords=video_coords,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
video_coords=video_coords,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
|
||||
@@ -830,18 +830,19 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
|
||||
@@ -671,14 +671,15 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# Mochi CFG + Sampling runs in FP32
|
||||
noise_pred = noise_pred.to(torch.float32)
|
||||
|
||||
|
||||
@@ -533,22 +533,24 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
||||
@@ -17,6 +17,21 @@ class FasterCacheConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FirstBlockCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HookRegistry(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -51,6 +66,10 @@ def apply_faster_cache(*args, **kwargs):
|
||||
requires_backends(apply_faster_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_first_block_cache(*args, **kwargs):
|
||||
requires_backends(apply_first_block_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(*args, **kwargs):
|
||||
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
|
||||
|
||||
|
||||
@@ -92,6 +92,11 @@ def is_compiled_module(module) -> bool:
|
||||
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
|
||||
|
||||
|
||||
def unwrap_module(module):
|
||||
"""Unwraps a module if it was compiled with torch.compile()"""
|
||||
return module._orig_mod if is_compiled_module(module) else module
|
||||
|
||||
|
||||
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
|
||||
"""Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497).
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
@@ -45,7 +46,11 @@ enable_full_determinism()
|
||||
|
||||
|
||||
class CogVideoXPipelineFastTests(
|
||||
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = CogVideoXPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
|
||||
@@ -24,6 +24,7 @@ from diffusers.utils.testing_utils import (
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
@@ -33,11 +34,12 @@ from ..test_pipelines_common import (
|
||||
|
||||
|
||||
class FluxPipelineFastTests(
|
||||
unittest.TestCase,
|
||||
PipelineTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = FluxPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
|
||||
@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
to_np,
|
||||
@@ -43,7 +44,11 @@ enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideoPipelineFastTests(
|
||||
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = HunyuanVideoPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
|
||||
@@ -23,13 +23,13 @@ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LT
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTXPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
@@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = LTXVideoTransformer3DModel(
|
||||
in_channels=8,
|
||||
@@ -59,7 +59,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
num_attention_heads=4,
|
||||
attention_head_dim=8,
|
||||
cross_attention_dim=32,
|
||||
num_layers=1,
|
||||
num_layers=num_layers,
|
||||
caption_channels=32,
|
||||
)
|
||||
|
||||
|
||||
@@ -32,13 +32,15 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
|
||||
from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
|
||||
class MochiPipelineFastTests(
|
||||
PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = MochiPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
|
||||
@@ -33,6 +33,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
|
||||
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
|
||||
@@ -2648,7 +2649,7 @@ class FasterCacheTesterMixin:
|
||||
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.faster_cache_config)
|
||||
output = run_forward(pipe).flatten().flatten()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FasterCache disabled
|
||||
@@ -2755,6 +2756,55 @@ class FasterCacheTesterMixin:
|
||||
self.assertTrue(state.cache is None, "Cache should be reset to None.")
|
||||
|
||||
|
||||
# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
|
||||
# of the box once there is better cache support/implementation
|
||||
class FirstBlockCacheTesterMixin:
|
||||
# threshold is intentionally set higher than usual values since we're testing with random unconverged models
|
||||
# that will not satisfy the expected properties of the denoiser for caching to be effective
|
||||
first_block_cache_config = FirstBlockCacheConfig(threshold=0.8)
|
||||
|
||||
def test_first_block_cache_inference(self, expected_atol: float = 0.1):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
def create_pipe():
|
||||
torch.manual_seed(0)
|
||||
num_layers = 2
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
def run_forward(pipe):
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
return pipe(**inputs)[0]
|
||||
|
||||
# Run inference without FirstBlockCache
|
||||
pipe = create_pipe()
|
||||
output = run_forward(pipe).flatten()
|
||||
original_image_slice = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FirstBlockCache enabled
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.first_block_cache_config)
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FirstBlockCache disabled
|
||||
pipe.transformer.disable_cache()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
|
||||
"FirstBlockCache outputs should not differ much."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
|
||||
"Outputs from normal inference and after disabling cache should not differ."
|
||||
)
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
# reference image.
|
||||
|
||||
Reference in New Issue
Block a user