mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
remove central registry based on review
This commit is contained in:
@@ -17,7 +17,6 @@ from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.metadata import TransformerBlockRegistry
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
@@ -72,7 +71,13 @@ class FBCHeadBlockHook(ModelHook):
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
unwrapped_module = unwrap_module(module)
|
||||
if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"):
|
||||
raise ValueError(
|
||||
f"Module {unwrapped_module} does not have any registered metadata. "
|
||||
"Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`."
|
||||
)
|
||||
self._metadata = unwrapped_module._diffusers_transformer_block_metadata
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
@@ -150,7 +155,13 @@ class FBCBlockHook(ModelHook):
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
unwrapped_module = unwrap_module(module)
|
||||
if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"):
|
||||
raise ValueError(
|
||||
f"Module {unwrapped_module} does not have any registered metadata. "
|
||||
"Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`."
|
||||
)
|
||||
self._metadata = unwrapped_module._diffusers_transformer_block_metadata
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
|
||||
@@ -22,7 +22,7 @@ from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
||||
from .attention_processor import Attention, JointAttnProcessor2_0
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .metadata import TransformerBlockMetadata, TransformerBlockRegistry
|
||||
from .metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
||||
|
||||
|
||||
@@ -259,7 +259,7 @@ class JointTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@TransformerBlockRegistry.register(
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
|
||||
@@ -44,20 +44,10 @@ class TransformerBlockMetadata:
|
||||
return args[index]
|
||||
|
||||
|
||||
class TransformerBlockRegistry:
|
||||
_registry = {}
|
||||
def register_transformer_block(metadata: TransformerBlockMetadata):
|
||||
def inner(model_class: Type):
|
||||
metadata._cls = model_class
|
||||
model_class._diffusers_transformer_block_metadata = metadata
|
||||
return model_class
|
||||
|
||||
@classmethod
|
||||
def register(cls, metadata: TransformerBlockMetadata):
|
||||
def inner(model_class: Type):
|
||||
metadata._cls = model_class
|
||||
cls._registry[model_class] = metadata
|
||||
return model_class
|
||||
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> TransformerBlockMetadata:
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
return inner
|
||||
|
||||
@@ -26,7 +26,7 @@ from ..attention import Attention, FeedForward
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
|
||||
from ..metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@TransformerBlockRegistry.register(
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
|
||||
@@ -26,7 +26,7 @@ from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
|
||||
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
|
||||
from ..metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
@@ -456,7 +456,7 @@ class CogView4TrainingAttnProcessor:
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@TransformerBlockRegistry.register(
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
|
||||
@@ -34,7 +34,7 @@ from ..attention_processor import (
|
||||
)
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
|
||||
from ..metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
@@ -44,7 +44,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@TransformerBlockRegistry.register(
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
@@ -116,7 +116,7 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@TransformerBlockRegistry.register(
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
|
||||
@@ -33,7 +33,7 @@ from ..embeddings import (
|
||||
Timesteps,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
|
||||
from ..metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
|
||||
@@ -311,7 +311,7 @@ class HunyuanVideoConditionEmbedding(nn.Module):
|
||||
return conditioning, token_replace_emb
|
||||
|
||||
|
||||
@TransformerBlockRegistry.register(
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
@@ -496,7 +496,7 @@ class HunyuanVideoRotaryPosEmbed(nn.Module):
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
@TransformerBlockRegistry.register(
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
@@ -578,7 +578,7 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@TransformerBlockRegistry.register(
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
@@ -663,7 +663,7 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@TransformerBlockRegistry.register(
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
@@ -749,7 +749,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@TransformerBlockRegistry.register(
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
|
||||
@@ -28,7 +28,7 @@ from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection
|
||||
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
|
||||
from ..metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
@@ -197,7 +197,7 @@ class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@TransformerBlockRegistry.register(
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
|
||||
@@ -27,7 +27,7 @@ from ..attention import FeedForward
|
||||
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
||||
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
|
||||
from ..metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, RMSNorm
|
||||
@@ -117,7 +117,7 @@ class MochiRMSNormZero(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@TransformerBlockRegistry.register(
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
|
||||
@@ -27,7 +27,7 @@ from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
|
||||
from ..metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
@@ -222,7 +222,7 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@TransformerBlockRegistry.register(
|
||||
@register_transformer_block(
|
||||
TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
|
||||
Reference in New Issue
Block a user