1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Use RMSNorm implementation closer to original for LTX 2.0 video VAE

This commit is contained in:
Daniel Gu
2025-12-20 02:40:38 +01:00
parent b1cf6ff8a9
commit 6c56954fa8

View File

@@ -29,6 +29,38 @@ from ..normalization import RMSNorm
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
class PerChannelRMSNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
For each element along the chosen dimension, this layer normalizes the tensor
by the root-mean-square of its values across that dimension:
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
"""
def __init__(self, channel_dim: int = 1, eps: float = 1e-8) -> None:
"""
Args:
dim: Dimension along which to compute the RMS (typically channels).
eps: Small constant added for numerical stability.
"""
super().__init__()
self.channel_dim = channel_dim
self.eps = eps
def forward(self, x: torch.Tensor, channel_dim: Optional[int] = None) -> torch.Tensor:
"""
Apply RMS normalization along the configured dimension.
"""
channel_dim = channel_dim or self.channel_dim
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
mean_sq = torch.mean(x**2, dim=self.channel_dim, keepdim=True)
# Normalize by the root-mean-square (RMS).
rms = torch.sqrt(mean_sq + self.eps)
return x / rms
# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime
class LTX2VideoCausalConv3d(nn.Module):
def __init__(
@@ -120,7 +152,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
self.nonlinearity = get_activation(non_linearity)
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.norm1 = PerChannelRMSNorm()
self.conv1 = LTX2VideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
@@ -128,7 +160,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
spatial_padding_mode=spatial_padding_mode,
)
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.norm2 = PerChannelRMSNorm()
self.dropout = nn.Dropout(dropout)
self.conv2 = LTX2VideoCausalConv3d(
in_channels=out_channels,
@@ -165,8 +197,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
) -> torch.Tensor:
hidden_states = inputs
# Normalize over the channels dimension (dim 1), which is not the last dim
hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.norm1(hidden_states)
if self.scale_shift_table is not None:
temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None]
@@ -183,7 +214,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
)[None]
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]
hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.norm2(hidden_states)
if self.scale_shift_table is not None:
hidden_states = hidden_states * (1 + scale_2) + shift_2
@@ -746,7 +777,7 @@ class LTX2VideoEncoder3d(nn.Module):
)
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.norm_out = PerChannelRMSNorm()
self.conv_act = nn.SiLU()
self.conv_out = LTX2VideoCausalConv3d(
in_channels=output_channel,
@@ -788,7 +819,7 @@ class LTX2VideoEncoder3d(nn.Module):
hidden_states = self.mid_block(hidden_states, causal=causal)
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states, causal=causal)
@@ -900,7 +931,7 @@ class LTX2VideoDecoder3d(nn.Module):
self.up_blocks.append(up_block)
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.norm_out = PerChannelRMSNorm()
self.conv_act = nn.SiLU()
self.conv_out = LTX2VideoCausalConv3d(
in_channels=output_channel,
@@ -942,7 +973,7 @@ class LTX2VideoDecoder3d(nn.Module):
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states, temb, causal=causal)
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.norm_out(hidden_states)
if self.time_embedder is not None:
temb = self.time_embedder(