From 6c56954fa876cd0aef5054d1eb0dc3ad684ebaa3 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 20 Dec 2025 02:40:38 +0100 Subject: [PATCH] Use RMSNorm implementation closer to original for LTX 2.0 video VAE --- .../autoencoders/autoencoder_kl_ltx2.py | 49 +++++++++++++++---- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index 6e7b4d324f..df59e2d748 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -29,6 +29,38 @@ from ..normalization import RMSNorm from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution +class PerChannelRMSNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + + For each element along the chosen dimension, this layer normalizes the tensor + by the root-mean-square of its values across that dimension: + + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + """ + + def __init__(self, channel_dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.channel_dim = channel_dim + self.eps = eps + + def forward(self, x: torch.Tensor, channel_dim: Optional[int] = None) -> torch.Tensor: + """ + Apply RMS normalization along the configured dimension. + """ + channel_dim = channel_dim or self.channel_dim + # Compute mean of squared values along `dim`, keep dimensions for broadcasting. + mean_sq = torch.mean(x**2, dim=self.channel_dim, keepdim=True) + # Normalize by the root-mean-square (RMS). + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + # Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime class LTX2VideoCausalConv3d(nn.Module): def __init__( @@ -120,7 +152,7 @@ class LTX2VideoResnetBlock3d(nn.Module): self.nonlinearity = get_activation(non_linearity) - self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.norm1 = PerChannelRMSNorm() self.conv1 = LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, @@ -128,7 +160,7 @@ class LTX2VideoResnetBlock3d(nn.Module): spatial_padding_mode=spatial_padding_mode, ) - self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.norm2 = PerChannelRMSNorm() self.dropout = nn.Dropout(dropout) self.conv2 = LTX2VideoCausalConv3d( in_channels=out_channels, @@ -165,8 +197,7 @@ class LTX2VideoResnetBlock3d(nn.Module): ) -> torch.Tensor: hidden_states = inputs - # Normalize over the channels dimension (dim 1), which is not the last dim - hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.norm1(hidden_states) if self.scale_shift_table is not None: temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None] @@ -183,7 +214,7 @@ class LTX2VideoResnetBlock3d(nn.Module): )[None] hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] - hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.norm2(hidden_states) if self.scale_shift_table is not None: hidden_states = hidden_states * (1 + scale_2) + shift_2 @@ -746,7 +777,7 @@ class LTX2VideoEncoder3d(nn.Module): ) # out - self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.norm_out = PerChannelRMSNorm() self.conv_act = nn.SiLU() self.conv_out = LTX2VideoCausalConv3d( in_channels=output_channel, @@ -788,7 +819,7 @@ class LTX2VideoEncoder3d(nn.Module): hidden_states = self.mid_block(hidden_states, causal=causal) - hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states, causal=causal) @@ -900,7 +931,7 @@ class LTX2VideoDecoder3d(nn.Module): self.up_blocks.append(up_block) # out - self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.norm_out = PerChannelRMSNorm() self.conv_act = nn.SiLU() self.conv_out = LTX2VideoCausalConv3d( in_channels=output_channel, @@ -942,7 +973,7 @@ class LTX2VideoDecoder3d(nn.Module): for up_block in self.up_blocks: hidden_states = up_block(hidden_states, temb, causal=causal) - hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.norm_out(hidden_states) if self.time_embedder is not None: temb = self.time_embedder(