diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py deleted file mode 100644 index 9043ffc418..0000000000 --- a/src/diffusers/hooks/_helpers.py +++ /dev/null @@ -1,271 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Any, Callable, Type - -from ..models.attention import BasicTransformerBlock -from ..models.attention_processor import AttnProcessor2_0 -from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock -from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock -from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock -from ..models.transformers.transformer_hunyuan_video import ( - HunyuanVideoSingleTransformerBlock, - HunyuanVideoTokenReplaceSingleTransformerBlock, - HunyuanVideoTokenReplaceTransformerBlock, - HunyuanVideoTransformerBlock, -) -from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock -from ..models.transformers.transformer_mochi import MochiTransformerBlock -from ..models.transformers.transformer_wan import WanTransformerBlock - - -@dataclass -class AttentionProcessorMetadata: - skip_processor_output_fn: Callable[[Any], Any] - - -@dataclass -class TransformerBlockMetadata: - skip_block_output_fn: Callable[[Any], Any] - return_hidden_states_index: int = None - return_encoder_hidden_states_index: int = None - - -class AttentionProcessorRegistry: - _registry = {} - - @classmethod - def register(cls, model_class: Type, metadata: AttentionProcessorMetadata): - cls._registry[model_class] = metadata - - @classmethod - def get(cls, model_class: Type) -> AttentionProcessorMetadata: - if model_class not in cls._registry: - raise ValueError(f"Model class {model_class} not registered.") - return cls._registry[model_class] - - -class TransformerBlockRegistry: - _registry = {} - - @classmethod - def register(cls, model_class: Type, metadata: TransformerBlockMetadata): - cls._registry[model_class] = metadata - - @classmethod - def get(cls, model_class: Type) -> TransformerBlockMetadata: - if model_class not in cls._registry: - raise ValueError(f"Model class {model_class} not registered.") - return cls._registry[model_class] - - -def _register_attention_processors_metadata(): - # AttnProcessor2_0 - AttentionProcessorRegistry.register( - model_class=AttnProcessor2_0, - metadata=AttentionProcessorMetadata( - skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0, - ), - ) - - # CogView4AttnProcessor - AttentionProcessorRegistry.register( - model_class=CogView4AttnProcessor, - metadata=AttentionProcessorMetadata( - skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor, - ), - ) - - -def _register_transformer_blocks_metadata(): - # BasicTransformerBlock - TransformerBlockRegistry.register( - model_class=BasicTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=None, - ), - ) - - # CogVideoX - TransformerBlockRegistry.register( - model_class=CogVideoXBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ), - ) - - # CogView4 - TransformerBlockRegistry.register( - model_class=CogView4TransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ), - ) - - # Flux - TransformerBlockRegistry.register( - model_class=FluxTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock, - return_hidden_states_index=1, - return_encoder_hidden_states_index=0, - ), - ) - TransformerBlockRegistry.register( - model_class=FluxSingleTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock, - return_hidden_states_index=1, - return_encoder_hidden_states_index=0, - ), - ) - - # HunyuanVideo - TransformerBlockRegistry.register( - model_class=HunyuanVideoTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ), - ) - TransformerBlockRegistry.register( - model_class=HunyuanVideoSingleTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ), - ) - TransformerBlockRegistry.register( - model_class=HunyuanVideoTokenReplaceTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ), - ) - TransformerBlockRegistry.register( - model_class=HunyuanVideoTokenReplaceSingleTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ), - ) - - # LTXVideo - TransformerBlockRegistry.register( - model_class=LTXVideoTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=None, - ), - ) - - # Mochi - TransformerBlockRegistry.register( - model_class=MochiTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ), - ) - - # Wan - TransformerBlockRegistry.register( - model_class=WanTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=None, - ), - ) - - -# fmt: off -def _skip_attention___ret___hidden_states(self, *args, **kwargs): - hidden_states = kwargs.get("hidden_states", None) - if hidden_states is None and len(args) > 0: - hidden_states = args[0] - return hidden_states - - -def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): - hidden_states = kwargs.get("hidden_states", None) - encoder_hidden_states = kwargs.get("encoder_hidden_states", None) - if hidden_states is None and len(args) > 0: - hidden_states = args[0] - if encoder_hidden_states is None and len(args) > 1: - encoder_hidden_states = args[1] - return hidden_states, encoder_hidden_states - - -_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states -_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states - - -def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): - hidden_states = kwargs.get("hidden_states", None) - if hidden_states is None and len(args) > 0: - hidden_states = args[0] - return hidden_states - - -def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): - hidden_states = kwargs.get("hidden_states", None) - encoder_hidden_states = kwargs.get("encoder_hidden_states", None) - if hidden_states is None and len(args) > 0: - hidden_states = args[0] - if encoder_hidden_states is None and len(args) > 1: - encoder_hidden_states = args[1] - return hidden_states, encoder_hidden_states - - -def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs): - hidden_states = kwargs.get("hidden_states", None) - encoder_hidden_states = kwargs.get("encoder_hidden_states", None) - if hidden_states is None and len(args) > 0: - hidden_states = args[0] - if encoder_hidden_states is None and len(args) > 1: - encoder_hidden_states = args[1] - return encoder_hidden_states, hidden_states - - -_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states -_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states -_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states -_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states -_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states -# fmt: on - - -_register_attention_processors_metadata() -_register_transformer_blocks_metadata() diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index 31ee08c34d..a7a415ca51 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -17,10 +17,10 @@ from typing import Tuple, Union import torch +from ..models.metadata import TransformerBlockRegistry from ..utils import get_logger from ..utils.torch_utils import unwrap_module from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS -from ._helpers import TransformerBlockRegistry from .hooks import BaseState, HookRegistry, ModelHook, StateManager @@ -76,12 +76,7 @@ class FBCHeadBlockHook(ModelHook): return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): - outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) - - if isinstance(outputs_if_skipped, tuple): - original_hidden_states = outputs_if_skipped[self._metadata.return_hidden_states_index] - else: - original_hidden_states = outputs_if_skipped + original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) output = self.fn_ref.original_forward(*args, **kwargs) is_output_tuple = isinstance(output, tuple) @@ -92,7 +87,7 @@ class FBCHeadBlockHook(ModelHook): hidden_states_residual = output - original_hidden_states shared_state: FBCSharedBlockState = self.state_manager.get_state() - hidden_states, encoder_hidden_states = None, None + hidden_states = encoder_hidden_states = None should_compute = self._should_compute_remaining_blocks(hidden_states_residual) shared_state.should_compute = should_compute @@ -159,13 +154,12 @@ class FBCBlockHook(ModelHook): return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): - outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) - if not isinstance(outputs_if_skipped, tuple): - outputs_if_skipped = (outputs_if_skipped,) - original_hidden_states = outputs_if_skipped[self._metadata.return_hidden_states_index] + original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) original_encoder_hidden_states = None if self._metadata.return_encoder_hidden_states_index is not None: - original_encoder_hidden_states = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index] + original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) shared_state = self.state_manager.get_state() @@ -185,13 +179,13 @@ class FBCBlockHook(ModelHook): shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual) return output - output_count = len(outputs_if_skipped) - if output_count == 1: + if original_encoder_hidden_states is None: return_output = original_hidden_states else: - return_output = [None] * output_count + return_output = [None, None] return_output[self._metadata.return_hidden_states_index] = original_hidden_states return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states + return_output = tuple(return_output) return return_output diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 93b11c2b43..dfcfc27304 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -22,6 +22,7 @@ from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding +from .metadata import TransformerBlockMetadata, TransformerBlockRegistry from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX @@ -258,6 +259,12 @@ class JointTransformerBlock(nn.Module): @maybe_allow_in_graph +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ) +) class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. diff --git a/src/diffusers/models/metadata.py b/src/diffusers/models/metadata.py new file mode 100644 index 0000000000..9b13e52fc0 --- /dev/null +++ b/src/diffusers/models/metadata.py @@ -0,0 +1,63 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Dict, Type + + +@dataclass +class TransformerBlockMetadata: + return_hidden_states_index: int = None + return_encoder_hidden_states_index: int = None + + _cls: Type = None + _cached_parameter_indices: Dict[str, int] = None + + def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): + kwargs = kwargs or {} + if identifier in kwargs: + return kwargs[identifier] + if self._cached_parameter_indices is not None: + return args[self._cached_parameter_indices[identifier]] + if self._cls is None: + raise ValueError("Model class is not set for metadata.") + parameters = list(inspect.signature(self._cls.forward).parameters.keys()) + parameters = parameters[1:] # skip `self` + self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)} + if identifier not in self._cached_parameter_indices: + raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.") + index = self._cached_parameter_indices[identifier] + if index >= len(args): + raise ValueError(f"Expected {index} arguments but got {len(args)}.") + return args[index] + + +class TransformerBlockRegistry: + _registry = {} + + @classmethod + def register(cls, metadata: TransformerBlockMetadata): + def inner(model_class: Type): + metadata._cls = model_class + cls._registry[model_class] = metadata + return model_class + + return inner + + @classmethod + def get(cls, model_class: Type) -> TransformerBlockMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 6b4f38dc04..d3e596b1af 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -26,6 +26,7 @@ from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 from ..cache_utils import CacheMixin from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -35,6 +36,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name @maybe_allow_in_graph +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ) +) class CogVideoXBlock(nn.Module): r""" Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index aef368f91a..c3d40b8749 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -21,10 +21,12 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import CogView3CombinedTimestepSizeEmbeddings +from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous @@ -453,6 +455,13 @@ class CogView4TrainingAttnProcessor: return hidden_states, encoder_hidden_states +@maybe_allow_in_graph +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ) +) class CogView4TransformerBlock(nn.Module): def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index c9abe06b42..f66d5f982b 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -34,6 +34,7 @@ from ..attention_processor import ( ) from ..cache_utils import CacheMixin from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle @@ -43,6 +44,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name @maybe_allow_in_graph +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ) +) class FluxSingleTransformerBlock(nn.Module): def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): super().__init__() @@ -109,6 +116,12 @@ class FluxSingleTransformerBlock(nn.Module): @maybe_allow_in_graph +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ) +) class FluxTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index d9100b2f54..6e5b107f9a 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -33,6 +33,7 @@ from ..embeddings import ( Timesteps, get_1d_rotary_pos_embed, ) +from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm @@ -310,6 +311,12 @@ class HunyuanVideoConditionEmbedding(nn.Module): return conditioning, token_replace_emb +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ) +) class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): def __init__( self, @@ -489,6 +496,12 @@ class HunyuanVideoRotaryPosEmbed(nn.Module): return freqs_cos, freqs_sin +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ) +) class HunyuanVideoSingleTransformerBlock(nn.Module): def __init__( self, @@ -565,6 +578,12 @@ class HunyuanVideoSingleTransformerBlock(nn.Module): return hidden_states, encoder_hidden_states +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ) +) class HunyuanVideoTransformerBlock(nn.Module): def __init__( self, @@ -644,6 +663,12 @@ class HunyuanVideoTransformerBlock(nn.Module): return hidden_states, encoder_hidden_states +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ) +) class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module): def __init__( self, @@ -724,6 +749,12 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module): return hidden_states, encoder_hidden_states +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ) +) class HunyuanVideoTokenReplaceTransformerBlock(nn.Module): def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 2ae2418098..8a8409f1bf 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -28,6 +28,7 @@ from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection +from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle, RMSNorm @@ -196,6 +197,12 @@ class LTXVideoRotaryPosEmbed(nn.Module): @maybe_allow_in_graph +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ) +) class LTXVideoTransformerBlock(nn.Module): r""" Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video). diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index e6532f080d..2148142601 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -27,6 +27,7 @@ from ..attention import FeedForward from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 from ..cache_utils import CacheMixin from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed +from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, RMSNorm @@ -116,6 +117,12 @@ class MochiRMSNormZero(nn.Module): @maybe_allow_in_graph +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ) +) class MochiTransformerBlock(nn.Module): r""" Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview). diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index c78d72dc4a..ec607c5126 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -22,10 +22,12 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm @@ -219,6 +221,13 @@ class WanRotaryPosEmbed(nn.Module): return freqs +@maybe_allow_in_graph +@TransformerBlockRegistry.register( + TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ) +) class WanTransformerBlock(nn.Module): def __init__( self,