From e271af9495435016e2af1230e66ea242e624c720 Mon Sep 17 00:00:00 2001 From: Edna <88869424+Ednaordinary@users.noreply.github.com> Date: Mon, 9 Jun 2025 21:03:10 -0600 Subject: [PATCH] working state (normalization) --- src/diffusers/models/normalization.py | 119 +++++++++++++++++++++++++- 1 file changed, 116 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 4a512c5cb1..f2b71bb688 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -171,6 +171,46 @@ class AdaLayerNormZero(nn.Module): return x, gate_msa, shift_mlp, scale_mlp, gate_mlp +class AdaLayerNormZeroPruned(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True): + super().__init__() + if num_embeddings is not None: + self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + else: + self.emb = None + + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + elif norm_type == "fp32_layer_norm": + self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, + x: torch.Tensor, + timestep: Optional[torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + hidden_dtype: Optional[torch.dtype] = None, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if self.emb is not None: + emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.squeeze(0).chunk(6, dim=0) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + class AdaLayerNormZeroSingle(nn.Module): r""" Norm layer adaptive layer norm zero (adaLN-Zero). @@ -203,6 +243,35 @@ class AdaLayerNormZeroSingle(nn.Module): return x, gate_msa +class AdaLayerNormZeroSinglePruned(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): + super().__init__() + + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + shift_msa, scale_msa, gate_msa = emb.squeeze(0).chunk(3, dim=0) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa + + class LuminaRMSNormZero(nn.Module): """ Norm layer adaptive RMS normalization zero. @@ -237,7 +306,7 @@ class AdaLayerNormSingle(nn.Module): r""" Norm layer adaptive layer norm single (adaLN-single). - As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3). + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). Parameters: embedding_dim (`int`): The size of each embedding vector. @@ -305,6 +374,50 @@ class AdaGroupNorm(nn.Module): return x +class AdaLayerNormContinuousPruned(nn.Module): + r""" + Adaptive normalization layer with a norm layer (layer_norm or rms_norm). + + Args: + embedding_dim (`int`): Embedding dimension to use during projection. + conditioning_embedding_dim (`int`): Dimension of the input condition. + elementwise_affine (`bool`, defaults to `True`): + Boolean flag to denote if affine transformation should be applied. + eps (`float`, defaults to 1e-5): Epsilon factor. + bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use. + norm_type (`str`, defaults to `"layer_norm"`): + Normalization layer to use. Values supported: "layer_norm", "rms_norm". + """ + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + shift, scale = torch.chunk(emb.squeeze(0).to(x.dtype), 2, dim=0) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + class AdaLayerNormContinuous(nn.Module): r""" Adaptive normalization layer with a norm layer (layer_norm or rms_norm). @@ -510,7 +623,7 @@ else: class RMSNorm(nn.Module): r""" - RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al. + RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al. Args: dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True. @@ -600,7 +713,7 @@ class MochiRMSNorm(nn.Module): class GlobalResponseNorm(nn.Module): r""" - Global response normalization as introduced in ConvNeXt-v2 (https://huggingface.co/papers/2301.00808). + Global response normalization as introduced in ConvNeXt-v2 (https://arxiv.org/abs/2301.00808). Args: dim (`int`): Number of dimensions to use for the `gamma` and `beta`.