1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

remove central registry based on review

This commit is contained in:
Aryan
2025-05-16 12:53:36 +02:00
parent a5fe2bd4fd
commit b8317da20f
10 changed files with 41 additions and 40 deletions

View File

@@ -17,7 +17,6 @@ from typing import Tuple, Union
import torch
from ..models.metadata import TransformerBlockRegistry
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
@@ -72,7 +71,13 @@ class FBCHeadBlockHook(ModelHook):
self._metadata = None
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
unwrapped_module = unwrap_module(module)
if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"):
raise ValueError(
f"Module {unwrapped_module} does not have any registered metadata. "
"Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`."
)
self._metadata = unwrapped_module._diffusers_transformer_block_metadata
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
@@ -150,7 +155,13 @@ class FBCBlockHook(ModelHook):
self._metadata = None
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
unwrapped_module = unwrap_module(module)
if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"):
raise ValueError(
f"Module {unwrapped_module} does not have any registered metadata. "
"Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`."
)
self._metadata = unwrapped_module._diffusers_transformer_block_metadata
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):

View File

@@ -22,7 +22,7 @@ from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .metadata import TransformerBlockMetadata, TransformerBlockRegistry
from .metadata import TransformerBlockMetadata, register_transformer_block
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
@@ -259,7 +259,7 @@ class JointTransformerBlock(nn.Module):
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,

View File

@@ -44,20 +44,10 @@ class TransformerBlockMetadata:
return args[index]
class TransformerBlockRegistry:
_registry = {}
def register_transformer_block(metadata: TransformerBlockMetadata):
def inner(model_class: Type):
metadata._cls = model_class
model_class._diffusers_transformer_block_metadata = metadata
return model_class
@classmethod
def register(cls, metadata: TransformerBlockMetadata):
def inner(model_class: Type):
metadata._cls = model_class
cls._registry[model_class] = metadata
return model_class
return inner
@classmethod
def get(cls, model_class: Type) -> TransformerBlockMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
return inner

View File

@@ -26,7 +26,7 @@ from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
from ..metadata import TransformerBlockMetadata, register_transformer_block
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,

View File

@@ -26,7 +26,7 @@ from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
from ..metadata import TransformerBlockMetadata, register_transformer_block
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous
@@ -456,7 +456,7 @@ class CogView4TrainingAttnProcessor:
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,

View File

@@ -34,7 +34,7 @@ from ..attention_processor import (
)
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
from ..metadata import TransformerBlockMetadata, register_transformer_block
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -44,7 +44,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
@@ -116,7 +116,7 @@ class FluxSingleTransformerBlock(nn.Module):
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,

View File

@@ -33,7 +33,7 @@ from ..embeddings import (
Timesteps,
get_1d_rotary_pos_embed,
)
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
from ..metadata import TransformerBlockMetadata, register_transformer_block
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
@@ -311,7 +311,7 @@ class HunyuanVideoConditionEmbedding(nn.Module):
return conditioning, token_replace_emb
@TransformerBlockRegistry.register(
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
@@ -496,7 +496,7 @@ class HunyuanVideoRotaryPosEmbed(nn.Module):
return freqs_cos, freqs_sin
@TransformerBlockRegistry.register(
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
@@ -578,7 +578,7 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
@TransformerBlockRegistry.register(
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
@@ -663,7 +663,7 @@ class HunyuanVideoTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
@TransformerBlockRegistry.register(
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
@@ -749,7 +749,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
@TransformerBlockRegistry.register(
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,

View File

@@ -28,7 +28,7 @@ from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
from ..metadata import TransformerBlockMetadata, register_transformer_block
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
@@ -197,7 +197,7 @@ class LTXVideoRotaryPosEmbed(nn.Module):
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,

View File

@@ -27,7 +27,7 @@ from ..attention import FeedForward
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
from ..metadata import TransformerBlockMetadata, register_transformer_block
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, RMSNorm
@@ -117,7 +117,7 @@ class MochiRMSNormZero(nn.Module):
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,

View File

@@ -27,7 +27,7 @@ from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
from ..metadata import TransformerBlockMetadata, register_transformer_block
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
@@ -222,7 +222,7 @@ class WanRotaryPosEmbed(nn.Module):
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
@register_transformer_block(
TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,