mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Use RMSNorm implementation closer to original for LTX 2.0 video VAE
This commit is contained in:
@@ -29,6 +29,38 @@ from ..normalization import RMSNorm
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
class PerChannelRMSNorm(nn.Module):
|
||||
"""
|
||||
Per-pixel (per-location) RMS normalization layer.
|
||||
|
||||
For each element along the chosen dimension, this layer normalizes the tensor
|
||||
by the root-mean-square of its values across that dimension:
|
||||
|
||||
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
||||
"""
|
||||
|
||||
def __init__(self, channel_dim: int = 1, eps: float = 1e-8) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim: Dimension along which to compute the RMS (typically channels).
|
||||
eps: Small constant added for numerical stability.
|
||||
"""
|
||||
super().__init__()
|
||||
self.channel_dim = channel_dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor, channel_dim: Optional[int] = None) -> torch.Tensor:
|
||||
"""
|
||||
Apply RMS normalization along the configured dimension.
|
||||
"""
|
||||
channel_dim = channel_dim or self.channel_dim
|
||||
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
|
||||
mean_sq = torch.mean(x**2, dim=self.channel_dim, keepdim=True)
|
||||
# Normalize by the root-mean-square (RMS).
|
||||
rms = torch.sqrt(mean_sq + self.eps)
|
||||
return x / rms
|
||||
|
||||
|
||||
# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime
|
||||
class LTX2VideoCausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
@@ -120,7 +152,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
|
||||
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
|
||||
self.norm1 = PerChannelRMSNorm()
|
||||
self.conv1 = LTX2VideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
@@ -128,7 +160,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
|
||||
self.norm2 = PerChannelRMSNorm()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = LTX2VideoCausalConv3d(
|
||||
in_channels=out_channels,
|
||||
@@ -165,8 +197,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs
|
||||
|
||||
# Normalize over the channels dimension (dim 1), which is not the last dim
|
||||
hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
|
||||
if self.scale_shift_table is not None:
|
||||
temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None]
|
||||
@@ -183,7 +214,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
|
||||
)[None]
|
||||
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]
|
||||
|
||||
hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
if self.scale_shift_table is not None:
|
||||
hidden_states = hidden_states * (1 + scale_2) + shift_2
|
||||
@@ -746,7 +777,7 @@ class LTX2VideoEncoder3d(nn.Module):
|
||||
)
|
||||
|
||||
# out
|
||||
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
|
||||
self.norm_out = PerChannelRMSNorm()
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = LTX2VideoCausalConv3d(
|
||||
in_channels=output_channel,
|
||||
@@ -788,7 +819,7 @@ class LTX2VideoEncoder3d(nn.Module):
|
||||
|
||||
hidden_states = self.mid_block(hidden_states, causal=causal)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states, causal=causal)
|
||||
|
||||
@@ -900,7 +931,7 @@ class LTX2VideoDecoder3d(nn.Module):
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
# out
|
||||
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
|
||||
self.norm_out = PerChannelRMSNorm()
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = LTX2VideoCausalConv3d(
|
||||
in_channels=output_channel,
|
||||
@@ -942,7 +973,7 @@ class LTX2VideoDecoder3d(nn.Module):
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states, temb, causal=causal)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
if self.time_embedder is not None:
|
||||
temb = self.time_embedder(
|
||||
|
||||
Reference in New Issue
Block a user