From b8317da20f023fe6022d34e8ca8bb91264596ea5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 16 May 2025 12:53:36 +0200 Subject: [PATCH] remove central registry based on review --- src/diffusers/hooks/first_block_cache.py | 17 +++++++++++--- src/diffusers/models/attention.py | 4 ++-- src/diffusers/models/metadata.py | 22 +++++-------------- .../transformers/cogvideox_transformer_3d.py | 4 ++-- .../transformers/transformer_cogview4.py | 4 ++-- .../models/transformers/transformer_flux.py | 6 ++--- .../transformers/transformer_hunyuan_video.py | 12 +++++----- .../models/transformers/transformer_ltx.py | 4 ++-- .../models/transformers/transformer_mochi.py | 4 ++-- .../models/transformers/transformer_wan.py | 4 ++-- 10 files changed, 41 insertions(+), 40 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index a7a415ca51..e2e27048cc 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -17,7 +17,6 @@ from typing import Tuple, Union import torch -from ..models.metadata import TransformerBlockRegistry from ..utils import get_logger from ..utils.torch_utils import unwrap_module from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS @@ -72,7 +71,13 @@ class FBCHeadBlockHook(ModelHook): self._metadata = None def initialize_hook(self, module): - self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + unwrapped_module = unwrap_module(module) + if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"): + raise ValueError( + f"Module {unwrapped_module} does not have any registered metadata. " + "Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`." + ) + self._metadata = unwrapped_module._diffusers_transformer_block_metadata return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): @@ -150,7 +155,13 @@ class FBCBlockHook(ModelHook): self._metadata = None def initialize_hook(self, module): - self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + unwrapped_module = unwrap_module(module) + if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"): + raise ValueError( + f"Module {unwrapped_module} does not have any registered metadata. " + "Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`." + ) + self._metadata = unwrapped_module._diffusers_transformer_block_metadata return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index dfcfc27304..b2bc08beff 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -22,7 +22,7 @@ from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding -from .metadata import TransformerBlockMetadata, TransformerBlockRegistry +from .metadata import TransformerBlockMetadata, register_transformer_block from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX @@ -259,7 +259,7 @@ class JointTransformerBlock(nn.Module): @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=None, diff --git a/src/diffusers/models/metadata.py b/src/diffusers/models/metadata.py index 9b13e52fc0..6da190ac30 100644 --- a/src/diffusers/models/metadata.py +++ b/src/diffusers/models/metadata.py @@ -44,20 +44,10 @@ class TransformerBlockMetadata: return args[index] -class TransformerBlockRegistry: - _registry = {} +def register_transformer_block(metadata: TransformerBlockMetadata): + def inner(model_class: Type): + metadata._cls = model_class + model_class._diffusers_transformer_block_metadata = metadata + return model_class - @classmethod - def register(cls, metadata: TransformerBlockMetadata): - def inner(model_class: Type): - metadata._cls = model_class - cls._registry[model_class] = metadata - return model_class - - return inner - - @classmethod - def get(cls, model_class: Type) -> TransformerBlockMetadata: - if model_class not in cls._registry: - raise ValueError(f"Model class {model_class} not registered.") - return cls._registry[model_class] + return inner diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index d3e596b1af..4561cbf505 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -26,7 +26,7 @@ from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 from ..cache_utils import CacheMixin from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry +from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=1, diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index c3d40b8749..8103b9dd83 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -26,7 +26,7 @@ from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import CogView3CombinedTimestepSizeEmbeddings -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry +from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous @@ -456,7 +456,7 @@ class CogView4TrainingAttnProcessor: @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=1, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index f66d5f982b..3be0ba9d16 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -34,7 +34,7 @@ from ..attention_processor import ( ) from ..cache_utils import CacheMixin from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry +from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle @@ -44,7 +44,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=1, return_encoder_hidden_states_index=0, @@ -116,7 +116,7 @@ class FluxSingleTransformerBlock(nn.Module): @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=1, return_encoder_hidden_states_index=0, diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 6e5b107f9a..1554ac129b 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -33,7 +33,7 @@ from ..embeddings import ( Timesteps, get_1d_rotary_pos_embed, ) -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry +from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm @@ -311,7 +311,7 @@ class HunyuanVideoConditionEmbedding(nn.Module): return conditioning, token_replace_emb -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=None, @@ -496,7 +496,7 @@ class HunyuanVideoRotaryPosEmbed(nn.Module): return freqs_cos, freqs_sin -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=1, @@ -578,7 +578,7 @@ class HunyuanVideoSingleTransformerBlock(nn.Module): return hidden_states, encoder_hidden_states -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=1, @@ -663,7 +663,7 @@ class HunyuanVideoTransformerBlock(nn.Module): return hidden_states, encoder_hidden_states -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=1, @@ -749,7 +749,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module): return hidden_states, encoder_hidden_states -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=1, diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 8a8409f1bf..042881524e 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -28,7 +28,7 @@ from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry +from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle, RMSNorm @@ -197,7 +197,7 @@ class LTXVideoRotaryPosEmbed(nn.Module): @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=None, diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 2148142601..f875103c26 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -27,7 +27,7 @@ from ..attention import FeedForward from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 from ..cache_utils import CacheMixin from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry +from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, RMSNorm @@ -117,7 +117,7 @@ class MochiRMSNormZero(nn.Module): @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=1, diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index ec607c5126..4ab26b90b3 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -27,7 +27,7 @@ from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry +from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm @@ -222,7 +222,7 @@ class WanRotaryPosEmbed(nn.Module): @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=None,