From 269cf7b40d3b5100637990907627b2254bf1897a Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 17 Dec 2025 10:51:34 +0100 Subject: [PATCH 1/6] Initial implementation of LTX 2.0 video VAE --- scripts/convert_ltx2_to_diffusers.py | 137 +- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoders/autoencoder_kl_ltx2.py | 1437 +++++++++++++++++ 5 files changed, 1577 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 312559dbee..dfec0262de 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -8,7 +8,7 @@ import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from diffusers import LTX2VideoTransformer3DModel +from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel from diffusers.utils.import_utils import is_accelerate_available @@ -35,6 +35,32 @@ LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = { "k_norm": "norm_k", } +LTX_2_0_VIDEO_VAE_RENAME_DICT = { + # Encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.1", + "down_blocks.3": "down_blocks.1.downsamplers.0", + "down_blocks.4": "down_blocks.2", + "down_blocks.5": "down_blocks.2.downsamplers.0", + "down_blocks.6": "down_blocks.3", + "down_blocks.7": "down_blocks.3.downsamplers.0", + "down_blocks.8": "mid_block", + # Decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + # Common + # For all 3D ResNets + "res_blocks": "resnets", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", +} + def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None: state_dict[new_key] = state_dict.pop(old_key) @@ -68,6 +94,11 @@ LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { "adaln_single": convert_ltx2_transformer_adaln_single, } +LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_inplace, + "per_channel_statistics.mean-of-stds": remove_keys_inplace, +} + def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": @@ -180,6 +211,102 @@ def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) return transformer +def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 2048), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": False, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": True, + }, + } + rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + elif version == "2.0": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 2048), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": False, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": True, + }, + } + rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vae = AutoencoderKLLTX2Video.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: if args.original_state_dict_repo_id is not None: ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) @@ -312,7 +439,13 @@ def main(args): combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename) if args.vae or args.full_pipeline: - pass + if args.vae_filename is not None: + original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename) + elif combined_ckpt is not None: + original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix) + vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version) + if not args.full_pipeline: + vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) if args.audio_vae or args.full_pipeline: pass diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 97ba02e2d0..71cad3425f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -194,6 +194,7 @@ else: "AutoencoderKLHunyuanVideo", "AutoencoderKLHunyuanVideo15", "AutoencoderKLLTXVideo", + "AutoencoderKLLTX2Video", "AutoencoderKLMagvit", "AutoencoderKLMochi", "AutoencoderKLQwenImage", @@ -928,6 +929,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, + AutoencoderKLLTX2Video, AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index b387bd817c..3f4e49015b 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -41,6 +41,7 @@ if is_torch_available(): _import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"] _import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] + _import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"] @@ -153,6 +154,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, + AutoencoderKLLTX2Video, AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 56df27f93c..ca0cac1a57 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -10,6 +10,7 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 from .autoencoder_kl_ltx import AutoencoderKLLTXVideo +from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py new file mode 100644 index 0000000000..9f65c9980d --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -0,0 +1,1437 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoCausalConv3d +class LTXVideoCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + padding_mode: str = "zeros", + is_causal: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.is_causal = is_causal + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) + + dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) + stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + height_pad = self.kernel_size[1] // 2 + width_pad = self.kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + self.kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + padding=padding, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + time_kernel_size = self.kernel_size[0] + + if self.is_causal: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) + else: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2) + + hidden_states = self.conv(hidden_states) + return hidden_states + + +# Like LTXVideoResnetBlock3d, but uses a normal Conv3d instead of a causal Conv3d for the conv_shortcut +class LTX2VideoResnetBlock3d(nn.Module): + r""" + A 3D ResNet block used in the LTX 2.0 audiovisual model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + elementwise_affine (`bool`, defaults to `False`): + Whether to enable elementwise affinity in the normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + eps: float = 1e-6, + elementwise_affine: bool = False, + non_linearity: str = "swish", + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.conv1 = LTXVideoCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + ) + + self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.dropout = nn.Dropout(dropout) + self.conv2 = LTXVideoCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + ) + + self.norm3 = None + self.conv_shortcut = None + if in_channels != out_channels: + self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) + # LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d + self.conv_shortcut = nn.Conv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1 + ) + # self.conv_shortcut = LTXVideoCausalConv3d( + # in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal + # ) + + self.per_channel_scale1 = None + self.per_channel_scale2 = None + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + + self.scale_shift_table = None + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) + + def forward( + self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None + ) -> torch.Tensor: + hidden_states = inputs + + # Normalize over the channels dimension (dim 1), which is not the last dim + hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.scale_shift_table is not None: + temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None] + shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale_1) + shift_1 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if self.per_channel_scale1 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] + + hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.scale_shift_table is not None: + hidden_states = hidden_states * (1 + scale_2) + shift_2 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.per_channel_scale2 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...] + + if self.norm3 is not None: + inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) + + if self.conv_shortcut is not None: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states + + +# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoDownsampler3d +class LTXVideoDownsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + is_causal: bool = True, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels + + out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) + + self.conv = LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + is_causal=is_causal, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) + + residual = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + residual = residual.unflatten(1, (-1, self.group_size)) + residual = residual.mean(dim=2) + + hidden_states = self.conv(hidden_states) + hidden_states = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + hidden_states = hidden_states + residual + + return hidden_states + + +# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoUpsampler3d +class LTXVideoUpsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + is_causal: bool = True, + residual: bool = False, + upscale_factor: int = 1, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.residual = residual + self.upscale_factor = upscale_factor + + out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor + + self.conv = LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + is_causal=is_causal, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.residual: + residual = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor + residual = residual.repeat(1, repeats, 1, 1, 1) + residual = residual[:, :, self.stride[0] - 1 :] + + hidden_states = self.conv(hidden_states) + hidden_states = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + hidden_states = hidden_states[:, :, self.stride[0] - 1 :] + + if self.residual: + hidden_states = hidden_states + residual + + return hidden_states + + +# Like LTX 1.0 LTXVideo095DownBlock3D, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoDownBlock3D(nn.Module): + r""" + Down block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + downsample_type: str = "conv", + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList() + + if downsample_type == "conv": + self.downsamplers.append( + LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + is_causal=is_causal, + ) + ) + elif downsample_type == "spatial": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal + ) + ) + elif downsample_type == "temporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal + ) + ) + elif downsample_type == "spatiotemporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal + ) + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d +# Like LTX 1.0 LTXVideoMidBlock3d, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoMidBlock3d(nn.Module): + r""" + A middle block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + ) -> None: + super().__init__() + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + r"""Forward method of the `LTXMidBlock3D` class.""" + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + return hidden_states + + +# Like LTXVideoUpBlock3d but with no conv_in and the updated LTX2VideoResnetBlock3d +class LTX2VideoUpBlock3d(nn.Module): + r""" + Up block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + upsample_residual: bool = False, + upscale_factor: int = 1, + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + self.conv_in = None + if in_channels != out_channels: + self.conv_in = LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + + self.upsamplers = None + if spatio_temporal_scale: + self.upsamplers = nn.ModuleList( + [ + LTXVideoUpsampler3d( + out_channels * upscale_factor, + stride=(2, 2, 2), + is_causal=is_causal, + residual=upsample_residual, + upscale_factor=upscale_factor, + ) + ] + ) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=out_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + if self.conv_in is not None: + hidden_states = self.conv_in(hidden_states, temb, generator) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + return hidden_states + + +# Like LTX 1.0 LTXVideoEncoder3d but with different default args - the spatiotemporal downsampling pattern is +# different, as is the layers_per_block (the 2.0 VAE is bigger) +class LTXVideoEncoder3d(nn.Module): + r""" + The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent + representation. + + Args: + in_channels (`int`, defaults to 3): + Number of input channels. + out_channels (`int`, defaults to 128): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 1024, 2048)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, True)`: + Whether a block should contain spatio-temporal downscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 6, 6, 2, 2)`): + The number of layers per block. + downsample_type (`Tuple[str, ...]`, defaults to `("spatial", "temporal", "spatiotemporal", "spatiotemporal")`): + The spatiotemporal downsampling pattern per block. Per-layer values can be + - `"spatial"` (downsample spatial dims by 2x) + - `"temporal"` (downsample temporal dim by 2x) + - `"spatiotemporal"` (downsample both spatial and temporal dims by 2x) + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 128, + block_out_channels: Tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: Tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, True), + layers_per_block: Tuple[int, ...] = (4, 6, 6, 2, 2), + downsample_type: Tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = True, + ): + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.in_channels = in_channels * patch_size**2 + + output_channel = out_channels + + self.conv_in = LTXVideoCausalConv3d( + in_channels=self.in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + is_causal=is_causal, + ) + + # down blocks + num_block_out_channels = len(block_out_channels) + self.down_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel + output_channel = block_out_channels[i] + + if down_block_types[i] == "LTX2VideoDownBlock3D": + down_block = LTX2VideoDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + downsample_type=downsample_type[i], + ) + else: + raise ValueError(f"Unknown down block type: {down_block_types[i]}") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[-1], + resnet_eps=resnet_norm_eps, + is_causal=is_causal, + ) + + # out + self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.conv_act = nn.SiLU() + self.conv_out = LTXVideoCausalConv3d( + in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `LTXVideoEncoder3d` class.""" + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + hidden_states = hidden_states.reshape( + batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p + ) + # Thanks for driving me insane with the weird patching order :( + hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + last_channel = hidden_states[:, -1:] + last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) + hidden_states = torch.cat([hidden_states, last_channel], dim=1) + + return hidden_states + + +# Like LTX 1.0 LTXVideoDecoder3d, but has only 3 symmetric up blocks which are causal and residual with upsample_factor 2 +class LTXVideoDecoder3d(nn.Module): + r""" + The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, defaults to 128): + Number of latent channels. + out_channels (`int`, defaults to 3): + Number of output channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal upscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `False`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + timestep_conditioning (`bool`, defaults to `False`): + Whether to condition the model on timesteps. + """ + + def __init__( + self, + in_channels: int = 128, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (256, 512, 1024), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True), + layers_per_block: Tuple[int, ...] = (5, 5, 5, 5), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = True, + inject_noise: Tuple[bool, ...] = (False, False, False), + timestep_conditioning: bool = False, + upsample_residual: Tuple[bool, ...] = (True, True, True), + upsample_factor: Tuple[bool, ...] = (2, 2, 2), + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.out_channels = out_channels * patch_size**2 + + block_out_channels = tuple(reversed(block_out_channels)) + spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) + layers_per_block = tuple(reversed(layers_per_block)) + inject_noise = tuple(reversed(inject_noise)) + upsample_residual = tuple(reversed(upsample_residual)) + upsample_factor = tuple(reversed(upsample_factor)) + output_channel = block_out_channels[0] + + self.conv_in = LTXVideoCausalConv3d( + in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal + ) + + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[0], + resnet_eps=resnet_norm_eps, + is_causal=is_causal, + inject_noise=inject_noise[0], + timestep_conditioning=timestep_conditioning, + ) + + # up blocks + num_block_out_channels = len(block_out_channels) + self.up_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel // upsample_factor[i] + output_channel = block_out_channels[i] // upsample_factor[i] + + up_block = LTX2VideoUpBlock3d( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i + 1], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + inject_noise=inject_noise[i + 1], + timestep_conditioning=timestep_conditioning, + upsample_residual=upsample_residual[i], + upscale_factor=upsample_factor[i], + ) + + self.up_blocks.append(up_block) + + # out + self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.conv_act = nn.SiLU() + self.conv_out = LTXVideoCausalConv3d( + in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal + ) + + # timestep embedding + self.time_embedder = None + self.scale_shift_table = None + self.timestep_scale_multiplier = None + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) + self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if self.timestep_scale_multiplier is not None: + temb = temb * self.timestep_scale_multiplier + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb) + else: + hidden_states = self.mid_block(hidden_states, temb) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states, temb) + + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1)) + temb = temb + self.scale_shift_table[None, ..., None, None, None] + shift, scale = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return hidden_states + + +class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [LTX](https://huggingface.co/Lightricks/LTX-Video). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Args: + in_channels (`int`, defaults to `3`): + Number of input channels. + out_channels (`int`, defaults to `3`): + Number of output channels. + latent_channels (`int`, defaults to `128`): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal downscaling or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + scaling_factor (`float`, *optional*, defaults to `1.0`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + encoder_causal (`bool`, defaults to `True`): + Whether the encoder should behave causally (future frames depend only on past frames) or not. + decoder_causal (`bool`, defaults to `False`): + Whether the decoder should behave causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 128, + block_out_channels: Tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: Tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + decoder_block_out_channels: Tuple[int, ...] = (256, 512, 1024), + layers_per_block: Tuple[int, ...] = (4, 6, 6, 2, 2), + decoder_layers_per_block: Tuple[int, ...] = (5, 5, 5, 5), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, True), + decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True), + decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False), + downsample_type: Tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + upsample_residual: Tuple[bool, ...] = (True, True, True), + upsample_factor: Tuple[int, ...] = (2, 2, 2), + timestep_conditioning: bool = False, + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + scaling_factor: float = 1.0, + encoder_causal: bool = True, + decoder_causal: bool = True, + spatial_compression_ratio: int = None, + temporal_compression_ratio: int = None, + ) -> None: + super().__init__() + + self.encoder = LTXVideoEncoder3d( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=block_out_channels, + down_block_types=down_block_types, + spatio_temporal_scaling=spatio_temporal_scaling, + layers_per_block=layers_per_block, + downsample_type=downsample_type, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=encoder_causal, + ) + self.decoder = LTXVideoDecoder3d( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=decoder_block_out_channels, + spatio_temporal_scaling=decoder_spatio_temporal_scaling, + layers_per_block=decoder_layers_per_block, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=decoder_causal, + timestep_conditioning=timestep_conditioning, + inject_noise=decoder_inject_noise, + upsample_residual=upsample_residual, + upsample_factor=upsample_factor, + ) + + latents_mean = torch.zeros((latent_channels,), requires_grad=False) + latents_std = torch.ones((latent_channels,), requires_grad=False) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + self.spatial_compression_ratio = ( + patch_size * 2 ** sum(spatio_temporal_scaling) + if spatial_compression_ratio is None + else spatial_compression_ratio + ) + self.temporal_compression_ratio = ( + patch_size_t * 2 ** sum(spatio_temporal_scaling) + if temporal_compression_ratio is None + else temporal_compression_ratio + ) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = False + self.use_framewise_decoding = False + + # This can be configured based on the amount of GPU memory available. + # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # Setting it to higher values results in higher memory usage. + self.num_sample_frames_batch_size = 16 + self.num_latent_frames_batch_size = 2 + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 + self.tile_sample_min_num_frames = 16 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 + self.tile_sample_stride_num_frames = 8 + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_min_num_frames: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + tile_sample_stride_num_frames: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + enc = self.encoder(x) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: + return self._temporal_tiled_decode(z, temb, return_dict=return_dict) + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, temb, return_dict=return_dict) + + dec = self.decoder(z, temb) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + if temb is not None: + decoded_slices = [ + self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1)) + ] + else: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z, temb).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + time = self.encoder( + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): + tile = self.tiled_encode(tile) + else: + tile = self.encoder(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] + return enc + + def _temporal_tiled_decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): + decoded = self.tiled_decode(tile, temb, return_dict=True).sample + else: + decoded = self.decoder(tile, temb) + if i > 0: + decoded = decoded[:, :, :-1, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :] + result_row.append(tile) + else: + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) + + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, temb) + if not return_dict: + return (dec.sample,) + return dec From baf23e2da3f0816d1ebe870ccd66249fa3e5ceaa Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 17 Dec 2025 11:14:45 +0100 Subject: [PATCH 2/6] Explicitly specify temporal and spatial VAE scale factors when converting --- scripts/convert_ltx2_to_diffusers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index dfec0262de..85fa169af3 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -241,6 +241,8 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "resnet_norm_eps": 1e-6, "encoder_causal": True, "decoder_causal": True, + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, }, } rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT @@ -274,6 +276,8 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "resnet_norm_eps": 1e-6, "encoder_causal": True, "decoder_causal": True, + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, }, } rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT From 5b950d6fefae4035d835e539c7b2676008ba43fc Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 17 Dec 2025 11:30:15 +0100 Subject: [PATCH 3/6] Add initial LTX 2.0 video VAE tests --- src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index 9f65c9980d..755b92c10a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -633,7 +633,7 @@ class LTX2VideoUpBlock3d(nn.Module): # Like LTX 1.0 LTXVideoEncoder3d but with different default args - the spatiotemporal downsampling pattern is # different, as is the layers_per_block (the 2.0 VAE is bigger) -class LTXVideoEncoder3d(nn.Module): +class LTX2VideoEncoder3d(nn.Module): r""" The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent representation. @@ -779,7 +779,7 @@ class LTXVideoEncoder3d(nn.Module): # Like LTX 1.0 LTXVideoDecoder3d, but has only 3 symmetric up blocks which are causal and residual with upsample_factor 2 -class LTXVideoDecoder3d(nn.Module): +class LTX2VideoDecoder3d(nn.Module): r""" The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample. @@ -1011,7 +1011,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig ) -> None: super().__init__() - self.encoder = LTXVideoEncoder3d( + self.encoder = LTX2VideoEncoder3d( in_channels=in_channels, out_channels=latent_channels, block_out_channels=block_out_channels, @@ -1024,7 +1024,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig resnet_norm_eps=resnet_norm_eps, is_causal=encoder_causal, ) - self.decoder = LTXVideoDecoder3d( + self.decoder = LTX2VideoDecoder3d( in_channels=latent_channels, out_channels=out_channels, block_out_channels=decoder_block_out_channels, From 491aae08d84d66a3db73f2fdeca96f109f28c4a7 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 17 Dec 2025 11:39:09 +0100 Subject: [PATCH 4/6] Add initial LTX 2.0 video VAE tests (part 2) --- .../test_models_autoencoder_ltx2_video.py | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/models/autoencoders/test_models_autoencoder_ltx2_video.py diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py new file mode 100644 index 0000000000..703ba54f89 --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py @@ -0,0 +1,102 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AutoencoderKLLTX2Video + +from ...testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): + model_class = AutoencoderKLLTX2Video + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_ltx_video_config(self): + return { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 8, + "block_out_channels": (8, 8, 8, 8), + "decoder_block_out_channels": (16, 32, 64), + "layers_per_block": (1, 1, 1, 1, 1), + "decoder_layers_per_block": (1, 1, 1, 1), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": False, + "patch_size": 1, + "patch_size_t": 1, + "encoder_causal": True, + "decoder_causal": True, + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + input_dict = {"sample": image} + return input_dict + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_ltx_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "LTX2VideoEncoder3d", + "LTX2VideoDecoder3d", + "LTX2VideoDownBlock3D", + "LTX2VideoMidBlock3d", + "LTX2VideoUpBlock3d", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.") + def test_forward_with_norm_groups(self): + pass From a748975a7c9a658b218694e10df6f9694e48078a Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 19 Dec 2025 07:02:38 +0100 Subject: [PATCH 5/6] Get diffusers implementation on par with official LTX 2.0 video VAE implementation --- scripts/convert_ltx2_to_diffusers.py | 8 +- .../autoencoders/autoencoder_kl_ltx2.py | 276 +++++++++++------- .../test_models_autoencoder_ltx2_video.py | 5 +- 3 files changed, 174 insertions(+), 115 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 85fa169af3..25a04e7893 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -240,7 +240,9 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "patch_size_t": 1, "resnet_norm_eps": 1e-6, "encoder_causal": True, - "decoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", "spatial_compression_ratio": 32, "temporal_compression_ratio": 8, }, @@ -275,7 +277,9 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "patch_size_t": 1, "resnet_norm_eps": 1e-6, "encoder_causal": True, - "decoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", "spatial_compression_ratio": 32, "temporal_compression_ratio": 8, }, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index 755b92c10a..6e7b4d324f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -29,8 +29,8 @@ from ..normalization import RMSNorm from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution -# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoCausalConv3d -class LTXVideoCausalConv3d(nn.Module): +# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime +class LTX2VideoCausalConv3d(nn.Module): def __init__( self, in_channels: int, @@ -39,14 +39,12 @@ class LTXVideoCausalConv3d(nn.Module): stride: Union[int, Tuple[int, int, int]] = 1, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, - padding_mode: str = "zeros", - is_causal: bool = True, + spatial_padding_mode: str = "zeros", ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels - self.is_causal = is_causal self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) @@ -63,13 +61,13 @@ class LTXVideoCausalConv3d(nn.Module): dilation=dilation, groups=groups, padding=padding, - padding_mode=padding_mode, + padding_mode=spatial_padding_mode, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: time_kernel_size = self.kernel_size[0] - if self.is_causal: + if causal: pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) else: @@ -81,7 +79,8 @@ class LTXVideoCausalConv3d(nn.Module): return hidden_states -# Like LTXVideoResnetBlock3d, but uses a normal Conv3d instead of a causal Conv3d for the conv_shortcut +# Like LTXVideoResnetBlock3d, but uses new causal Conv3d, normal Conv3d for the conv_shortcut, and the spatial padding +# mode is configurable class LTX2VideoResnetBlock3d(nn.Module): r""" A 3D ResNet block used in the LTX 2.0 audiovisual model. @@ -111,9 +110,9 @@ class LTX2VideoResnetBlock3d(nn.Module): eps: float = 1e-6, elementwise_affine: bool = False, non_linearity: str = "swish", - is_causal: bool = True, inject_noise: bool = False, timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -122,14 +121,20 @@ class LTX2VideoResnetBlock3d(nn.Module): self.nonlinearity = get_activation(non_linearity) self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) - self.conv1 = LTXVideoCausalConv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + self.conv1 = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, ) self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) self.dropout = nn.Dropout(dropout) - self.conv2 = LTXVideoCausalConv3d( - in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + self.conv2 = LTX2VideoCausalConv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, ) self.norm3 = None @@ -140,9 +145,6 @@ class LTX2VideoResnetBlock3d(nn.Module): self.conv_shortcut = nn.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1 ) - # self.conv_shortcut = LTXVideoCausalConv3d( - # in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal - # ) self.per_channel_scale1 = None self.per_channel_scale2 = None @@ -155,7 +157,11 @@ class LTX2VideoResnetBlock3d(nn.Module): self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) def forward( - self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None + self, + inputs: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: hidden_states = inputs @@ -168,7 +174,7 @@ class LTX2VideoResnetBlock3d(nn.Module): hidden_states = hidden_states * (1 + scale_1) + shift_1 hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.conv1(hidden_states) + hidden_states = self.conv1(hidden_states, causal=causal) if self.per_channel_scale1 is not None: spatial_shape = hidden_states.shape[-2:] @@ -184,7 +190,7 @@ class LTX2VideoResnetBlock3d(nn.Module): hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) + hidden_states = self.conv2(hidden_states, causal=causal) if self.per_channel_scale2 is not None: spatial_shape = hidden_states.shape[-2:] @@ -203,15 +209,14 @@ class LTX2VideoResnetBlock3d(nn.Module): return hidden_states -# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoDownsampler3d +# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d class LTXVideoDownsampler3d(nn.Module): def __init__( self, in_channels: int, out_channels: int, stride: Union[int, Tuple[int, int, int]] = 1, - is_causal: bool = True, - padding_mode: str = "zeros", + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -220,16 +225,15 @@ class LTXVideoDownsampler3d(nn.Module): out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) - self.conv = LTXVideoCausalConv3d( + self.conv = LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, - is_causal=is_causal, - padding_mode=padding_mode, + spatial_padding_mode=spatial_padding_mode, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) residual = ( @@ -241,7 +245,7 @@ class LTXVideoDownsampler3d(nn.Module): residual = residual.unflatten(1, (-1, self.group_size)) residual = residual.mean(dim=2) - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states, causal=causal) hidden_states = ( hidden_states.unflatten(4, (-1, self.stride[2])) .unflatten(3, (-1, self.stride[1])) @@ -253,16 +257,15 @@ class LTXVideoDownsampler3d(nn.Module): return hidden_states -# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoUpsampler3d +# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d class LTXVideoUpsampler3d(nn.Module): def __init__( self, in_channels: int, stride: Union[int, Tuple[int, int, int]] = 1, - is_causal: bool = True, residual: bool = False, upscale_factor: int = 1, - padding_mode: str = "zeros", + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -272,16 +275,15 @@ class LTXVideoUpsampler3d(nn.Module): out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor - self.conv = LTXVideoCausalConv3d( + self.conv = LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, - is_causal=is_causal, - padding_mode=padding_mode, + spatial_padding_mode=spatial_padding_mode, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape if self.residual: @@ -293,7 +295,7 @@ class LTXVideoUpsampler3d(nn.Module): residual = residual.repeat(1, repeats, 1, 1, 1) residual = residual[:, :, self.stride[0] - 1 :] - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states, causal=causal) hidden_states = hidden_states.reshape( batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width ) @@ -342,8 +344,8 @@ class LTX2VideoDownBlock3D(nn.Module): resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, - is_causal: bool = True, downsample_type: str = "conv", + spatial_padding_mode: str = "zeros", ): super().__init__() @@ -358,7 +360,7 @@ class LTX2VideoDownBlock3D(nn.Module): dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) ) self.resnets = nn.ModuleList(resnets) @@ -369,30 +371,39 @@ class LTX2VideoDownBlock3D(nn.Module): if downsample_type == "conv": self.downsamplers.append( - LTXVideoCausalConv3d( + LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=(2, 2, 2), - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) ) elif downsample_type == "spatial": self.downsamplers.append( LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, ) ) elif downsample_type == "temporal": self.downsamplers.append( LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, ) ) elif downsample_type == "spatiotemporal": self.downsamplers.append( LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, ) ) @@ -403,18 +414,19 @@ class LTX2VideoDownBlock3D(nn.Module): hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: r"""Forward method of the `LTXDownBlock3D` class.""" for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) else: - hidden_states = resnet(hidden_states, temb, generator) + hidden_states = resnet(hidden_states, temb, generator, causal=causal) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, causal=causal) return hidden_states @@ -449,9 +461,9 @@ class LTX2VideoMidBlock3d(nn.Module): dropout: float = 0.0, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", - is_causal: bool = True, inject_noise: bool = False, timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -468,9 +480,9 @@ class LTX2VideoMidBlock3d(nn.Module): dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, inject_noise=inject_noise, timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) ) self.resnets = nn.ModuleList(resnets) @@ -482,6 +494,7 @@ class LTX2VideoMidBlock3d(nn.Module): hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: r"""Forward method of the `LTXMidBlock3D` class.""" @@ -497,9 +510,9 @@ class LTX2VideoMidBlock3d(nn.Module): for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) else: - hidden_states = resnet(hidden_states, temb, generator) + hidden_states = resnet(hidden_states, temb, generator, causal=causal) return hidden_states @@ -540,11 +553,11 @@ class LTX2VideoUpBlock3d(nn.Module): resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, - is_causal: bool = True, inject_noise: bool = False, timestep_conditioning: bool = False, upsample_residual: bool = False, upscale_factor: int = 1, + spatial_padding_mode: str = "zeros", ): super().__init__() @@ -562,9 +575,9 @@ class LTX2VideoUpBlock3d(nn.Module): dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, inject_noise=inject_noise, timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) self.upsamplers = None @@ -574,9 +587,9 @@ class LTX2VideoUpBlock3d(nn.Module): LTXVideoUpsampler3d( out_channels * upscale_factor, stride=(2, 2, 2), - is_causal=is_causal, residual=upsample_residual, upscale_factor=upscale_factor, + spatial_padding_mode=spatial_padding_mode, ) ] ) @@ -590,9 +603,9 @@ class LTX2VideoUpBlock3d(nn.Module): dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, inject_noise=inject_noise, timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) ) self.resnets = nn.ModuleList(resnets) @@ -604,9 +617,10 @@ class LTX2VideoUpBlock3d(nn.Module): hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: if self.conv_in is not None: - hidden_states = self.conv_in(hidden_states, temb, generator) + hidden_states = self.conv_in(hidden_states, temb, generator, causal=causal) if self.time_embedder is not None: temb = self.time_embedder( @@ -620,13 +634,13 @@ class LTX2VideoUpBlock3d(nn.Module): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, causal=causal) for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) else: - hidden_states = resnet(hidden_states, temb, generator) + hidden_states = resnet(hidden_states, temb, generator, causal=causal) return hidden_states @@ -682,21 +696,23 @@ class LTX2VideoEncoder3d(nn.Module): patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, is_causal: bool = True, + spatial_padding_mode: str = "zeros", ): super().__init__() self.patch_size = patch_size self.patch_size_t = patch_size_t self.in_channels = in_channels * patch_size**2 + self.is_causal = is_causal output_channel = out_channels - self.conv_in = LTXVideoCausalConv3d( + self.conv_in = LTX2VideoCausalConv3d( in_channels=self.in_channels, out_channels=output_channel, kernel_size=3, stride=1, - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) # down blocks @@ -713,8 +729,8 @@ class LTX2VideoEncoder3d(nn.Module): num_layers=layers_per_block[i], resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], - is_causal=is_causal, downsample_type=downsample_type[i], + spatial_padding_mode=spatial_padding_mode, ) else: raise ValueError(f"Unknown down block type: {down_block_types[i]}") @@ -726,19 +742,23 @@ class LTX2VideoEncoder3d(nn.Module): in_channels=output_channel, num_layers=layers_per_block[-1], resnet_eps=resnet_norm_eps, - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) # out self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) self.conv_act = nn.SiLU() - self.conv_out = LTXVideoCausalConv3d( - in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=out_channels + 1, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, ) self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: r"""The forward method of the `LTXVideoEncoder3d` class.""" p = self.patch_size @@ -748,28 +768,29 @@ class LTX2VideoEncoder3d(nn.Module): post_patch_num_frames = num_frames // p_t post_patch_height = height // p post_patch_width = width // p + causal = causal or self.is_causal hidden_states = hidden_states.reshape( batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p ) # Thanks for driving me insane with the weird patching order :( hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) - hidden_states = self.conv_in(hidden_states) + hidden_states = self.conv_in(hidden_states, causal=causal) if torch.is_grad_enabled() and self.gradient_checkpointing: for down_block in self.down_blocks: - hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states, None, None, causal) - hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, None, None, causal) else: for down_block in self.down_blocks: - hidden_states = down_block(hidden_states) + hidden_states = down_block(hidden_states, causal=causal) - hidden_states = self.mid_block(hidden_states) + hidden_states = self.mid_block(hidden_states, causal=causal) hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) last_channel = hidden_states[:, -1:] last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) @@ -817,17 +838,19 @@ class LTX2VideoDecoder3d(nn.Module): patch_size: int = 4, patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, - is_causal: bool = True, + is_causal: bool = False, inject_noise: Tuple[bool, ...] = (False, False, False), timestep_conditioning: bool = False, upsample_residual: Tuple[bool, ...] = (True, True, True), upsample_factor: Tuple[bool, ...] = (2, 2, 2), + spatial_padding_mode: str = "reflect", ) -> None: super().__init__() self.patch_size = patch_size self.patch_size_t = patch_size_t self.out_channels = out_channels * patch_size**2 + self.is_causal = is_causal block_out_channels = tuple(reversed(block_out_channels)) spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) @@ -837,17 +860,21 @@ class LTX2VideoDecoder3d(nn.Module): upsample_factor = tuple(reversed(upsample_factor)) output_channel = block_out_channels[0] - self.conv_in = LTXVideoCausalConv3d( - in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal + self.conv_in = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, ) self.mid_block = LTX2VideoMidBlock3d( in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, - is_causal=is_causal, inject_noise=inject_noise[0], timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) # up blocks @@ -863,11 +890,11 @@ class LTX2VideoDecoder3d(nn.Module): num_layers=layers_per_block[i + 1], resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], - is_causal=is_causal, inject_noise=inject_noise[i + 1], timestep_conditioning=timestep_conditioning, upsample_residual=upsample_residual[i], upscale_factor=upsample_factor[i], + spatial_padding_mode=spatial_padding_mode, ) self.up_blocks.append(up_block) @@ -875,8 +902,12 @@ class LTX2VideoDecoder3d(nn.Module): # out self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) self.conv_act = nn.SiLU() - self.conv_out = LTXVideoCausalConv3d( - in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, ) # timestep embedding @@ -890,22 +921,26 @@ class LTX2VideoDecoder3d(nn.Module): self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: - hidden_states = self.conv_in(hidden_states) + def forward( + self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, causal: Optional[bool] = None, + ) -> torch.Tensor: + causal = causal or self.is_causal + + hidden_states = self.conv_in(hidden_states, causal=causal) if self.timestep_scale_multiplier is not None: temb = temb * self.timestep_scale_multiplier if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb, None, causal) for up_block in self.up_blocks: - hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb, None, causal) else: - hidden_states = self.mid_block(hidden_states, temb) + hidden_states = self.mid_block(hidden_states, temb, causal=causal) for up_block in self.up_blocks: - hidden_states = up_block(hidden_states, temb) + hidden_states = up_block(hidden_states, temb, causal=causal) hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) @@ -923,7 +958,7 @@ class LTX2VideoDecoder3d(nn.Module): hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) p = self.patch_size p_t = self.patch_size_t @@ -1006,6 +1041,8 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig scaling_factor: float = 1.0, encoder_causal: bool = True, decoder_causal: bool = True, + encoder_spatial_padding_mode: str = "zeros", + decoder_spatial_padding_mode: str = "reflect", spatial_compression_ratio: int = None, temporal_compression_ratio: int = None, ) -> None: @@ -1023,6 +1060,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig patch_size_t=patch_size_t, resnet_norm_eps=resnet_norm_eps, is_causal=encoder_causal, + spatial_padding_mode=encoder_spatial_padding_mode, ) self.decoder = LTX2VideoDecoder3d( in_channels=latent_channels, @@ -1038,6 +1076,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig inject_noise=decoder_inject_noise, upsample_residual=upsample_residual, upsample_factor=upsample_factor, + spatial_padding_mode=decoder_spatial_padding_mode, ) latents_mean = torch.zeros((latent_channels,), requires_grad=False) @@ -1120,22 +1159,22 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames - def _encode(self, x: torch.Tensor) -> torch.Tensor: + def _encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = x.shape if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: - return self._temporal_tiled_encode(x) + return self._temporal_tiled_encode(x, causal=causal) if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): - return self.tiled_encode(x) + return self.tiled_encode(x, causal=causal) - enc = self.encoder(x) + enc = self.encoder(x, causal=causal) return enc @apply_forward_hook def encode( - self, x: torch.Tensor, return_dict: bool = True + self, x: torch.Tensor, causal: Optional[bool] = None, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. @@ -1150,10 +1189,10 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_slicing and x.shape[0] > 1: - encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + encoded_slices = [self._encode(x_slice, causal=causal) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: - h = self._encode(x) + h = self._encode(x, causal=causal) posterior = DiagonalGaussianDistribution(h) if not return_dict: @@ -1161,7 +1200,11 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig return AutoencoderKLOutput(latent_dist=posterior) def _decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + self, + z: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + return_dict: bool = True, ) -> Union[DecoderOutput, torch.Tensor]: batch_size, num_channels, num_frames, height, width = z.shape tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio @@ -1169,12 +1212,12 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: - return self._temporal_tiled_decode(z, temb, return_dict=return_dict) + return self._temporal_tiled_decode(z, temb, causal=causal, return_dict=return_dict) if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): - return self.tiled_decode(z, temb, return_dict=return_dict) + return self.tiled_decode(z, temb, causal=causal, return_dict=return_dict) - dec = self.decoder(z, temb) + dec = self.decoder(z, temb, causal=causal) if not return_dict: return (dec,) @@ -1183,7 +1226,11 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig @apply_forward_hook def decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + self, + z: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + return_dict: bool = True, ) -> Union[DecoderOutput, torch.Tensor]: """ Decode a batch of images. @@ -1201,13 +1248,13 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig if self.use_slicing and z.shape[0] > 1: if temb is not None: decoded_slices = [ - self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1)) + self._decode(z_slice, t_slice, causal=causal).sample for z_slice, t_slice in (z.split(1), temb.split(1)) ] else: - decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded_slices = [self._decode(z_slice, causal=causal).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: - decoded = self._decode(z, temb).sample + decoded = self._decode(z, temb, causal=causal).sample if not return_dict: return (decoded,) @@ -1238,7 +1285,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig ) return b - def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + def tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. Args: @@ -1267,7 +1314,8 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig row = [] for j in range(0, width, self.tile_sample_stride_width): time = self.encoder( - x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width], + causal=causal, ) row.append(time) @@ -1290,7 +1338,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig return enc def tiled_decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True ) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images using a tiled decoder. @@ -1324,7 +1372,9 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig for i in range(0, height, tile_latent_stride_height): row = [] for j in range(0, width, tile_latent_stride_width): - time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb) + time = self.decoder( + z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb, causal=causal + ) row.append(time) rows.append(row) @@ -1349,7 +1399,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig return DecoderOutput(sample=dec) - def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + def _temporal_tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> AutoencoderKLOutput: batch_size, num_channels, num_frames, height, width = x.shape latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 @@ -1361,9 +1411,9 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig for i in range(0, num_frames, self.tile_sample_stride_num_frames): tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): - tile = self.tiled_encode(tile) + tile = self.tiled_encode(tile, causal=causal) else: - tile = self.encoder(tile) + tile = self.encoder(tile, causal=causal) if i > 0: tile = tile[:, :, 1:, :, :] row.append(tile) @@ -1380,7 +1430,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig return enc def _temporal_tiled_decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True ) -> Union[DecoderOutput, torch.Tensor]: batch_size, num_channels, num_frames, height, width = z.shape num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 @@ -1395,9 +1445,9 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig for i in range(0, num_frames, tile_latent_stride_num_frames): tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): - decoded = self.tiled_decode(tile, temb, return_dict=True).sample + decoded = self.tiled_decode(tile, temb, causal=causal, return_dict=True).sample else: - decoded = self.decoder(tile, temb) + decoded = self.decoder(tile, temb, causal=causal) if i > 0: decoded = decoded[:, :, :-1, :, :] row.append(decoded) @@ -1422,16 +1472,18 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig sample: torch.Tensor, temb: Optional[torch.Tensor] = None, sample_posterior: bool = False, + encoder_causal: Optional[bool] = None, + decoder_causal: Optional[bool] = None, return_dict: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[torch.Tensor, torch.Tensor]: x = sample - posterior = self.encode(x).latent_dist + posterior = self.encode(x, causal=encoder_causal).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() - dec = self.decode(z, temb) + dec = self.decode(z, temb, causal=decoder_causal) if not return_dict: return (dec.sample,) return dec diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py index 703ba54f89..25984d621a 100644 --- a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py @@ -55,7 +55,10 @@ class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unit "patch_size": 1, "patch_size_t": 1, "encoder_causal": True, - "decoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + # Full model uses `reflect` but this does not have deterministic backward implementation, so use `zeros` + "decoder_spatial_padding_mode": "zeros", } @property From c6a11a553038e503f5f76f5bb667030a04504277 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 19 Dec 2025 12:17:10 +0100 Subject: [PATCH 6/6] Initial LTX 2.0 vocoder implementation --- scripts/convert_ltx2_to_diffusers.py | 65 ++++++++- src/diffusers/pipelines/ltx2/vocoder.py | 173 ++++++++++++++++++++++++ 2 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/pipelines/ltx2/vocoder.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 25a04e7893..f2e879c065 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -10,6 +10,7 @@ from huggingface_hub import hf_hub_download from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel from diffusers.utils.import_utils import is_accelerate_available +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder CTX = init_empty_weights if is_accelerate_available() else nullcontext @@ -61,6 +62,13 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = { "per_channel_statistics.std-of-means": "latents_std", } +LTX_2_0_VOCODER_RENAME_DICT = { + "ups": "upsamplers", + "resblocks": "resnets", + "conv_pre": "conv_in", + "conv_post": "conv_out", +} + def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None: state_dict[new_key] = state_dict.pop(old_key) @@ -99,6 +107,8 @@ LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.mean-of-stds": remove_keys_inplace, } +LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} + def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": @@ -315,6 +325,53 @@ def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> return vae +def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "in_channels": 128, + "hidden_channels": 1024, + "out_channels": 2, + "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "upsample_factors": [6, 5, 2, 2, 2], + "resnet_kernel_sizes": [3, 7, 11], + "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "leaky_relu_negative_slope": 0.1, + "output_sampling_rate": 24000, + } + } + rename_dict = LTX_2_0_VOCODER_RENAME_DICT + special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vocoder = LTX2Vocoder.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vocoder.load_state_dict(original_state_dict, strict=True, assign=True) + return vocoder + + def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: if args.original_state_dict_repo_id is not None: ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) @@ -468,7 +525,13 @@ def main(args): transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer")) if args.vocoder or args.full_pipeline: - pass + if args.vocoder_filename is not None: + original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename) + elif combined_ckpt is not None: + original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix) + vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version) + if not args.full_pipeline: + vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder")) if args.full_pipeline: pass diff --git a/src/diffusers/pipelines/ltx2/vocoder.py b/src/diffusers/pipelines/ltx2/vocoder.py new file mode 100644 index 0000000000..c3b3c1f367 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/vocoder.py @@ -0,0 +1,173 @@ +import math +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class ResBlock(nn.Module): + def __init__( + self, + channels: int, + kernel_size: int = 3, + stride: int = 1, + dilations: Tuple[int, ...] = (1, 3, 5), + leaky_relu_negative_slope: float = 0.1, + padding_mode: str = "same", + ): + super().__init__() + self.dilations = dilations + self.negative_slope = leaky_relu_negative_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding_mode + ) + for dilation in dilations + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=stride, + dilation=1, + padding=padding_mode + ) + for _ in range(len(dilations)) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv1, conv2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, negative_slope=self.negative_slope) + xt = conv1(xt) + xt = F.leaky_relu(xt, negative_slope=self.negative_slope) + xt = conv2(xt) + x = x + xt + return x + + +class LTX2Vocoder(ModelMixin, ConfigMixin): + r""" + LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + hidden_channels: int = 1024, + out_channels: int = 2, + upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4], + upsample_factors: List[int] = [6, 5, 2, 2, 2], + resnet_kernel_sizes: List[int] = [3, 7, 11], + resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + leaky_relu_negative_slope: float = 0.1, + output_sampling_rate: int = 24000, + ): + super().__init__() + self.num_upsample_layers = len(upsample_kernel_sizes) + self.resnets_per_upsample = len(resnet_kernel_sizes) + self.out_channels = out_channels + self.total_upsample_factor = math.prod(upsample_factors) + self.negative_slope = leaky_relu_negative_slope + + if self.num_upsample_layers != len(upsample_factors): + raise ValueError( + f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length" + f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively." + ) + + if self.resnets_per_upsample != len(resnet_dilations): + raise ValueError( + f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length" + f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively." + ) + + self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3) + + self.upsamplers = nn.ModuleList() + self.resnets = nn.ModuleList() + input_channels = hidden_channels + for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + output_channels = input_channels // 2 + self.upsamplers.append( + nn.ConvTranspose1d( + input_channels, # hidden_channels // (2 ** i) + output_channels, # hidden_channels // (2 ** (i + 1)) + kernel_size, + stride=stride, + padding=(kernel_size - stride) // 2, + ) + ) + + for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations): + self.resnets.append( + ResBlock( + output_channels, + kernel_size, + dilations=dilations, + leaky_relu_negative_slope=leaky_relu_negative_slope, + ) + ) + input_channels = output_channels + + self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3) + + def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: + r""" + Forward pass of the vocoder. + + Args: + hidden_states (`torch.Tensor`): + Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last` + is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is + `True`. + time_last (`bool`, *optional*, defaults to `False`): + Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension. + + Returns: + `torch.Tensor`: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + + # Ensure that the time/frame dimension is last + if not time_last: + hidden_states = hidden_states.transpose(2, 3) + # Combine channels and frequency (mel bins) dimensions + hidden_states = hidden_states.flatten(1, 2) + + hidden_states = self.conv_in(hidden_states) + + for i in range(self.num_upsample_layers): + hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) + hidden_states = self.upsamplers[i](hidden_states) + + # Run all resnets in parallel on hidden_states + start = i * self.resnets_per_upsample + end = (i + 1) * self.resnets_per_upsample + resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0) + + hidden_states = torch.mean(resnet_outputs, dim=0) + + # NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of + # 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended + hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01) + hidden_states = self.conv_out(hidden_states) + hidden_states = torch.tanh(hidden_states) + + return hidden_states