1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

metadata registration with decorators instead of centralized

This commit is contained in:
Aryan
2025-05-14 14:19:48 +02:00
parent 0a44380a36
commit fb229b54bb
11 changed files with 163 additions and 287 deletions

View File

@@ -1,271 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, Type
from ..models.attention import BasicTransformerBlock
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_hunyuan_video import (
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenReplaceSingleTransformerBlock,
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
@dataclass
class AttentionProcessorMetadata:
skip_processor_output_fn: Callable[[Any], Any]
@dataclass
class TransformerBlockMetadata:
skip_block_output_fn: Callable[[Any], Any]
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
class AttentionProcessorRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
class TransformerBlockRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> TransformerBlockMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
def _register_attention_processors_metadata():
# AttnProcessor2_0
AttentionProcessorRegistry.register(
model_class=AttnProcessor2_0,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
),
)
# CogView4AttnProcessor
AttentionProcessorRegistry.register(
model_class=CogView4AttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
),
)
def _register_transformer_blocks_metadata():
# BasicTransformerBlock
TransformerBlockRegistry.register(
model_class=BasicTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# CogVideoX
TransformerBlockRegistry.register(
model_class=CogVideoXBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# CogView4
TransformerBlockRegistry.register(
model_class=CogView4TransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Flux
TransformerBlockRegistry.register(
model_class=FluxTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock,
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
TransformerBlockRegistry.register(
model_class=FluxSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock,
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
# HunyuanVideo
TransformerBlockRegistry.register(
model_class=HunyuanVideoTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# LTXVideo
TransformerBlockRegistry.register(
model_class=LTXVideoTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# Mochi
TransformerBlockRegistry.register(
model_class=MochiTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Wan
TransformerBlockRegistry.register(
model_class=WanTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
return hidden_states
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
return hidden_states
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return encoder_hidden_states, hidden_states
_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
# fmt: on
_register_attention_processors_metadata()
_register_transformer_blocks_metadata()

View File

@@ -17,10 +17,10 @@ from typing import Tuple, Union
import torch
from ..models.metadata import TransformerBlockRegistry
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
@@ -76,12 +76,7 @@ class FBCHeadBlockHook(ModelHook):
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
if isinstance(outputs_if_skipped, tuple):
original_hidden_states = outputs_if_skipped[self._metadata.return_hidden_states_index]
else:
original_hidden_states = outputs_if_skipped
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
output = self.fn_ref.original_forward(*args, **kwargs)
is_output_tuple = isinstance(output, tuple)
@@ -92,7 +87,7 @@ class FBCHeadBlockHook(ModelHook):
hidden_states_residual = output - original_hidden_states
shared_state: FBCSharedBlockState = self.state_manager.get_state()
hidden_states, encoder_hidden_states = None, None
hidden_states = encoder_hidden_states = None
should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
shared_state.should_compute = should_compute
@@ -159,13 +154,12 @@ class FBCBlockHook(ModelHook):
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
if not isinstance(outputs_if_skipped, tuple):
outputs_if_skipped = (outputs_if_skipped,)
original_hidden_states = outputs_if_skipped[self._metadata.return_hidden_states_index]
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
original_encoder_hidden_states = None
if self._metadata.return_encoder_hidden_states_index is not None:
original_encoder_hidden_states = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index]
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
"encoder_hidden_states", args, kwargs
)
shared_state = self.state_manager.get_state()
@@ -185,13 +179,13 @@ class FBCBlockHook(ModelHook):
shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
return output
output_count = len(outputs_if_skipped)
if output_count == 1:
if original_encoder_hidden_states is None:
return_output = original_hidden_states
else:
return_output = [None] * output_count
return_output = [None, None]
return_output[self._metadata.return_hidden_states_index] = original_hidden_states
return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
return_output = tuple(return_output)
return return_output

View File

@@ -22,6 +22,7 @@ from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .metadata import TransformerBlockMetadata, TransformerBlockRegistry
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
@@ -258,6 +259,12 @@ class JointTransformerBlock(nn.Module):
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
)
)
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.

View File

@@ -0,0 +1,63 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Dict, Type
@dataclass
class TransformerBlockMetadata:
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
_cls: Type = None
_cached_parameter_indices: Dict[str, int] = None
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
kwargs = kwargs or {}
if identifier in kwargs:
return kwargs[identifier]
if self._cached_parameter_indices is not None:
return args[self._cached_parameter_indices[identifier]]
if self._cls is None:
raise ValueError("Model class is not set for metadata.")
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
parameters = parameters[1:] # skip `self`
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
if identifier not in self._cached_parameter_indices:
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
index = self._cached_parameter_indices[identifier]
if index >= len(args):
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
return args[index]
class TransformerBlockRegistry:
_registry = {}
@classmethod
def register(cls, metadata: TransformerBlockMetadata):
def inner(model_class: Type):
metadata._cls = model_class
cls._registry[model_class] = metadata
return model_class
return inner
@classmethod
def get(cls, model_class: Type) -> TransformerBlockMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]

View File

@@ -26,6 +26,7 @@ from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
@@ -35,6 +36,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
)
)
class CogVideoXBlock(nn.Module):
r"""
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.

View File

@@ -21,10 +21,12 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous
@@ -453,6 +455,13 @@ class CogView4TrainingAttnProcessor:
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
)
)
class CogView4TransformerBlock(nn.Module):
def __init__(
self,

View File

@@ -34,6 +34,7 @@ from ..attention_processor import (
)
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -43,6 +44,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
)
)
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
super().__init__()
@@ -109,6 +116,12 @@ class FluxSingleTransformerBlock(nn.Module):
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
)
)
class FluxTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6

View File

@@ -33,6 +33,7 @@ from ..embeddings import (
Timesteps,
get_1d_rotary_pos_embed,
)
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
@@ -310,6 +311,12 @@ class HunyuanVideoConditionEmbedding(nn.Module):
return conditioning, token_replace_emb
@TransformerBlockRegistry.register(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
)
)
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
def __init__(
self,
@@ -489,6 +496,12 @@ class HunyuanVideoRotaryPosEmbed(nn.Module):
return freqs_cos, freqs_sin
@TransformerBlockRegistry.register(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
)
)
class HunyuanVideoSingleTransformerBlock(nn.Module):
def __init__(
self,
@@ -565,6 +578,12 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
@TransformerBlockRegistry.register(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
)
)
class HunyuanVideoTransformerBlock(nn.Module):
def __init__(
self,
@@ -644,6 +663,12 @@ class HunyuanVideoTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
@TransformerBlockRegistry.register(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
)
)
class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
def __init__(
self,
@@ -724,6 +749,12 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
@TransformerBlockRegistry.register(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
)
)
class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
def __init__(
self,

View File

@@ -28,6 +28,7 @@ from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
@@ -196,6 +197,12 @@ class LTXVideoRotaryPosEmbed(nn.Module):
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
)
)
class LTXVideoTransformerBlock(nn.Module):
r"""
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).

View File

@@ -27,6 +27,7 @@ from ..attention import FeedForward
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, RMSNorm
@@ -116,6 +117,12 @@ class MochiRMSNormZero(nn.Module):
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
)
)
class MochiTransformerBlock(nn.Module):
r"""
Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).

View File

@@ -22,10 +22,12 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
@@ -219,6 +221,13 @@ class WanRotaryPosEmbed(nn.Module):
return freqs
@maybe_allow_in_graph
@TransformerBlockRegistry.register(
TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
)
)
class WanTransformerBlock(nn.Module):
def __init__(
self,