From aa602ac4831d44a3bab2d7d90f62096e5146ed59 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 12 Dec 2025 07:52:33 +0100 Subject: [PATCH 01/19] Initial LTX 2.0 transformer implementation --- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_ltx2.py | 1206 +++++++++++++++++ 4 files changed, 1211 insertions(+) create mode 100644 src/diffusers/models/transformers/transformer_ltx2.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e69d334fdb..97ba02e2d0 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -236,6 +236,7 @@ else: "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", "LTXVideoTransformer3DModel", + "LTX2VideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", "MochiTransformer3DModel", @@ -969,6 +970,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: Kandinsky5Transformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, + LTX2VideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 29d8b0b5a5..b387bd817c 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -102,6 +102,7 @@ if is_torch_available(): _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] + _import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] @@ -209,6 +210,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: Kandinsky5Transformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, + LTX2VideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index a42f6b2716..cc8aff8142 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -34,6 +34,7 @@ if is_torch_available(): from .transformer_hunyuanimage import HunyuanImageTransformer2DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel + from .transformer_ltx2 import LTX2VideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py new file mode 100644 index 0000000000..57d71a3eb6 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -0,0 +1,1206 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, PixArtAlphaCombinedTimestepSizeEmbeddings +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def apply_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + +@dataclass +class AudioVisualModelOutput(BaseOutput): + r""" + Holds the output of an audiovisual model which produces both visual (e.g. video) and audio outputs. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on the `encoder_hidden_states` input, representing the visual output + of the model. This is typically a video (spatiotemporal) output. + audio_sample (`torch.Tensor` of shape `(batch_size, TODO)`): + The audio output of the audiovisual model. + """ + + sample: "torch.Tensor" # noqa: F821 + audio_sample: "torch.Tensor" # noqa: F821 + + +class LTX2AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3) and adapted by the LTX-2.0 + model. In particular, the number of modulation parameters to be calculated is now configurable. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_mod_params (`int`, *optional*, defaults to `6`): + The number of modulation parameters which will be calculated in the first return argument. The default of 6 + is standard, but sometimes we may want to have a different (usually smaller) number of modulation + parameters. + use_additional_conditions (`bool`, *optional*, defaults to `False`): + Whether to use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, num_mod_params: int = 6, use_additional_conditions: bool = False): + super().__init__() + self.num_mod_params = num_mod_params + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, self.num_mod_params * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class LTX2AudioVideoAttnProcessor: + r""" + Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model. + Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can + support audio-to-video (a2v) and video-to-audio (v2a) cross attention. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if is_torch_version("<", "2.0"): + raise ValueError( + "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." + ) + + def __call__( + self, + attn: "LTX2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if query_rotary_emb is not None: + query = apply_rotary_emb(query, query_rotary_emb) + key = apply_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LTX2Attention(torch.nn.Module, AttentionModuleMixin): + r""" + Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key + RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention. + """ + + _default_processor_cls = LTX2AudioVideoAttnProcessor + _available_processors = [LTX2AudioVideoAttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + kv_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = True, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + qk_norm: str = "rms_norm_across_heads", + norm_eps: float = 1e-6, + norm_elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + if qk_norm != "rms_norm_across_heads": + raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.") + + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = query_dim + self.heads = heads + + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + hidden_states = self.processor( + self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs + ) + return hidden_states + + +@maybe_allow_in_graph +class LTX2VideoTransformerBlock(nn.Module): + r""" + Transformer block used in [LTX-2.0](https://huggingface.co/Lightricks/LTX-Video). + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + audio_dim: int, + audio_num_attention_heads: int, + audio_attention_head_dim, + audio_cross_attention_dim: int, + qk_norm: str = "rms_norm_across_heads", + activation_fn: str = "gelu-approximate", + attention_bias: bool = True, + attention_out_bias: bool = True, + eps: float = 1e-6, + elementwise_affine: bool = False, + ): + super().__init__() + + # 1. Self-Attention (video and audio) + self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + ) + + self.audio_norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn1 = LTX2Attention( + query_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + ) + + # 2. Prompt Cross-Attention + self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn2 = LTX2Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + ) + + self.audio_norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn2 = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=audio_cross_attention_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + ) + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + # Audio-to-Video (a2v) Attention --> Q: Video; K,V: Audio + self.audio_to_video_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_to_video_attn = LTX2Attention( + query_dim=dim, + cross_attention_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + ) + + # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video + self.video_to_audio_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.video_to_audio_attn = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + ) + + # 4. Feedforward layers + self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + self.audio_norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn) + + # 5. Per-Layer Modulation Parameters + # Self-Attention / Feedforward AdaLayerNorm-Zero mod params + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # Per-layer a2v, v2a Cross-Attention mod params + self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) + self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim)) + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + temb_audio: torch.Tensor, + temb_ca_scale_shift: torch.Tensor, + temb_ca_audio_scale_shift: torch.Tensor, + temb_ca_gate: torch.Tensor, + temb_ca_audio_gate: torch.Tensor, + video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ca_video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ca_audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + a2v_cross_attention_mask: Optional[torch.Tensor] = None, + v2a_cross_attention_mask: Optional[torch.Tensor] = None, + use_video_self_attn: bool = True, + use_audio_self_attn: bool = True, + use_a2v_cross_attn: bool = True, + use_v2a_cross_attn: bool = True, + ) -> torch.Tensor: + batch_size = hidden_states.size(0) + + # 1. Video and Audio Self-Attention + if use_video_self_attn: + norm_hidden_states = self.norm1(hidden_states) + + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape( + batch_size, temb.size(1), num_ada_params, -1 + ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=video_rotary_emb, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa + + if use_audio_self_attn: + norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) + + num_audio_ada_params = self.audio_scale_shift_table.shape[0] + audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( + batch_size, temb_audio.size(1), num_audio_ada_params, -1 + ) + audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = audio_ada_values.unbind(dim=2) + norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa + + attn_audio_hidden_states = self.audio_attn1( + hidden_states=norm_audio_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=audio_rotary_emb, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa + + # 2. Video and Audio Cross-Attention with the text embeddings + norm_hidden_states = self.norm2(hidden_states) + attn_hidden_states = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + query_rotary_emb=None, + attention_mask=encoder_attention_mask, + ) + hidden_states = hidden_states + attn_hidden_states + + norm_audio_hidden_states = self.audio_norm2(audio_hidden_states) + attn_audio_hidden_states = self.audio_attn2( + norm_audio_hidden_states, + encoder_hidden_states=audio_encoder_hidden_states, + query_rotary_emb=None, + attention_mask=encoder_attention_mask, + ) + hidden_states = hidden_states + attn_hidden_states + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + if use_a2v_cross_attn or use_v2a_cross_attn: + norm_hidden_states = self.norm3(hidden_states) + norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) + + # Combine global and per-layer cross attention modulation parameters + # Video + video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] + video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] + + video_ca_scale_shift_table = ( + video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype) + + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + video_ca_gate = ( + video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype) + + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) + ).unbind(dim=2) + + video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table + a2v_gate = video_ca_gate[0].squeeze(2) + + # Audio + audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] + audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] + + audio_ca_scale_shift_table = ( + audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype) + + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + audio_ca_gate = ( + audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype) + + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1) + ).unbind(dim=2) + + audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table + v2a_gate = audio_ca_gate[0].squeeze(2) + + if use_a2v_cross_attn: + # Audio-to-Video Cross Attention: Q: Video; K,V: Audio + mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(2) + mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_a2v_ca_scale.squeeze(2)) + audio_a2v_ca_shift.squeeze(2) + + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + query_rotary_emb=ca_video_rotary_emb, + key_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) + + hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + + if use_v2a_cross_attn: + # Video-to-Audio Cross Attention: Q: Audio; K,V: Video + mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(2) + mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_v2a_ca_scale.squeeze(2)) + audio_v2a_ca_shift.squeeze(2) + + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states, + encoder_hidden_states=mod_norm_hidden_states, + query_rotary_emb=ca_audio_rotary_emb, + key_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) + + audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states + + # 4. Feedforward + norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp + + norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) * (1 + audio_scale_mlp) + audio_shift_mlp + audio_ff_output = self.audio_ff(norm_audio_hidden_states) + audio_hidden_states = audio_hidden_states + audio_ff_output * audio_gate_mlp + + return hidden_states, audio_hidden_states + + +class LTX2AudioVideoRotaryPosEmbed(nn.Module): + """ + Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model. + + Args: + causal_offset (`int`, *optional*, defaults to `1`): + Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where + the VAE treats the very first frame differently), but could also be 0 (for non-causal modeling). + """ + def __init__( + self, + dim: int, + patch_size: int = 1, + patch_size_t: int = 1, + base_num_frames: int = 20, + base_height: int = 2048, + base_width: int = 2048, + sampling_rate: int = 16000, + hop_length: int = 160, + scale_factors: Tuple[int, ...] = (8, 32 ,32), + theta: float = 10000.0, + causal_offset: int = 1, + modality: str = "video", + ) -> None: + super().__init__() + + self.dim = dim + self.patch_size = patch_size + self.patch_size_t = patch_size_t + + self.base_num_frames = base_num_frames + + # Video-specific + self.base_height = base_height + self.base_width = base_width + + # Audio-specific + self.sampling_rate = sampling_rate + self.hop_length = hop_length + + self.scale_factors = scale_factors + self.theta = theta + self.causal_offset = causal_offset + + self.modality = modality + if self.modality not in ["video", "audio"]: + raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.") + + def prepare_video_coords( + self, + batch_size: int, + num_frames: int, + height: int, + width: int, + device: torch.device, + ) -> torch.Tensor: + """ + Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original + pixel space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, + num_patches, 2) where + - axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames) + - axis 3 (size 2) stores `[start, end)` indices within each dimension + + Args: + batch_size (`int`): + Batch size of the video latents. + num_frames (`int`): + Number of latent frames in the video latents. + height (`int`): + Latent height of the video latents. + width (`int`): + Latent width of the video latents. + device (`torch.device`): + Device on which to create the video grid. + + Returns: + `torch.Tensor`: + Per-dimension patch boundaries tensor of shape [batch_size, 3, num_patches, 2]. + """ + + # 1. Generate grid coordinates for each spatiotemporal dimension (frames, height, width) + # Always compute rope in fp32 + grid_f = torch.arange(start=0, end=num_frames, step=self.patch_size_t, dtype=torch.float32, device=device) + grid_h = torch.arange(start=0, end=height, step=self.patch_size, dtype=torch.float32, device=device) + grid_w = torch.arange(start=0, end=width, step=self.patch_size, dtype=torch.float32, device=device) + # indexing='ij' ensures that the dimensions are kept in order as (frames, height, width) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = torch.stack(grid, dim=0) # [3, N_F, N_H, N_W], where e.g. N_F is the number of temporal patches + + # 2. Get the patch boundaries with respect to the latent video grid + patch_size = (self.patch_size_t, self.patch_size, self.patch_size) + patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device) + patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) + + # Combine the start (grid) and end (patch_ends) coordinates along new trailing dimension + latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2] + # Reshape to (batch_size, 3, num_patches, 2) + latent_coords = latent_coords.flatten(1, 3) + latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1) + + # 3. Calculate the pixel space patch boundaries from the latent boundaries. + scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device) + # Broadcast the VAE scale factors such that they are compatible with latent_coords's shape + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 # This is the (frame, height, width) dim + # Apply per-axis scaling to convert latent coordinates to pixel space coordinates + pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape) + + # As the VAE temporal stride for the first frame is 1 instead of self.vae_scale_factors[0], we need to shift + # and clamp to keep the first-frame timestamps causal and non-negative. + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0) + + return pixel_coords + + def prepare_audio_coords( + self, + batch_size: int, + num_frames: int, + device: torch.device, + shift: int = 0, + ) -> torch.Tensor: + """ + Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent + frame. This will ultimately have shape (batch_size, 3, num_patches, 2) where + - axis 1 (size 1) represents the temporal dimension + - axis 3 (size 2) stores `[start, end)` indices within each dimension + + Args: + batch_size (`int`): + Batch size of the audio latents. + num_frames (`int`): + Number of latent frames in the audio latents. + device (`torch.device`): + Device on which to create the audio grid. + shift (`int`, *optional*, defaults to `0`): + Offset on the latent indices. Different shift values correspond to different overlapping windows with + respect to the same underlying latent grid. + + Returns: + `torch.Tensor`: + Per-dimension patch boundaries tensor of shape [batch_size, 1, num_patches, 2]. + """ + + # 1. Generate coordinates in the frame (time) dimension. + # Always compute rope in fp32 + grid_f = torch.arange( + start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device + ) + + # 2. Calculate start timstamps in seconds with respect to the original spectrogram grid + audio_scale_factor = self.scale_factors[0] + # Scale back to mel spectrogram space + grid_start_mel = grid_f * audio_scale_factor + # Handle first frame causal offset, ensuring non-negative timestamps + grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0) + # Convert mel bins back into seconds + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate + + # 3. Calculate start timstamps in seconds with respect to the original spectrogram grid + grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor + grid_end_mel = (grid_end_mel + self.causal_offset - audio_scale_factor).clip(min=0) + grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate + + audio_coords = torch.stack([grid_start_s, grid_end_s], dim=-1) # [num_patches, 2] + audio_coords = audio_coords.unsqueeze(0).expand(batch_size, -1, -1) # [batch_size, num_patches, 2] + audio_coords = audio_coords.unsqueeze(1) # [batch_size, 1, num_patches, 2] + return audio_coords + + def prepare_coords(self, *args, **kwargs): + if self.modality == "video": + return self.prepare_video_coords(*args, **kwargs) + elif self.modality == "audio": + return self.prepare_audio_coords(*args, **kwargs) + + def forward( + self, + coords: Optional[torch.Tensor] = None, + batch_size: Optional[int] = None, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + fps: float = 25.0, + shift: int = 0, + device: Optional[Union[str, torch.device]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if coords is not None: + device = device or coords.device + batch_size = batch_size or coords.size(0) + else: + device = device or "cpu" + batch_size = batch_size or 1 + + # 1. Calculate the coordinate grid with respect to data space for the given modality (video, audio). + if coords is None and self.modality == "video": + coords = self.prepare_video_coords( + batch_size, + num_frames, + height, + width, + device=device, + ) + # Scale the temporal coordinates by the video FPS + coords[:, 0, ...] = coords[:, 0, ...] / fps + elif coords is None and self.modality == "audio": + coords = self.prepare_audio_coords( + batch_size, + num_frames, + device=device, + shift=shift, + ) + # Number of spatiotemporal dimensions (3 for video, 1 for audio) + num_pos_dims = coords.shape[1] + + # 2. If the coords are patch boundaries [start, end), use the midpoint of these boundaries + if coords.ndim == 4: + coords_start, coords_end = coords.chunk(2, dim=-1) + coords = (coords_start + coords_end) / 2.0 + coords = coords.squeeze(-1) # [B, num_pos_dims, num_patches] + + # 3. Get coordinates as a fraction of the base data shape + if self.modality == "video": + max_positions = (self.base_num_frames, self.base_height, self.base_width) + elif self.modality == "audio": + max_positions = (self.base_num_frames,) + grid = torch.stack([coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1).to(device) + # Number of spatiotemporal dimensions (3 for video, 1 for audio) times 2 for cos, sin + num_rope_elems = num_pos_dims * 2 + + # 4. Create a 1D grid of frequencies for RoPE + start = 1.0 + end = self.theta + freqs = self.theta ** torch.linspace( + start=math.log(start, self.theta), + end=math.log(end, self.theta), + steps=self.dim // num_rope_elems, + device=device, + dtype=torch.float32, + ) + freqs = freqs * math.pi / 2.0 + + # 5. Tensor-vector outer product between pos ids tensor of shape [B, 3, num_patches] and freqs vector of shape + # self.dim // num_elems + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, 3, num_patches, self.dim // num_elems] + freqs = freqs.transpose(1, 2).flatten(2) # [B, num_patches, self.dim // 2] + # freqs = freqs.transpose(-1, -2).flatten(2) # [B, 3, num_patches * self.dim // num_elems]??? + + # 6. Get real, interleaved (cos, sin) frequencies, padded to self.dim + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + return cos_freqs, sin_freqs + + +@maybe_allow_in_graph +class LTX2VideoTransformer3DModel( + ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin +): + r""" + A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). + + Args: + in_channels (`int`, defaults to `128`): + The number of channels in the input. + out_channels (`int`, defaults to `128`): + The number of channels in the output. + patch_size (`int`, defaults to `1`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `32`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + cross_attention_dim (`int`, defaults to `2048 `): + The number of channels for cross attention heads. + num_layers (`int`, defaults to `28`): + The number of layers of Transformer blocks to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + qk_norm (`str`, defaults to `"rms_norm_across_heads"`): + The normalization layer to use. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["norm"] + _repeated_blocks = ["LTXVideoTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + "rope": { + 0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + 1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } + + @register_to_config + def __init__( + self, + in_channels: int = 128, # Video Arguments + out_channels: Optional[int] = 128, + patch_size: int = 1, + patch_size_t: int = 1, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + cross_attention_dim: int = 4096, + vae_scale_factors: Tuple[int, int, int] = (8, 32, 32), + pos_embed_max_pos: int = 20, + base_height: int = 2048, + base_width: int = 2048, + audio_in_channels: int = 128, # Audio Arguments + audio_out_channels: Optional[int] = 128, + audio_patch_size: int = 1, + audio_patch_size_t: int = 1, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_cross_attention_dim: int = 2048, + audio_scale_factor: int = 4, + audio_pos_embed_max_pos: int = 20, + audio_sampling_rate: int = 16000, + audio_hop_length: int = 160, + num_layers: int = 48, # Shared arguments + activation_fn: str = "gelu-approximate", + qk_norm: str = "rms_norm_across_heads", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = 3840, + attention_bias: bool = True, + attention_out_bias: bool = True, + rope_theta: float = 10000.0, + causal_offset: int = 1, + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + audio_out_channels = audio_out_channels or audio_in_channels + inner_dim = num_attention_heads * attention_head_dim + audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim + + # 1. Patchification input projections + self.proj_in = nn.Linear(in_channels, inner_dim) + self.audio_proj_in = nn.Linear(audio_in_channels, inner_dim) + + # 2. Prompt embeddings + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=audio_inner_dim + ) + + # 3. Timestep Modulation Params and Embedding + # 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding + # time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters + self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False) + self.audio_time_embed = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=6, use_additional_conditions=False + ) + + # 3.2. Global Cross Attention Modulation Parameters + # Used in the audio-to-video and video-to-audio cross attention layers as a global set of modulation params, + # which are then further modified by per-block modulaton params in each transformer block. + # There are 2 sets of scale/shift parameters for each modality, 1 each for audio-to-video (a2v) and + # video-to-audio (v2a) cross attention + self.av_cross_attn_video_scale_shift = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=4, use_additional_conditions=False + ) + self.av_cross_attn_audio_scale_shift = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=4, use_additional_conditions=False + ) + # Gate param for audio-to-video (a2v) cross attn (where the video is the queries (Q) and the audio is the keys + # and values (KV)) + self.av_cross_attn_video_a2v_gate = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=1, use_additional_conditions=False + ) + # Gate param for video-to-audio (v2a) cross attn (where the audio is the queries (Q) and the video is the keys + # and values (KV)) + self.av_cross_attn_audio_v2a_gate = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=1, use_additional_conditions=False + ) + + # 3.3. Output Layer Scale/Shift Modulation parameters + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5) + + # 4. Rotary Positional Embeddings (RoPE) + # Self-Attention + self.rope = LTX2AudioVideoRotaryPosEmbed( + dim=inner_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + scale_factors=vae_scale_factors, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + ) + self.audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_inner_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=audio_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + scale_factors=[audio_scale_factor], + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + ) + + # Audio-to-Video, Video-to-Audio Cross-Attention + cross_attn_pos_embed_max_pos = max(pos_embed_max_pos, audio_pos_embed_max_pos) + self.cross_attn_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + ) + self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + ) + + # 5. Transformer Blocks + self.transformer_blocks = nn.ModuleList( + [ + LTX2VideoTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + attention_bias=attention_bias, + attention_out_bias=attention_out_bias, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + ) + for _ in range(num_layers) + ] + ) + + # 6. Output layers + self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels) + + self.audio_norm_out = nn.LayerNorm(audio_inner_dim, eps=1e-6, elementwise_affine=False) + self.audio_proj_out = nn.Linear(audio_inner_dim, audio_out_channels) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_attention_mask: torch.Tensor, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + fps: float = 25.0, + video_coords: Optional[torch.Tensor] = None, + audio_coords: Optional[torch.Tensor] = None, + timestep_scale_multiplier: int = 1000, + cross_attn_timestep_scale_multiplier: int = 1, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> torch.Tensor: + """ + Forward pass for LTX-2.0 audiovisual video transformer. + + Args: + hidden_states (`torch.Tensor`): + Input patchified video latents of shape (batch_size, num_video_tokens, in_channels). + audio_hidden_states (`torch.Tensor`): + Input patchified audio latents of shape (batch_size, num_audio_tokens, audio_in_channels). + encoder_hidden_states (`torch.Tensor`): + Input text embeddings of shape TODO. + timesteps (`torch.Tensor`): + Timestep information of shape (batch_size, num_train_timesteps). + + Returns: + `AudioVisualModelOutput` or `tuple`: + If `return_dict` is `True`, returns a structured output of type `AudioVisualModelOutput`, otherwise a + `tuple` is returned where the first element is the denoised video latent patch sequence and the second + element is the denoised audio latent patch sequence. + """ + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + batch_size = hidden_states.size(0) + + # 1. Prepare RoPE positional embeddings + if video_coords is None: + video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, hidden_states.device) + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords(batch_size, num_frames, audio_hidden_states.device) + + video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device) + audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) + + video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :], device=audio_hidden_states.device) + + # 2. Patchify input projections + hidden_states = self.proj_in(hidden_states) + audio_hidden_states = self.audio_proj_in(audio_hidden_states) + + # 3. Prepare timestep embeddings and modulation parameters + # Scale timestep + timestep = timestep * timestep_scale_multiplier + timestep_cross_attn_gate_scale_factor = cross_attn_timestep_scale_multiplier / timestep_scale_multiplier + + # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters + # temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer + # modulation with scale_shift_table (and similarly for audio) + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + temb_audio, audio_embedded_timestep = self.audio_time_embed( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + temb_audio = temb.view(batch_size, -1, temb_audio.size(-1)) + audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) + + # 3.2. Prepare global modality cross attention modulation parameters + video_cross_attn_scale_shift = self.av_cross_attn_video_scale_shift( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_a2v_gate = self.av_cross_attn_video_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.view(batch_size, -1, video_cross_attn_scale_shift.shape[-1]) + video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) + + audio_cross_attn_scale_shift = self.av_cross_attn_audio_scale_shift( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_v2a_gate = self.av_cross_attn_audio_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view(batch_size, -1, audio_cross_attn_scale_shift.shape[-1]) + audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) + + # 4. Prepare prompt embeddings + # TODO: does the audio prompt embedding start from the same text embeddings as the video one? + audio_encoder_hidden_states = self.audio_caption_projection(encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1)) + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + # 5. Run transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, audio_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + audio_hidden_states, + encoder_hidden_states, + temb, + temb_audio, + video_cross_attn_scale_shift, + audio_cross_attn_scale_shift, + video_cross_attn_a2v_gate, + audio_cross_attn_v2a_gate, + video_rotary_emb, + audio_rotary_emb, + video_cross_attn_rotary_emb, + audio_cross_attn_rotary_emb, + encoder_attention_mask, + ) + else: + hidden_states, audio_hidden_states = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + ) + + # 6. Output layers (including unpatchification) + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) + + audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] + audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] + + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_output = self.audio_proj_out(audio_hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output, audio_output) + return AudioVisualModelOutput(sample=output, audio_sample=audio_output) From b3096c3c9eaf24a9778e3c30162ad22c6ae4af8f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 13 Dec 2025 04:55:41 +0100 Subject: [PATCH 02/19] Add tests for LTX 2 transformer model --- .../test_models_transformer_ltx2.py | 115 ++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 tests/models/transformers/test_models_transformer_ltx2.py diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py new file mode 100644 index 0000000000..d67789acca --- /dev/null +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import LTX2VideoTransformer3DModel + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin + + +enable_full_determinism() + + +class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = LTX2VideoTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + # Common + batch_size = 2 + # NOTE: at 25 FPS, using the same num_frames for hidden_states and audio_hidden_states will result in video + # and audio of equal duration + num_frames = 2 + + # Video + num_channels = 4 + height = 16 + width = 16 + + # Audio + audio_num_channels = 2 + num_mel_bins = 2 + + # Text + embedding_dim = 16 + sequence_length = 16 + + hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device) + audio_hidden_states = torch.randn( + (batch_size, num_frames, audio_num_channels * num_mel_bins) + ).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "audio_hidden_states": audio_hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "audio_encoder_hidden_states": audio_encoder_hidden_states, + "timestep": timestep, + "encoder_attention_mask": encoder_attention_mask, + "num_frames": num_frames, + "height": height, + "width": width, + "fps": 25.0, + } + + @property + def input_shape(self): + return (512, 4) + + @property + def output_shape(self): + return (512, 4) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 4, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 2, + "attention_head_dim": 8, + "cross_attention_dim": 16, + "audio_in_channels": 4, + "audio_out_channels": 4, + "audio_num_attention_heads": 2, + "audio_attention_head_dim": 8, + "audio_cross_attention_dim": 16, + "num_layers": 2, + "qk_norm": "rms_norm_across_heads", + "caption_channels": 16, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"LTX2VideoTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = LTX2VideoTransformer3DModel + + def prepare_init_args_and_inputs_for_common(self): + return LTX2TransformerTests().prepare_init_args_and_inputs_for_common() From 980591de53bac7425234ebe8b295bd0375828ee8 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 13 Dec 2025 04:57:23 +0100 Subject: [PATCH 03/19] Get LTX 2 transformer tests working --- .../models/transformers/transformer_ltx2.py | 51 ++++++++++++------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 57d71a3eb6..93b59dec51 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -577,6 +577,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): # Audio-specific self.sampling_rate = sampling_rate self.hop_length = hop_length + self.audio_latents_per_second = float(sampling_rate) / float(hop_length) / float(scale_factors[0]) self.scale_factors = scale_factors self.theta = theta @@ -657,6 +658,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): batch_size: int, num_frames: int, device: torch.device, + fps: float = 25.0, shift: int = 0, ) -> torch.Tensor: """ @@ -682,9 +684,11 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): """ # 1. Generate coordinates in the frame (time) dimension. + audio_duration_s = num_frames / fps + latent_frames = int(audio_duration_s * self.audio_latents_per_second) # Always compute rope in fp32 grid_f = torch.arange( - start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device + start=shift, end=latent_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device ) # 2. Calculate start timstamps in seconds with respect to the original spectrogram grid @@ -748,10 +752,11 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): device=device, shift=shift, ) - # Number of spatiotemporal dimensions (3 for video, 1 for audio) + # Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn) num_pos_dims = coords.shape[1] - # 2. If the coords are patch boundaries [start, end), use the midpoint of these boundaries + # 2. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch + # position index if coords.ndim == 4: coords_start, coords_end = coords.chunk(2, dim=-1) coords = (coords_start + coords_end) / 2.0 @@ -762,8 +767,9 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): max_positions = (self.base_num_frames, self.base_height, self.base_width) elif self.modality == "audio": max_positions = (self.base_num_frames,) + # [B, num_pos_dims, num_patches] --> [B, num_patches, num_pos_dims] grid = torch.stack([coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1).to(device) - # Number of spatiotemporal dimensions (3 for video, 1 for audio) times 2 for cos, sin + # Number of spatiotemporal dimensions (3 for video, 1 for audio and cross attn) times 2 for cos, sin num_rope_elems = num_pos_dims * 2 # 4. Create a 1D grid of frequencies for RoPE @@ -778,11 +784,10 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): ) freqs = freqs * math.pi / 2.0 - # 5. Tensor-vector outer product between pos ids tensor of shape [B, 3, num_patches] and freqs vector of shape - # self.dim // num_elems - freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, 3, num_patches, self.dim // num_elems] - freqs = freqs.transpose(1, 2).flatten(2) # [B, num_patches, self.dim // 2] - # freqs = freqs.transpose(-1, -2).flatten(2) # [B, 3, num_patches * self.dim // num_elems]??? + # 5. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape + # (self.dim // num_elems,) + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, num_patches, num_pos_dims, self.dim // num_elems] + freqs = freqs.transpose(-1, -2).flatten(2) # [B, num_patches, self.dim // 2] # 6. Get real, interleaved (cos, sin) frequencies, padded to self.dim cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) @@ -888,7 +893,7 @@ class LTX2VideoTransformer3DModel( # 1. Patchification input projections self.proj_in = nn.Linear(in_channels, inner_dim) - self.audio_proj_in = nn.Linear(audio_in_channels, inner_dim) + self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim) # 2. Prompt embeddings self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) @@ -990,6 +995,10 @@ class LTX2VideoTransformer3DModel( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, cross_attention_dim=cross_attention_dim, + audio_dim=audio_inner_dim, + audio_num_attention_heads=audio_num_attention_heads, + audio_attention_head_dim=audio_attention_head_dim, + audio_cross_attention_dim=audio_cross_attention_dim, qk_norm=qk_norm, activation_fn=activation_fn, attention_bias=attention_bias, @@ -1015,6 +1024,7 @@ class LTX2VideoTransformer3DModel( hidden_states: torch.Tensor, audio_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, num_frames: Optional[int] = None, @@ -1077,9 +1087,13 @@ class LTX2VideoTransformer3DModel( video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device) audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) + print(f"Video RoPE cos shape: {video_rotary_emb[0].shape} | sin shape: {video_rotary_emb[1].shape}") + print(f"Audio RoPE cos shape: {audio_rotary_emb[0].shape} | sin shape: {audio_rotary_emb[1].shape}") video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :], device=audio_hidden_states.device) + print(f"Video CA RoPE cos shape: {video_cross_attn_rotary_emb[0].shape} | sin shape: {video_cross_attn_rotary_emb[1].shape}") + print(f"Audio CA RoPE cos shape: {audio_cross_attn_rotary_emb[0].shape} | sin shape: {audio_cross_attn_rotary_emb[1].shape}") # 2. Patchify input projections hidden_states = self.proj_in(hidden_states) @@ -1110,12 +1124,12 @@ class LTX2VideoTransformer3DModel( audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) # 3.2. Prepare global modality cross attention modulation parameters - video_cross_attn_scale_shift = self.av_cross_attn_video_scale_shift( + video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, ) - video_cross_attn_a2v_gate = self.av_cross_attn_video_a2v_gate( + video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( timestep.flatten() * timestep_cross_attn_gate_scale_factor, batch_size=batch_size, hidden_dtype=hidden_states.dtype, @@ -1123,12 +1137,12 @@ class LTX2VideoTransformer3DModel( video_cross_attn_scale_shift = video_cross_attn_scale_shift.view(batch_size, -1, video_cross_attn_scale_shift.shape[-1]) video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) - audio_cross_attn_scale_shift = self.av_cross_attn_audio_scale_shift( + audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) - audio_cross_attn_v2a_gate = self.av_cross_attn_audio_a2v_gate( + audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( timestep.flatten() * timestep_cross_attn_gate_scale_factor, batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, @@ -1137,13 +1151,12 @@ class LTX2VideoTransformer3DModel( audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) # 4. Prepare prompt embeddings - # TODO: does the audio prompt embedding start from the same text embeddings as the video one? - audio_encoder_hidden_states = self.audio_caption_projection(encoder_hidden_states) - audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1)) - encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1)) + # 5. Run transformer blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -1152,6 +1165,7 @@ class LTX2VideoTransformer3DModel( hidden_states, audio_hidden_states, encoder_hidden_states, + audio_encoder_hidden_states, temb, temb_audio, video_cross_attn_scale_shift, @@ -1169,6 +1183,7 @@ class LTX2VideoTransformer3DModel( hidden_states=hidden_states, audio_hidden_states=audio_hidden_states, encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, temb=temb, temb_audio=temb_audio, temb_ca_scale_shift=video_cross_attn_scale_shift, From e100b8f2a3f7d7d88ac0ca6c33a47a0dd215d8f1 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 13 Dec 2025 10:34:11 +0100 Subject: [PATCH 04/19] Rename LTX 2 compile test class to have LTX2 --- tests/models/transformers/test_models_transformer_ltx2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index d67789acca..fc089e6190 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -108,7 +108,7 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): super().test_gradient_checkpointing_is_applied(expected_set=expected_set) -class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): +class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = LTX2VideoTransformer3DModel def prepare_init_args_and_inputs_for_common(self): From 780fb61d32a7a664eec978a7d7c98784394386cc Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 13 Dec 2025 10:37:24 +0100 Subject: [PATCH 05/19] Remove RoPE debug print statements --- src/diffusers/models/transformers/transformer_ltx2.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 93b59dec51..f74f608457 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -1087,13 +1087,9 @@ class LTX2VideoTransformer3DModel( video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device) audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) - print(f"Video RoPE cos shape: {video_rotary_emb[0].shape} | sin shape: {video_rotary_emb[1].shape}") - print(f"Audio RoPE cos shape: {audio_rotary_emb[0].shape} | sin shape: {audio_rotary_emb[1].shape}") video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :], device=audio_hidden_states.device) - print(f"Video CA RoPE cos shape: {video_cross_attn_rotary_emb[0].shape} | sin shape: {video_cross_attn_rotary_emb[1].shape}") - print(f"Audio CA RoPE cos shape: {audio_cross_attn_rotary_emb[0].shape} | sin shape: {audio_cross_attn_rotary_emb[1].shape}") # 2. Patchify input projections hidden_states = self.proj_in(hidden_states) From 5765759cd33c693de60db2a4805990a12002fd61 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 15 Dec 2025 03:38:34 +0100 Subject: [PATCH 06/19] Get LTX 2 transformer compile tests passing --- src/diffusers/models/transformers/transformer_ltx2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index f74f608457..c3bb7a00a4 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -246,7 +246,6 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin): return hidden_states -@maybe_allow_in_graph class LTX2VideoTransformerBlock(nn.Module): r""" Transformer block used in [LTX-2.0](https://huggingface.co/Lightricks/LTX-Video). @@ -802,7 +801,6 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): return cos_freqs, sin_freqs -@maybe_allow_in_graph class LTX2VideoTransformer3DModel( ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin ): @@ -834,7 +832,7 @@ class LTX2VideoTransformer3DModel( _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["norm"] - _repeated_blocks = ["LTXVideoTransformerBlock"] + _repeated_blocks = ["LTX2VideoTransformerBlock"] _cp_plan = { "": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), From aeecc4d7125e111a7dba491036dbf1e26f759ecd Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 15 Dec 2025 06:38:57 +0100 Subject: [PATCH 07/19] Fix LTX 2 transformer shape errors --- .../models/transformers/transformer_ltx2.py | 16 ++++++++-------- .../transformers/test_models_transformer_ltx2.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index c3bb7a00a4..c1ad5f180f 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -297,7 +297,7 @@ class LTX2VideoTransformerBlock(nn.Module): qk_norm=qk_norm, ) - self.audio_norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) self.audio_attn1 = LTX2Attention( query_dim=audio_dim, heads=audio_num_attention_heads, @@ -322,7 +322,7 @@ class LTX2VideoTransformerBlock(nn.Module): qk_norm=qk_norm, ) - self.audio_norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) self.audio_attn2 = LTX2Attention( query_dim=audio_dim, cross_attention_dim=audio_cross_attention_dim, @@ -349,7 +349,7 @@ class LTX2VideoTransformerBlock(nn.Module): ) # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video - self.video_to_audio_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.video_to_audio_norm = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) self.video_to_audio_attn = LTX2Attention( query_dim=audio_dim, cross_attention_dim=dim, @@ -365,13 +365,13 @@ class LTX2VideoTransformerBlock(nn.Module): self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) self.ff = FeedForward(dim, activation_fn=activation_fn) - self.audio_norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_norm3 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn) # 5. Per-Layer Modulation Parameters # Self-Attention / Feedforward AdaLayerNorm-Zero mod params self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) - self.audio_scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5) # Per-layer a2v, v2a Cross-Attention mod params self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) @@ -459,8 +459,8 @@ class LTX2VideoTransformerBlock(nn.Module): # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention if use_a2v_cross_attn or use_v2a_cross_attn: - norm_hidden_states = self.norm3(hidden_states) - norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) + norm_hidden_states = self.audio_to_video_norm(hidden_states) + norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) # Combine global and per-layer cross attention modulation parameters # Video @@ -1114,7 +1114,7 @@ class LTX2VideoTransformer3DModel( batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) - temb_audio = temb.view(batch_size, -1, temb_audio.size(-1)) + temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) # 3.2. Prepare global modality cross attention modulation parameters diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index fc089e6190..c382a63eaa 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -94,8 +94,8 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): "audio_in_channels": 4, "audio_out_channels": 4, "audio_num_attention_heads": 2, - "audio_attention_head_dim": 8, - "audio_cross_attention_dim": 16, + "audio_attention_head_dim": 4, + "audio_cross_attention_dim": 8, "num_layers": 2, "qk_norm": "rms_norm_across_heads", "caption_channels": 16, From a5f2d2da6c4131449a9726e01342b20e12ab2110 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 15 Dec 2025 07:09:42 +0100 Subject: [PATCH 08/19] Initial script to convert LTX 2 transformer to diffusers --- scripts/convert_ltx2_to_diffusers.py | 318 +++++++++++++++++++++++++++ 1 file changed, 318 insertions(+) create mode 100644 scripts/convert_ltx2_to_diffusers.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py new file mode 100644 index 0000000000..286e2aed42 --- /dev/null +++ b/scripts/convert_ltx2_to_diffusers.py @@ -0,0 +1,318 @@ +import argparse +import os +from contextlib import nullcontext +from typing import Any, Dict, Optional, Tuple + +import safetensors.torch +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download + +from diffusers import LTX2VideoTransformer3DModel +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available() else nullcontext + + +LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = { + # Input Patchify Projections + "patchify_proj": "proj_in", + "audio_patchify_proj": "audio_proj_in", + # Modulation Parameters + # Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are + # substrings of the other modulation parameters below + "av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift", + "av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate", + "av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift", + "av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate", + # Transformer Blocks + # Per-Block Cross Attention Modulatin Parameters + "scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table", + "scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + + +def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None: + state_dict[new_key] = state_dict.pop(old_key) + + +def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None: + state_dict.pop(key) + + +def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None: + # Skip if not a weight, bias + if ".weight" not in key and ".bias" not in key: + return + + if key.startswith("adaln_single."): + new_key = key.replace("adaln_single.", "time_embed.") + param = state_dict.pop(key) + state_dict[new_key] = param + + if key.startswith("audio_adaln_single."): + new_key = key.replace("audio_adaln_single.", "audio_time_embed.") + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + +LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { + "video_embeddings_connector": remove_keys_inplace, + "audio_embeddings_connector": remove_keys_inplace, + "adaln_single": convert_ltx2_transformer_adaln_single, +} + + +def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + # Produces a transformer of the same size as used in test_models_transformer_ltx2.py + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 4, + "out_channels": 4, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 2, + "attention_head_dim": 8, + "cross_attention_dim": 16, + "vae_scale_factors": (8, 32 ,32), + "pos_embed_max_pos": 20, + "base_height": 2048, + "base_width": 2048, + "audio_in_channels": 4, + "audio_out_channels": 4, + "audio_patch_size": 1, + "audio_patch_size_t": 1, + "audio_num_attention_heads": 2, + "audio_attention_head_dim": 4, + "audio_cross_attention_dim": 8, + "audio_scale_factor": 4, + "audio_pos_embed_max_pos": 20, + "audio_sampling_rate": 16000, + "audio_hop_length": 160, + "num_layers": 2, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 16, + "attention_bias": True, + "attention_out_bias": True, + "rope_theta": 10000.0, + "causal_offset": 1, + }, + } + rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP + elif version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "in_channels": 128, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 32, + "attention_head_dim": 128, + "cross_attention_dim": 4096, + "vae_scale_factors": (8, 32 ,32), + "pos_embed_max_pos": 20, + "base_height": 2048, + "base_width": 2048, + "audio_in_channels": 128, + "audio_out_channels": 128, + "audio_patch_size": 1, + "audio_patch_size_t": 1, + "audio_num_attention_heads": 32, + "audio_attention_head_dim": 64, + "audio_cross_attention_dim": 2048, + "audio_scale_factor": 4, + "audio_pos_embed_max_pos": 20, + "audio_sampling_rate": 16000, + "audio_hop_length": 160, + "num_layers": 48, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 3840, + "attention_bias": True, + "attention_out_bias": True, + "rope_theta": 10000.0, + "causal_offset": 1, + }, + } + rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + transformer.load_state_dict(original_state_dict, strict=True, assign=True) + return transformer + + +def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: + if args.original_state_dict_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + + +def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]: + # Ensure that the key prefix ends with a dot (.) + if not prefix.endswith("."): + prefix = prefix + "." + + model_state_dict = {} + for param_name, param in combined_ckpt.items(): + if param_name.startswith(prefix): + model_state_dict[param_name.replace(prefix, "")] = param + return model_state_dict + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--original_state_dict_repo_id", + default="diffusers-internal-dev/new-ltx-model", + type=str, + help="HF Hub repo id with LTX 2.0 checkpoint", + ) + parser.add_argument( + "--checkpoint_path", + default=None, + type=str, + help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.", + ) + parser.add_argument( + "--version", + type=str, + default="2.0", + choices=["test", "2.0"], + help="Version of the LTX 2.0 model", + ) + + parser.add_argument( + "--combined_filename", + default="ltx-av-step-1932500-interleaved-new-vae.safetensors", + type=str, + help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)", + ) + parser.add_argument("--vae_prefix", default="vae.", type=str) + parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str) + parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str) + parser.add_argument("--vocoder_prefix", default="vocoder.", type=str) + + parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set") + parser.add_argument( + "--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set" + ) + parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set") + parser.add_argument( + "--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set" + ) + + parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") + parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") + parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") + parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model") + parser.add_argument( + "--full_pipeline", + action="store_true", + help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)", + ) + + parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +VARIANT_MAPPING = { + "fp32": None, + "fp16": "fp16", + "bf16": "bf16", +} + + +def main(args): + vae_dtype = DTYPE_MAPPING[args.vae_dtype] + audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype] + dit_dtype = DTYPE_MAPPING[args.dit_dtype] + vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype] + + combined_ckpt = None + load_combined_models = any([args.vae, args.audio_vae, args.dit, args.vocoder, args.full_pipeline]) + if args.combined_filename is not None and load_combined_models: + combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename) + + if args.vae or args.full_pipeline: + pass + + if args.audio_vae or args.full_pipeline: + pass + + if args.dit or args.full_pipeline: + if args.dit_filename is not None: + original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename) + elif combined_ckpt is not None: + original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix) + transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version) + if not args.full_pipeline: + transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer")) + + if args.vocoder or args.full_pipeline: + pass + + if args.full_pipeline: + pass + + +if __name__ == '__main__': + args = get_args() + main(args) From d86f89ddea76952279af1da5ff188562f615325f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 16 Dec 2025 07:58:12 +0100 Subject: [PATCH 09/19] Add more LTX 2 transformer audio arguments --- .../models/transformers/transformer_ltx2.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index c1ad5f180f..2ce6106eec 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -394,6 +394,7 @@ class LTX2VideoTransformerBlock(nn.Module): ca_video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ca_audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, a2v_cross_attention_mask: Optional[torch.Tensor] = None, v2a_cross_attention_mask: Optional[torch.Tensor] = None, use_video_self_attn: bool = True, @@ -453,7 +454,7 @@ class LTX2VideoTransformerBlock(nn.Module): norm_audio_hidden_states, encoder_hidden_states=audio_encoder_hidden_states, query_rotary_emb=None, - attention_mask=encoder_attention_mask, + attention_mask=audio_encoder_attention_mask, ) hidden_states = hidden_states + attn_hidden_states @@ -1024,11 +1025,13 @@ class LTX2VideoTransformer3DModel( encoder_hidden_states: torch.Tensor, audio_encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, - encoder_attention_mask: torch.Tensor, + encoder_attention_mask: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, num_frames: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None, fps: float = 25.0, + audio_num_frames: Optional[int] = None, video_coords: Optional[torch.Tensor] = None, audio_coords: Optional[torch.Tensor] = None, timestep_scale_multiplier: int = 1000, @@ -1075,13 +1078,17 @@ class LTX2VideoTransformer3DModel( encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: + audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0 + audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) + batch_size = hidden_states.size(0) # 1. Prepare RoPE positional embeddings if video_coords is None: video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, hidden_states.device) if audio_coords is None: - audio_coords = self.audio_rope.prepare_audio_coords(batch_size, num_frames, audio_hidden_states.device) + audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames, audio_hidden_states.device) video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device) audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) @@ -1171,6 +1178,7 @@ class LTX2VideoTransformer3DModel( video_cross_attn_rotary_emb, audio_cross_attn_rotary_emb, encoder_attention_mask, + audio_encoder_attention_mask, ) else: hidden_states, audio_hidden_states = block( @@ -1189,6 +1197,7 @@ class LTX2VideoTransformer3DModel( ca_video_rotary_emb=video_cross_attn_rotary_emb, ca_audio_rotary_emb=audio_cross_attn_rotary_emb, encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, ) # 6. Output layers (including unpatchification) From 57a8b9c3300201cc9609b882c3229bce3eb5cfeb Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 16 Dec 2025 10:38:03 +0100 Subject: [PATCH 10/19] Allow LTX 2 transformer to be loaded from local path for conversion --- scripts/convert_ltx2_to_diffusers.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 286e2aed42..312559dbee 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -192,6 +192,26 @@ def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: return original_state_dict +def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]: + if repo_id is None and filename is None: + raise ValueError("Please supply at least one of `repo_id` or `filename`") + + if repo_id is not None: + if filename is None: + raise ValueError("If repo_id is specified, filename must also be specified.") + ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) + else: + ckpt_path = filename + + _, ext = os.path.splitext(ckpt_path) + if ext in [".safetensors", ".sft"]: + state_dict = safetensors.torch.load_file(ckpt_path) + else: + state_dict = torch.load(ckpt_path, map_location="cpu") + + return state_dict + + def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]: # Ensure that the key prefix ends with a dot (.) if not prefix.endswith("."): @@ -299,7 +319,7 @@ def main(args): if args.dit or args.full_pipeline: if args.dit_filename is not None: - original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename) + original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename) elif combined_ckpt is not None: original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix) transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version) From a7bc052e899936396dfcd08b0a5a88abe2088b5f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 16 Dec 2025 10:44:02 +0100 Subject: [PATCH 11/19] Improve dummy inputs and add test for LTX 2 transformer consistency --- .../test_models_transformer_ltx2.py | 122 +++++++++++++++++- 1 file changed, 116 insertions(+), 6 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index c382a63eaa..0bf08f161d 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -17,7 +17,7 @@ import unittest import torch -from diffusers import LTX2VideoTransformer3DModel +from diffusers import LTX2VideoTransformer3DModel, attention_backend from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin @@ -35,16 +35,15 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): def dummy_input(self): # Common batch_size = 2 - # NOTE: at 25 FPS, using the same num_frames for hidden_states and audio_hidden_states will result in video - # and audio of equal duration - num_frames = 2 # Video + num_frames = 2 num_channels = 4 height = 16 width = 16 # Audio + audio_num_frames = 9 audio_num_channels = 2 num_mel_bins = 2 @@ -54,12 +53,12 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device) audio_hidden_states = torch.randn( - (batch_size, num_frames, audio_num_channels * num_mel_bins) + (batch_size, audio_num_frames, audio_num_channels * num_mel_bins) ).to(torch_device) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) - timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + timestep = torch.rand((batch_size,)).to(torch_device) return { "hidden_states": hidden_states, @@ -71,6 +70,7 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): "num_frames": num_frames, "height": height, "width": width, + "audio_num_frames": audio_num_frames, "fps": 25.0, } @@ -107,6 +107,116 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): expected_set = {"LTX2VideoTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def test_ltx2_consistency(self, seed=0, dtype=torch.float32): + torch.manual_seed(seed) + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + # Calculate dummy inputs in a custom manner to ensure compatibility with original code + batch_size = 2 + num_frames = 9 + latent_frames = 2 + text_embedding_dim = 16 + text_seq_len = 16 + fps = 25.0 + sampling_rate = 16000.0 + hop_length = 160.0 + + sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") + timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device) + + num_channels = 4 + latent_height = 4 + latent_width = 4 + hidden_states = torch.randn( + (batch_size, num_channels, latent_frames, latent_height, latent_width), + generator=torch.manual_seed(seed), + dtype=dtype, + device="cpu", + ) + # Patchify video latents (with patch_size (1, 1, 1)) + hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1) + hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + encoder_hidden_states = torch.randn( + (batch_size, text_seq_len, text_embedding_dim), + generator=torch.manual_seed(seed), + dtype=dtype, + device="cpu", + ) + + audio_num_channels = 2 + num_mel_bins = 2 + latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps)) + audio_hidden_states = torch.randn( + (batch_size, audio_num_channels, latent_length, num_mel_bins), + generator=torch.manual_seed(seed), + dtype=dtype, + device="cpu", + ) + # Patchify audio latents + audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3) + audio_encoder_hidden_states = torch.randn( + (batch_size, text_seq_len, text_embedding_dim), + generator=torch.manual_seed(seed), + dtype=dtype, + device="cpu", + ) + + inputs_dict = { + "hidden_states": hidden_states.to(device=torch_device), + "audio_hidden_states": audio_hidden_states.to(device=torch_device), + "encoder_hidden_states": encoder_hidden_states.to(device=torch_device), + "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device), + "timestep": timestep, + "num_frames": latent_frames, + "height": latent_height, + "width": latent_width, + "audio_num_frames": num_frames, + "fps": 25.0, + } + + model = self.model_class.from_pretrained( + "diffusers-internal-dev/dummy-ltx2", + subfolder="transformer", + device_map="cpu", + ) + # torch.manual_seed(seed) + # model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with attention_backend("native"): + with torch.no_grad(): + output = model(**inputs_dict) + + video_output, audio_output = output.to_tuple() + + self.assertIsNotNone(video_output) + self.assertIsNotNone(audio_output) + + # input & output have to have the same shape + video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels) + self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match") + audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins) + self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match") + + # Check against expected slice + # fmt: off + video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676]) + audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692]) + # fmt: on + + video_output_flat = video_output.cpu().flatten().float() + video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]]) + print(f"Video Expected Slice: {video_expected_slice}") + print(f"Video Generated Slice: {video_generated_slice}") + self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4)) + + audio_output_flat = audio_output.cpu().flatten().float() + audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]]) + print(f"Audio Expected Slice: {audio_expected_slice}") + print(f"Audio Generated Slice: {audio_generated_slice}") + self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4)) + class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = LTX2VideoTransformer3DModel From bda3ff13dbc895365fb6b3fcbb800df5f1844ecf Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 16 Dec 2025 10:53:43 +0100 Subject: [PATCH 12/19] Fix LTX 2 transformer bugs so consistency test passes --- .../models/transformers/transformer_ltx2.py | 22 +++++++++++++------ .../test_models_transformer_ltx2.py | 4 ---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 2ce6106eec..ea9bca115e 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -456,7 +456,7 @@ class LTX2VideoTransformerBlock(nn.Module): query_rotary_emb=None, attention_mask=audio_encoder_attention_mask, ) - hidden_states = hidden_states + attn_hidden_states + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention if use_a2v_cross_attn or use_v2a_cross_attn: @@ -557,7 +557,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): base_width: int = 2048, sampling_rate: int = 16000, hop_length: int = 160, - scale_factors: Tuple[int, ...] = (8, 32 ,32), + scale_factors: Tuple[int, ...] = (8, 32, 32), theta: float = 10000.0, causal_offset: int = 1, modality: str = "video", @@ -594,6 +594,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): height: int, width: int, device: torch.device, + fps: float = 25.0, ) -> torch.Tensor: """ Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original @@ -651,6 +652,9 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): # and clamp to keep the first-frame timestamps causal and non-negative. pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0) + # Scale the temporal coordinates by the video FPS + pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps + return pixel_coords def prepare_audio_coords( @@ -742,15 +746,15 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): height, width, device=device, + fps=fps, ) - # Scale the temporal coordinates by the video FPS - coords[:, 0, ...] = coords[:, 0, ...] / fps elif coords is None and self.modality == "audio": coords = self.prepare_audio_coords( batch_size, num_frames, device=device, shift=shift, + fps=fps, ) # Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn) num_pos_dims = coords.shape[1] @@ -1086,9 +1090,13 @@ class LTX2VideoTransformer3DModel( # 1. Prepare RoPE positional embeddings if video_coords is None: - video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, hidden_states.device) + video_coords = self.rope.prepare_video_coords( + batch_size, num_frames, height, width, hidden_states.device, fps=fps + ) if audio_coords is None: - audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames, audio_hidden_states.device) + audio_coords = self.audio_rope.prepare_audio_coords( + batch_size, audio_num_frames, audio_hidden_states.device, fps=fps + ) video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device) audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) @@ -1104,7 +1112,7 @@ class LTX2VideoTransformer3DModel( # Scale timestep timestep = timestep * timestep_scale_multiplier timestep_cross_attn_gate_scale_factor = cross_attn_timestep_scale_multiplier / timestep_scale_multiplier - + # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters # temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer # modulation with scale_shift_table (and similarly for audio) diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index 0bf08f161d..6c0b97c589 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -207,14 +207,10 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): video_output_flat = video_output.cpu().flatten().float() video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]]) - print(f"Video Expected Slice: {video_expected_slice}") - print(f"Video Generated Slice: {video_generated_slice}") self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4)) audio_output_flat = audio_output.cpu().flatten().float() audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]]) - print(f"Audio Expected Slice: {audio_expected_slice}") - print(f"Audio Generated Slice: {audio_generated_slice}") self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4)) From 269cf7b40d3b5100637990907627b2254bf1897a Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 17 Dec 2025 10:51:34 +0100 Subject: [PATCH 13/19] Initial implementation of LTX 2.0 video VAE --- scripts/convert_ltx2_to_diffusers.py | 137 +- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoders/autoencoder_kl_ltx2.py | 1437 +++++++++++++++++ 5 files changed, 1577 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 312559dbee..dfec0262de 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -8,7 +8,7 @@ import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from diffusers import LTX2VideoTransformer3DModel +from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel from diffusers.utils.import_utils import is_accelerate_available @@ -35,6 +35,32 @@ LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = { "k_norm": "norm_k", } +LTX_2_0_VIDEO_VAE_RENAME_DICT = { + # Encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.1", + "down_blocks.3": "down_blocks.1.downsamplers.0", + "down_blocks.4": "down_blocks.2", + "down_blocks.5": "down_blocks.2.downsamplers.0", + "down_blocks.6": "down_blocks.3", + "down_blocks.7": "down_blocks.3.downsamplers.0", + "down_blocks.8": "mid_block", + # Decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + # Common + # For all 3D ResNets + "res_blocks": "resnets", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", +} + def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None: state_dict[new_key] = state_dict.pop(old_key) @@ -68,6 +94,11 @@ LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { "adaln_single": convert_ltx2_transformer_adaln_single, } +LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_inplace, + "per_channel_statistics.mean-of-stds": remove_keys_inplace, +} + def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": @@ -180,6 +211,102 @@ def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) return transformer +def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 2048), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": False, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": True, + }, + } + rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + elif version == "2.0": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 2048), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": False, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": True, + }, + } + rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vae = AutoencoderKLLTX2Video.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: if args.original_state_dict_repo_id is not None: ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) @@ -312,7 +439,13 @@ def main(args): combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename) if args.vae or args.full_pipeline: - pass + if args.vae_filename is not None: + original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename) + elif combined_ckpt is not None: + original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix) + vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version) + if not args.full_pipeline: + vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) if args.audio_vae or args.full_pipeline: pass diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 97ba02e2d0..71cad3425f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -194,6 +194,7 @@ else: "AutoencoderKLHunyuanVideo", "AutoencoderKLHunyuanVideo15", "AutoencoderKLLTXVideo", + "AutoencoderKLLTX2Video", "AutoencoderKLMagvit", "AutoencoderKLMochi", "AutoencoderKLQwenImage", @@ -928,6 +929,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, + AutoencoderKLLTX2Video, AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index b387bd817c..3f4e49015b 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -41,6 +41,7 @@ if is_torch_available(): _import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"] _import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] + _import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"] @@ -153,6 +154,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, + AutoencoderKLLTX2Video, AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 56df27f93c..ca0cac1a57 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -10,6 +10,7 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 from .autoencoder_kl_ltx import AutoencoderKLLTXVideo +from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py new file mode 100644 index 0000000000..9f65c9980d --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -0,0 +1,1437 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoCausalConv3d +class LTXVideoCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + padding_mode: str = "zeros", + is_causal: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.is_causal = is_causal + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) + + dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) + stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + height_pad = self.kernel_size[1] // 2 + width_pad = self.kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + self.kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + padding=padding, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + time_kernel_size = self.kernel_size[0] + + if self.is_causal: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) + else: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2) + + hidden_states = self.conv(hidden_states) + return hidden_states + + +# Like LTXVideoResnetBlock3d, but uses a normal Conv3d instead of a causal Conv3d for the conv_shortcut +class LTX2VideoResnetBlock3d(nn.Module): + r""" + A 3D ResNet block used in the LTX 2.0 audiovisual model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + elementwise_affine (`bool`, defaults to `False`): + Whether to enable elementwise affinity in the normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + eps: float = 1e-6, + elementwise_affine: bool = False, + non_linearity: str = "swish", + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.conv1 = LTXVideoCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + ) + + self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.dropout = nn.Dropout(dropout) + self.conv2 = LTXVideoCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + ) + + self.norm3 = None + self.conv_shortcut = None + if in_channels != out_channels: + self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) + # LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d + self.conv_shortcut = nn.Conv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1 + ) + # self.conv_shortcut = LTXVideoCausalConv3d( + # in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal + # ) + + self.per_channel_scale1 = None + self.per_channel_scale2 = None + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + + self.scale_shift_table = None + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) + + def forward( + self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None + ) -> torch.Tensor: + hidden_states = inputs + + # Normalize over the channels dimension (dim 1), which is not the last dim + hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.scale_shift_table is not None: + temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None] + shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale_1) + shift_1 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if self.per_channel_scale1 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] + + hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.scale_shift_table is not None: + hidden_states = hidden_states * (1 + scale_2) + shift_2 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.per_channel_scale2 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...] + + if self.norm3 is not None: + inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) + + if self.conv_shortcut is not None: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states + + +# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoDownsampler3d +class LTXVideoDownsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + is_causal: bool = True, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels + + out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) + + self.conv = LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + is_causal=is_causal, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) + + residual = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + residual = residual.unflatten(1, (-1, self.group_size)) + residual = residual.mean(dim=2) + + hidden_states = self.conv(hidden_states) + hidden_states = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + hidden_states = hidden_states + residual + + return hidden_states + + +# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoUpsampler3d +class LTXVideoUpsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + is_causal: bool = True, + residual: bool = False, + upscale_factor: int = 1, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.residual = residual + self.upscale_factor = upscale_factor + + out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor + + self.conv = LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + is_causal=is_causal, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.residual: + residual = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor + residual = residual.repeat(1, repeats, 1, 1, 1) + residual = residual[:, :, self.stride[0] - 1 :] + + hidden_states = self.conv(hidden_states) + hidden_states = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + hidden_states = hidden_states[:, :, self.stride[0] - 1 :] + + if self.residual: + hidden_states = hidden_states + residual + + return hidden_states + + +# Like LTX 1.0 LTXVideo095DownBlock3D, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoDownBlock3D(nn.Module): + r""" + Down block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + downsample_type: str = "conv", + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList() + + if downsample_type == "conv": + self.downsamplers.append( + LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + is_causal=is_causal, + ) + ) + elif downsample_type == "spatial": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal + ) + ) + elif downsample_type == "temporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal + ) + ) + elif downsample_type == "spatiotemporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal + ) + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d +# Like LTX 1.0 LTXVideoMidBlock3d, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoMidBlock3d(nn.Module): + r""" + A middle block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + ) -> None: + super().__init__() + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + r"""Forward method of the `LTXMidBlock3D` class.""" + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + return hidden_states + + +# Like LTXVideoUpBlock3d but with no conv_in and the updated LTX2VideoResnetBlock3d +class LTX2VideoUpBlock3d(nn.Module): + r""" + Up block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + upsample_residual: bool = False, + upscale_factor: int = 1, + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + self.conv_in = None + if in_channels != out_channels: + self.conv_in = LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + + self.upsamplers = None + if spatio_temporal_scale: + self.upsamplers = nn.ModuleList( + [ + LTXVideoUpsampler3d( + out_channels * upscale_factor, + stride=(2, 2, 2), + is_causal=is_causal, + residual=upsample_residual, + upscale_factor=upscale_factor, + ) + ] + ) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=out_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + if self.conv_in is not None: + hidden_states = self.conv_in(hidden_states, temb, generator) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + return hidden_states + + +# Like LTX 1.0 LTXVideoEncoder3d but with different default args - the spatiotemporal downsampling pattern is +# different, as is the layers_per_block (the 2.0 VAE is bigger) +class LTXVideoEncoder3d(nn.Module): + r""" + The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent + representation. + + Args: + in_channels (`int`, defaults to 3): + Number of input channels. + out_channels (`int`, defaults to 128): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 1024, 2048)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, True)`: + Whether a block should contain spatio-temporal downscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 6, 6, 2, 2)`): + The number of layers per block. + downsample_type (`Tuple[str, ...]`, defaults to `("spatial", "temporal", "spatiotemporal", "spatiotemporal")`): + The spatiotemporal downsampling pattern per block. Per-layer values can be + - `"spatial"` (downsample spatial dims by 2x) + - `"temporal"` (downsample temporal dim by 2x) + - `"spatiotemporal"` (downsample both spatial and temporal dims by 2x) + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 128, + block_out_channels: Tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: Tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, True), + layers_per_block: Tuple[int, ...] = (4, 6, 6, 2, 2), + downsample_type: Tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = True, + ): + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.in_channels = in_channels * patch_size**2 + + output_channel = out_channels + + self.conv_in = LTXVideoCausalConv3d( + in_channels=self.in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + is_causal=is_causal, + ) + + # down blocks + num_block_out_channels = len(block_out_channels) + self.down_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel + output_channel = block_out_channels[i] + + if down_block_types[i] == "LTX2VideoDownBlock3D": + down_block = LTX2VideoDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + downsample_type=downsample_type[i], + ) + else: + raise ValueError(f"Unknown down block type: {down_block_types[i]}") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[-1], + resnet_eps=resnet_norm_eps, + is_causal=is_causal, + ) + + # out + self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.conv_act = nn.SiLU() + self.conv_out = LTXVideoCausalConv3d( + in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `LTXVideoEncoder3d` class.""" + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + hidden_states = hidden_states.reshape( + batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p + ) + # Thanks for driving me insane with the weird patching order :( + hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + last_channel = hidden_states[:, -1:] + last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) + hidden_states = torch.cat([hidden_states, last_channel], dim=1) + + return hidden_states + + +# Like LTX 1.0 LTXVideoDecoder3d, but has only 3 symmetric up blocks which are causal and residual with upsample_factor 2 +class LTXVideoDecoder3d(nn.Module): + r""" + The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, defaults to 128): + Number of latent channels. + out_channels (`int`, defaults to 3): + Number of output channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal upscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `False`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + timestep_conditioning (`bool`, defaults to `False`): + Whether to condition the model on timesteps. + """ + + def __init__( + self, + in_channels: int = 128, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (256, 512, 1024), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True), + layers_per_block: Tuple[int, ...] = (5, 5, 5, 5), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = True, + inject_noise: Tuple[bool, ...] = (False, False, False), + timestep_conditioning: bool = False, + upsample_residual: Tuple[bool, ...] = (True, True, True), + upsample_factor: Tuple[bool, ...] = (2, 2, 2), + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.out_channels = out_channels * patch_size**2 + + block_out_channels = tuple(reversed(block_out_channels)) + spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) + layers_per_block = tuple(reversed(layers_per_block)) + inject_noise = tuple(reversed(inject_noise)) + upsample_residual = tuple(reversed(upsample_residual)) + upsample_factor = tuple(reversed(upsample_factor)) + output_channel = block_out_channels[0] + + self.conv_in = LTXVideoCausalConv3d( + in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal + ) + + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[0], + resnet_eps=resnet_norm_eps, + is_causal=is_causal, + inject_noise=inject_noise[0], + timestep_conditioning=timestep_conditioning, + ) + + # up blocks + num_block_out_channels = len(block_out_channels) + self.up_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel // upsample_factor[i] + output_channel = block_out_channels[i] // upsample_factor[i] + + up_block = LTX2VideoUpBlock3d( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i + 1], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + inject_noise=inject_noise[i + 1], + timestep_conditioning=timestep_conditioning, + upsample_residual=upsample_residual[i], + upscale_factor=upsample_factor[i], + ) + + self.up_blocks.append(up_block) + + # out + self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.conv_act = nn.SiLU() + self.conv_out = LTXVideoCausalConv3d( + in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal + ) + + # timestep embedding + self.time_embedder = None + self.scale_shift_table = None + self.timestep_scale_multiplier = None + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) + self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if self.timestep_scale_multiplier is not None: + temb = temb * self.timestep_scale_multiplier + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb) + else: + hidden_states = self.mid_block(hidden_states, temb) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states, temb) + + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1)) + temb = temb + self.scale_shift_table[None, ..., None, None, None] + shift, scale = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return hidden_states + + +class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [LTX](https://huggingface.co/Lightricks/LTX-Video). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Args: + in_channels (`int`, defaults to `3`): + Number of input channels. + out_channels (`int`, defaults to `3`): + Number of output channels. + latent_channels (`int`, defaults to `128`): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal downscaling or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + scaling_factor (`float`, *optional*, defaults to `1.0`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + encoder_causal (`bool`, defaults to `True`): + Whether the encoder should behave causally (future frames depend only on past frames) or not. + decoder_causal (`bool`, defaults to `False`): + Whether the decoder should behave causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 128, + block_out_channels: Tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: Tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + decoder_block_out_channels: Tuple[int, ...] = (256, 512, 1024), + layers_per_block: Tuple[int, ...] = (4, 6, 6, 2, 2), + decoder_layers_per_block: Tuple[int, ...] = (5, 5, 5, 5), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, True), + decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True), + decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False), + downsample_type: Tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + upsample_residual: Tuple[bool, ...] = (True, True, True), + upsample_factor: Tuple[int, ...] = (2, 2, 2), + timestep_conditioning: bool = False, + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + scaling_factor: float = 1.0, + encoder_causal: bool = True, + decoder_causal: bool = True, + spatial_compression_ratio: int = None, + temporal_compression_ratio: int = None, + ) -> None: + super().__init__() + + self.encoder = LTXVideoEncoder3d( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=block_out_channels, + down_block_types=down_block_types, + spatio_temporal_scaling=spatio_temporal_scaling, + layers_per_block=layers_per_block, + downsample_type=downsample_type, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=encoder_causal, + ) + self.decoder = LTXVideoDecoder3d( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=decoder_block_out_channels, + spatio_temporal_scaling=decoder_spatio_temporal_scaling, + layers_per_block=decoder_layers_per_block, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=decoder_causal, + timestep_conditioning=timestep_conditioning, + inject_noise=decoder_inject_noise, + upsample_residual=upsample_residual, + upsample_factor=upsample_factor, + ) + + latents_mean = torch.zeros((latent_channels,), requires_grad=False) + latents_std = torch.ones((latent_channels,), requires_grad=False) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + self.spatial_compression_ratio = ( + patch_size * 2 ** sum(spatio_temporal_scaling) + if spatial_compression_ratio is None + else spatial_compression_ratio + ) + self.temporal_compression_ratio = ( + patch_size_t * 2 ** sum(spatio_temporal_scaling) + if temporal_compression_ratio is None + else temporal_compression_ratio + ) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = False + self.use_framewise_decoding = False + + # This can be configured based on the amount of GPU memory available. + # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # Setting it to higher values results in higher memory usage. + self.num_sample_frames_batch_size = 16 + self.num_latent_frames_batch_size = 2 + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 + self.tile_sample_min_num_frames = 16 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 + self.tile_sample_stride_num_frames = 8 + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_min_num_frames: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + tile_sample_stride_num_frames: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + enc = self.encoder(x) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: + return self._temporal_tiled_decode(z, temb, return_dict=return_dict) + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, temb, return_dict=return_dict) + + dec = self.decoder(z, temb) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + if temb is not None: + decoded_slices = [ + self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1)) + ] + else: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z, temb).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + time = self.encoder( + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): + tile = self.tiled_encode(tile) + else: + tile = self.encoder(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] + return enc + + def _temporal_tiled_decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): + decoded = self.tiled_decode(tile, temb, return_dict=True).sample + else: + decoded = self.decoder(tile, temb) + if i > 0: + decoded = decoded[:, :, :-1, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :] + result_row.append(tile) + else: + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) + + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, temb) + if not return_dict: + return (dec.sample,) + return dec From baf23e2da3f0816d1ebe870ccd66249fa3e5ceaa Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 17 Dec 2025 11:14:45 +0100 Subject: [PATCH 14/19] Explicitly specify temporal and spatial VAE scale factors when converting --- scripts/convert_ltx2_to_diffusers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index dfec0262de..85fa169af3 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -241,6 +241,8 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "resnet_norm_eps": 1e-6, "encoder_causal": True, "decoder_causal": True, + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, }, } rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT @@ -274,6 +276,8 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "resnet_norm_eps": 1e-6, "encoder_causal": True, "decoder_causal": True, + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, }, } rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT From 5b950d6fefae4035d835e539c7b2676008ba43fc Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 17 Dec 2025 11:30:15 +0100 Subject: [PATCH 15/19] Add initial LTX 2.0 video VAE tests --- src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index 9f65c9980d..755b92c10a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -633,7 +633,7 @@ class LTX2VideoUpBlock3d(nn.Module): # Like LTX 1.0 LTXVideoEncoder3d but with different default args - the spatiotemporal downsampling pattern is # different, as is the layers_per_block (the 2.0 VAE is bigger) -class LTXVideoEncoder3d(nn.Module): +class LTX2VideoEncoder3d(nn.Module): r""" The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent representation. @@ -779,7 +779,7 @@ class LTXVideoEncoder3d(nn.Module): # Like LTX 1.0 LTXVideoDecoder3d, but has only 3 symmetric up blocks which are causal and residual with upsample_factor 2 -class LTXVideoDecoder3d(nn.Module): +class LTX2VideoDecoder3d(nn.Module): r""" The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample. @@ -1011,7 +1011,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig ) -> None: super().__init__() - self.encoder = LTXVideoEncoder3d( + self.encoder = LTX2VideoEncoder3d( in_channels=in_channels, out_channels=latent_channels, block_out_channels=block_out_channels, @@ -1024,7 +1024,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig resnet_norm_eps=resnet_norm_eps, is_causal=encoder_causal, ) - self.decoder = LTXVideoDecoder3d( + self.decoder = LTX2VideoDecoder3d( in_channels=latent_channels, out_channels=out_channels, block_out_channels=decoder_block_out_channels, From 491aae08d84d66a3db73f2fdeca96f109f28c4a7 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 17 Dec 2025 11:39:09 +0100 Subject: [PATCH 16/19] Add initial LTX 2.0 video VAE tests (part 2) --- .../test_models_autoencoder_ltx2_video.py | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/models/autoencoders/test_models_autoencoder_ltx2_video.py diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py new file mode 100644 index 0000000000..703ba54f89 --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py @@ -0,0 +1,102 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AutoencoderKLLTX2Video + +from ...testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): + model_class = AutoencoderKLLTX2Video + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_ltx_video_config(self): + return { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 8, + "block_out_channels": (8, 8, 8, 8), + "decoder_block_out_channels": (16, 32, 64), + "layers_per_block": (1, 1, 1, 1, 1), + "decoder_layers_per_block": (1, 1, 1, 1), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": False, + "patch_size": 1, + "patch_size_t": 1, + "encoder_causal": True, + "decoder_causal": True, + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + input_dict = {"sample": image} + return input_dict + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_ltx_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "LTX2VideoEncoder3d", + "LTX2VideoDecoder3d", + "LTX2VideoDownBlock3D", + "LTX2VideoMidBlock3d", + "LTX2VideoUpBlock3d", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.") + def test_forward_with_norm_groups(self): + pass From a748975a7c9a658b218694e10df6f9694e48078a Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 19 Dec 2025 07:02:38 +0100 Subject: [PATCH 17/19] Get diffusers implementation on par with official LTX 2.0 video VAE implementation --- scripts/convert_ltx2_to_diffusers.py | 8 +- .../autoencoders/autoencoder_kl_ltx2.py | 276 +++++++++++------- .../test_models_autoencoder_ltx2_video.py | 5 +- 3 files changed, 174 insertions(+), 115 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 85fa169af3..25a04e7893 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -240,7 +240,9 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "patch_size_t": 1, "resnet_norm_eps": 1e-6, "encoder_causal": True, - "decoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", "spatial_compression_ratio": 32, "temporal_compression_ratio": 8, }, @@ -275,7 +277,9 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "patch_size_t": 1, "resnet_norm_eps": 1e-6, "encoder_causal": True, - "decoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", "spatial_compression_ratio": 32, "temporal_compression_ratio": 8, }, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index 755b92c10a..6e7b4d324f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -29,8 +29,8 @@ from ..normalization import RMSNorm from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution -# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoCausalConv3d -class LTXVideoCausalConv3d(nn.Module): +# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime +class LTX2VideoCausalConv3d(nn.Module): def __init__( self, in_channels: int, @@ -39,14 +39,12 @@ class LTXVideoCausalConv3d(nn.Module): stride: Union[int, Tuple[int, int, int]] = 1, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, - padding_mode: str = "zeros", - is_causal: bool = True, + spatial_padding_mode: str = "zeros", ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels - self.is_causal = is_causal self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) @@ -63,13 +61,13 @@ class LTXVideoCausalConv3d(nn.Module): dilation=dilation, groups=groups, padding=padding, - padding_mode=padding_mode, + padding_mode=spatial_padding_mode, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: time_kernel_size = self.kernel_size[0] - if self.is_causal: + if causal: pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) else: @@ -81,7 +79,8 @@ class LTXVideoCausalConv3d(nn.Module): return hidden_states -# Like LTXVideoResnetBlock3d, but uses a normal Conv3d instead of a causal Conv3d for the conv_shortcut +# Like LTXVideoResnetBlock3d, but uses new causal Conv3d, normal Conv3d for the conv_shortcut, and the spatial padding +# mode is configurable class LTX2VideoResnetBlock3d(nn.Module): r""" A 3D ResNet block used in the LTX 2.0 audiovisual model. @@ -111,9 +110,9 @@ class LTX2VideoResnetBlock3d(nn.Module): eps: float = 1e-6, elementwise_affine: bool = False, non_linearity: str = "swish", - is_causal: bool = True, inject_noise: bool = False, timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -122,14 +121,20 @@ class LTX2VideoResnetBlock3d(nn.Module): self.nonlinearity = get_activation(non_linearity) self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) - self.conv1 = LTXVideoCausalConv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + self.conv1 = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, ) self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) self.dropout = nn.Dropout(dropout) - self.conv2 = LTXVideoCausalConv3d( - in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + self.conv2 = LTX2VideoCausalConv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, ) self.norm3 = None @@ -140,9 +145,6 @@ class LTX2VideoResnetBlock3d(nn.Module): self.conv_shortcut = nn.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1 ) - # self.conv_shortcut = LTXVideoCausalConv3d( - # in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal - # ) self.per_channel_scale1 = None self.per_channel_scale2 = None @@ -155,7 +157,11 @@ class LTX2VideoResnetBlock3d(nn.Module): self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) def forward( - self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None + self, + inputs: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: hidden_states = inputs @@ -168,7 +174,7 @@ class LTX2VideoResnetBlock3d(nn.Module): hidden_states = hidden_states * (1 + scale_1) + shift_1 hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.conv1(hidden_states) + hidden_states = self.conv1(hidden_states, causal=causal) if self.per_channel_scale1 is not None: spatial_shape = hidden_states.shape[-2:] @@ -184,7 +190,7 @@ class LTX2VideoResnetBlock3d(nn.Module): hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) + hidden_states = self.conv2(hidden_states, causal=causal) if self.per_channel_scale2 is not None: spatial_shape = hidden_states.shape[-2:] @@ -203,15 +209,14 @@ class LTX2VideoResnetBlock3d(nn.Module): return hidden_states -# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoDownsampler3d +# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d class LTXVideoDownsampler3d(nn.Module): def __init__( self, in_channels: int, out_channels: int, stride: Union[int, Tuple[int, int, int]] = 1, - is_causal: bool = True, - padding_mode: str = "zeros", + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -220,16 +225,15 @@ class LTXVideoDownsampler3d(nn.Module): out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) - self.conv = LTXVideoCausalConv3d( + self.conv = LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, - is_causal=is_causal, - padding_mode=padding_mode, + spatial_padding_mode=spatial_padding_mode, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) residual = ( @@ -241,7 +245,7 @@ class LTXVideoDownsampler3d(nn.Module): residual = residual.unflatten(1, (-1, self.group_size)) residual = residual.mean(dim=2) - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states, causal=causal) hidden_states = ( hidden_states.unflatten(4, (-1, self.stride[2])) .unflatten(3, (-1, self.stride[1])) @@ -253,16 +257,15 @@ class LTXVideoDownsampler3d(nn.Module): return hidden_states -# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoUpsampler3d +# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d class LTXVideoUpsampler3d(nn.Module): def __init__( self, in_channels: int, stride: Union[int, Tuple[int, int, int]] = 1, - is_causal: bool = True, residual: bool = False, upscale_factor: int = 1, - padding_mode: str = "zeros", + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -272,16 +275,15 @@ class LTXVideoUpsampler3d(nn.Module): out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor - self.conv = LTXVideoCausalConv3d( + self.conv = LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, - is_causal=is_causal, - padding_mode=padding_mode, + spatial_padding_mode=spatial_padding_mode, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape if self.residual: @@ -293,7 +295,7 @@ class LTXVideoUpsampler3d(nn.Module): residual = residual.repeat(1, repeats, 1, 1, 1) residual = residual[:, :, self.stride[0] - 1 :] - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states, causal=causal) hidden_states = hidden_states.reshape( batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width ) @@ -342,8 +344,8 @@ class LTX2VideoDownBlock3D(nn.Module): resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, - is_causal: bool = True, downsample_type: str = "conv", + spatial_padding_mode: str = "zeros", ): super().__init__() @@ -358,7 +360,7 @@ class LTX2VideoDownBlock3D(nn.Module): dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) ) self.resnets = nn.ModuleList(resnets) @@ -369,30 +371,39 @@ class LTX2VideoDownBlock3D(nn.Module): if downsample_type == "conv": self.downsamplers.append( - LTXVideoCausalConv3d( + LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=(2, 2, 2), - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) ) elif downsample_type == "spatial": self.downsamplers.append( LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, ) ) elif downsample_type == "temporal": self.downsamplers.append( LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, ) ) elif downsample_type == "spatiotemporal": self.downsamplers.append( LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, ) ) @@ -403,18 +414,19 @@ class LTX2VideoDownBlock3D(nn.Module): hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: r"""Forward method of the `LTXDownBlock3D` class.""" for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) else: - hidden_states = resnet(hidden_states, temb, generator) + hidden_states = resnet(hidden_states, temb, generator, causal=causal) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, causal=causal) return hidden_states @@ -449,9 +461,9 @@ class LTX2VideoMidBlock3d(nn.Module): dropout: float = 0.0, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", - is_causal: bool = True, inject_noise: bool = False, timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -468,9 +480,9 @@ class LTX2VideoMidBlock3d(nn.Module): dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, inject_noise=inject_noise, timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) ) self.resnets = nn.ModuleList(resnets) @@ -482,6 +494,7 @@ class LTX2VideoMidBlock3d(nn.Module): hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: r"""Forward method of the `LTXMidBlock3D` class.""" @@ -497,9 +510,9 @@ class LTX2VideoMidBlock3d(nn.Module): for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) else: - hidden_states = resnet(hidden_states, temb, generator) + hidden_states = resnet(hidden_states, temb, generator, causal=causal) return hidden_states @@ -540,11 +553,11 @@ class LTX2VideoUpBlock3d(nn.Module): resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, - is_causal: bool = True, inject_noise: bool = False, timestep_conditioning: bool = False, upsample_residual: bool = False, upscale_factor: int = 1, + spatial_padding_mode: str = "zeros", ): super().__init__() @@ -562,9 +575,9 @@ class LTX2VideoUpBlock3d(nn.Module): dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, inject_noise=inject_noise, timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) self.upsamplers = None @@ -574,9 +587,9 @@ class LTX2VideoUpBlock3d(nn.Module): LTXVideoUpsampler3d( out_channels * upscale_factor, stride=(2, 2, 2), - is_causal=is_causal, residual=upsample_residual, upscale_factor=upscale_factor, + spatial_padding_mode=spatial_padding_mode, ) ] ) @@ -590,9 +603,9 @@ class LTX2VideoUpBlock3d(nn.Module): dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, inject_noise=inject_noise, timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) ) self.resnets = nn.ModuleList(resnets) @@ -604,9 +617,10 @@ class LTX2VideoUpBlock3d(nn.Module): hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: if self.conv_in is not None: - hidden_states = self.conv_in(hidden_states, temb, generator) + hidden_states = self.conv_in(hidden_states, temb, generator, causal=causal) if self.time_embedder is not None: temb = self.time_embedder( @@ -620,13 +634,13 @@ class LTX2VideoUpBlock3d(nn.Module): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, causal=causal) for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) else: - hidden_states = resnet(hidden_states, temb, generator) + hidden_states = resnet(hidden_states, temb, generator, causal=causal) return hidden_states @@ -682,21 +696,23 @@ class LTX2VideoEncoder3d(nn.Module): patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, is_causal: bool = True, + spatial_padding_mode: str = "zeros", ): super().__init__() self.patch_size = patch_size self.patch_size_t = patch_size_t self.in_channels = in_channels * patch_size**2 + self.is_causal = is_causal output_channel = out_channels - self.conv_in = LTXVideoCausalConv3d( + self.conv_in = LTX2VideoCausalConv3d( in_channels=self.in_channels, out_channels=output_channel, kernel_size=3, stride=1, - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) # down blocks @@ -713,8 +729,8 @@ class LTX2VideoEncoder3d(nn.Module): num_layers=layers_per_block[i], resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], - is_causal=is_causal, downsample_type=downsample_type[i], + spatial_padding_mode=spatial_padding_mode, ) else: raise ValueError(f"Unknown down block type: {down_block_types[i]}") @@ -726,19 +742,23 @@ class LTX2VideoEncoder3d(nn.Module): in_channels=output_channel, num_layers=layers_per_block[-1], resnet_eps=resnet_norm_eps, - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) # out self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) self.conv_act = nn.SiLU() - self.conv_out = LTXVideoCausalConv3d( - in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=out_channels + 1, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, ) self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: r"""The forward method of the `LTXVideoEncoder3d` class.""" p = self.patch_size @@ -748,28 +768,29 @@ class LTX2VideoEncoder3d(nn.Module): post_patch_num_frames = num_frames // p_t post_patch_height = height // p post_patch_width = width // p + causal = causal or self.is_causal hidden_states = hidden_states.reshape( batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p ) # Thanks for driving me insane with the weird patching order :( hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) - hidden_states = self.conv_in(hidden_states) + hidden_states = self.conv_in(hidden_states, causal=causal) if torch.is_grad_enabled() and self.gradient_checkpointing: for down_block in self.down_blocks: - hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states, None, None, causal) - hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, None, None, causal) else: for down_block in self.down_blocks: - hidden_states = down_block(hidden_states) + hidden_states = down_block(hidden_states, causal=causal) - hidden_states = self.mid_block(hidden_states) + hidden_states = self.mid_block(hidden_states, causal=causal) hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) last_channel = hidden_states[:, -1:] last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) @@ -817,17 +838,19 @@ class LTX2VideoDecoder3d(nn.Module): patch_size: int = 4, patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, - is_causal: bool = True, + is_causal: bool = False, inject_noise: Tuple[bool, ...] = (False, False, False), timestep_conditioning: bool = False, upsample_residual: Tuple[bool, ...] = (True, True, True), upsample_factor: Tuple[bool, ...] = (2, 2, 2), + spatial_padding_mode: str = "reflect", ) -> None: super().__init__() self.patch_size = patch_size self.patch_size_t = patch_size_t self.out_channels = out_channels * patch_size**2 + self.is_causal = is_causal block_out_channels = tuple(reversed(block_out_channels)) spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) @@ -837,17 +860,21 @@ class LTX2VideoDecoder3d(nn.Module): upsample_factor = tuple(reversed(upsample_factor)) output_channel = block_out_channels[0] - self.conv_in = LTXVideoCausalConv3d( - in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal + self.conv_in = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, ) self.mid_block = LTX2VideoMidBlock3d( in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, - is_causal=is_causal, inject_noise=inject_noise[0], timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) # up blocks @@ -863,11 +890,11 @@ class LTX2VideoDecoder3d(nn.Module): num_layers=layers_per_block[i + 1], resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], - is_causal=is_causal, inject_noise=inject_noise[i + 1], timestep_conditioning=timestep_conditioning, upsample_residual=upsample_residual[i], upscale_factor=upsample_factor[i], + spatial_padding_mode=spatial_padding_mode, ) self.up_blocks.append(up_block) @@ -875,8 +902,12 @@ class LTX2VideoDecoder3d(nn.Module): # out self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) self.conv_act = nn.SiLU() - self.conv_out = LTXVideoCausalConv3d( - in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, ) # timestep embedding @@ -890,22 +921,26 @@ class LTX2VideoDecoder3d(nn.Module): self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: - hidden_states = self.conv_in(hidden_states) + def forward( + self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, causal: Optional[bool] = None, + ) -> torch.Tensor: + causal = causal or self.is_causal + + hidden_states = self.conv_in(hidden_states, causal=causal) if self.timestep_scale_multiplier is not None: temb = temb * self.timestep_scale_multiplier if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb, None, causal) for up_block in self.up_blocks: - hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb, None, causal) else: - hidden_states = self.mid_block(hidden_states, temb) + hidden_states = self.mid_block(hidden_states, temb, causal=causal) for up_block in self.up_blocks: - hidden_states = up_block(hidden_states, temb) + hidden_states = up_block(hidden_states, temb, causal=causal) hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) @@ -923,7 +958,7 @@ class LTX2VideoDecoder3d(nn.Module): hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) p = self.patch_size p_t = self.patch_size_t @@ -1006,6 +1041,8 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig scaling_factor: float = 1.0, encoder_causal: bool = True, decoder_causal: bool = True, + encoder_spatial_padding_mode: str = "zeros", + decoder_spatial_padding_mode: str = "reflect", spatial_compression_ratio: int = None, temporal_compression_ratio: int = None, ) -> None: @@ -1023,6 +1060,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig patch_size_t=patch_size_t, resnet_norm_eps=resnet_norm_eps, is_causal=encoder_causal, + spatial_padding_mode=encoder_spatial_padding_mode, ) self.decoder = LTX2VideoDecoder3d( in_channels=latent_channels, @@ -1038,6 +1076,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig inject_noise=decoder_inject_noise, upsample_residual=upsample_residual, upsample_factor=upsample_factor, + spatial_padding_mode=decoder_spatial_padding_mode, ) latents_mean = torch.zeros((latent_channels,), requires_grad=False) @@ -1120,22 +1159,22 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames - def _encode(self, x: torch.Tensor) -> torch.Tensor: + def _encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = x.shape if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: - return self._temporal_tiled_encode(x) + return self._temporal_tiled_encode(x, causal=causal) if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): - return self.tiled_encode(x) + return self.tiled_encode(x, causal=causal) - enc = self.encoder(x) + enc = self.encoder(x, causal=causal) return enc @apply_forward_hook def encode( - self, x: torch.Tensor, return_dict: bool = True + self, x: torch.Tensor, causal: Optional[bool] = None, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. @@ -1150,10 +1189,10 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_slicing and x.shape[0] > 1: - encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + encoded_slices = [self._encode(x_slice, causal=causal) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: - h = self._encode(x) + h = self._encode(x, causal=causal) posterior = DiagonalGaussianDistribution(h) if not return_dict: @@ -1161,7 +1200,11 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig return AutoencoderKLOutput(latent_dist=posterior) def _decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + self, + z: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + return_dict: bool = True, ) -> Union[DecoderOutput, torch.Tensor]: batch_size, num_channels, num_frames, height, width = z.shape tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio @@ -1169,12 +1212,12 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: - return self._temporal_tiled_decode(z, temb, return_dict=return_dict) + return self._temporal_tiled_decode(z, temb, causal=causal, return_dict=return_dict) if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): - return self.tiled_decode(z, temb, return_dict=return_dict) + return self.tiled_decode(z, temb, causal=causal, return_dict=return_dict) - dec = self.decoder(z, temb) + dec = self.decoder(z, temb, causal=causal) if not return_dict: return (dec,) @@ -1183,7 +1226,11 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig @apply_forward_hook def decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + self, + z: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + return_dict: bool = True, ) -> Union[DecoderOutput, torch.Tensor]: """ Decode a batch of images. @@ -1201,13 +1248,13 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig if self.use_slicing and z.shape[0] > 1: if temb is not None: decoded_slices = [ - self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1)) + self._decode(z_slice, t_slice, causal=causal).sample for z_slice, t_slice in (z.split(1), temb.split(1)) ] else: - decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded_slices = [self._decode(z_slice, causal=causal).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: - decoded = self._decode(z, temb).sample + decoded = self._decode(z, temb, causal=causal).sample if not return_dict: return (decoded,) @@ -1238,7 +1285,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig ) return b - def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + def tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. Args: @@ -1267,7 +1314,8 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig row = [] for j in range(0, width, self.tile_sample_stride_width): time = self.encoder( - x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width], + causal=causal, ) row.append(time) @@ -1290,7 +1338,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig return enc def tiled_decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True ) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images using a tiled decoder. @@ -1324,7 +1372,9 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig for i in range(0, height, tile_latent_stride_height): row = [] for j in range(0, width, tile_latent_stride_width): - time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb) + time = self.decoder( + z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb, causal=causal + ) row.append(time) rows.append(row) @@ -1349,7 +1399,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig return DecoderOutput(sample=dec) - def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + def _temporal_tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> AutoencoderKLOutput: batch_size, num_channels, num_frames, height, width = x.shape latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 @@ -1361,9 +1411,9 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig for i in range(0, num_frames, self.tile_sample_stride_num_frames): tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): - tile = self.tiled_encode(tile) + tile = self.tiled_encode(tile, causal=causal) else: - tile = self.encoder(tile) + tile = self.encoder(tile, causal=causal) if i > 0: tile = tile[:, :, 1:, :, :] row.append(tile) @@ -1380,7 +1430,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig return enc def _temporal_tiled_decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True ) -> Union[DecoderOutput, torch.Tensor]: batch_size, num_channels, num_frames, height, width = z.shape num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 @@ -1395,9 +1445,9 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig for i in range(0, num_frames, tile_latent_stride_num_frames): tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): - decoded = self.tiled_decode(tile, temb, return_dict=True).sample + decoded = self.tiled_decode(tile, temb, causal=causal, return_dict=True).sample else: - decoded = self.decoder(tile, temb) + decoded = self.decoder(tile, temb, causal=causal) if i > 0: decoded = decoded[:, :, :-1, :, :] row.append(decoded) @@ -1422,16 +1472,18 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig sample: torch.Tensor, temb: Optional[torch.Tensor] = None, sample_posterior: bool = False, + encoder_causal: Optional[bool] = None, + decoder_causal: Optional[bool] = None, return_dict: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[torch.Tensor, torch.Tensor]: x = sample - posterior = self.encode(x).latent_dist + posterior = self.encode(x, causal=encoder_causal).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() - dec = self.decode(z, temb) + dec = self.decode(z, temb, causal=decoder_causal) if not return_dict: return (dec.sample,) return dec diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py index 703ba54f89..25984d621a 100644 --- a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py @@ -55,7 +55,10 @@ class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unit "patch_size": 1, "patch_size_t": 1, "encoder_causal": True, - "decoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + # Full model uses `reflect` but this does not have deterministic backward implementation, so use `zeros` + "decoder_spatial_padding_mode": "zeros", } @property From c6a11a553038e503f5f76f5bb667030a04504277 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 19 Dec 2025 12:17:10 +0100 Subject: [PATCH 18/19] Initial LTX 2.0 vocoder implementation --- scripts/convert_ltx2_to_diffusers.py | 65 ++++++++- src/diffusers/pipelines/ltx2/vocoder.py | 173 ++++++++++++++++++++++++ 2 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/pipelines/ltx2/vocoder.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 25a04e7893..f2e879c065 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -10,6 +10,7 @@ from huggingface_hub import hf_hub_download from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel from diffusers.utils.import_utils import is_accelerate_available +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder CTX = init_empty_weights if is_accelerate_available() else nullcontext @@ -61,6 +62,13 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = { "per_channel_statistics.std-of-means": "latents_std", } +LTX_2_0_VOCODER_RENAME_DICT = { + "ups": "upsamplers", + "resblocks": "resnets", + "conv_pre": "conv_in", + "conv_post": "conv_out", +} + def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None: state_dict[new_key] = state_dict.pop(old_key) @@ -99,6 +107,8 @@ LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.mean-of-stds": remove_keys_inplace, } +LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} + def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": @@ -315,6 +325,53 @@ def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> return vae +def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "in_channels": 128, + "hidden_channels": 1024, + "out_channels": 2, + "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "upsample_factors": [6, 5, 2, 2, 2], + "resnet_kernel_sizes": [3, 7, 11], + "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "leaky_relu_negative_slope": 0.1, + "output_sampling_rate": 24000, + } + } + rename_dict = LTX_2_0_VOCODER_RENAME_DICT + special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vocoder = LTX2Vocoder.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vocoder.load_state_dict(original_state_dict, strict=True, assign=True) + return vocoder + + def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: if args.original_state_dict_repo_id is not None: ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) @@ -468,7 +525,13 @@ def main(args): transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer")) if args.vocoder or args.full_pipeline: - pass + if args.vocoder_filename is not None: + original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename) + elif combined_ckpt is not None: + original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix) + vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version) + if not args.full_pipeline: + vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder")) if args.full_pipeline: pass diff --git a/src/diffusers/pipelines/ltx2/vocoder.py b/src/diffusers/pipelines/ltx2/vocoder.py new file mode 100644 index 0000000000..c3b3c1f367 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/vocoder.py @@ -0,0 +1,173 @@ +import math +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class ResBlock(nn.Module): + def __init__( + self, + channels: int, + kernel_size: int = 3, + stride: int = 1, + dilations: Tuple[int, ...] = (1, 3, 5), + leaky_relu_negative_slope: float = 0.1, + padding_mode: str = "same", + ): + super().__init__() + self.dilations = dilations + self.negative_slope = leaky_relu_negative_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding_mode + ) + for dilation in dilations + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=stride, + dilation=1, + padding=padding_mode + ) + for _ in range(len(dilations)) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv1, conv2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, negative_slope=self.negative_slope) + xt = conv1(xt) + xt = F.leaky_relu(xt, negative_slope=self.negative_slope) + xt = conv2(xt) + x = x + xt + return x + + +class LTX2Vocoder(ModelMixin, ConfigMixin): + r""" + LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + hidden_channels: int = 1024, + out_channels: int = 2, + upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4], + upsample_factors: List[int] = [6, 5, 2, 2, 2], + resnet_kernel_sizes: List[int] = [3, 7, 11], + resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + leaky_relu_negative_slope: float = 0.1, + output_sampling_rate: int = 24000, + ): + super().__init__() + self.num_upsample_layers = len(upsample_kernel_sizes) + self.resnets_per_upsample = len(resnet_kernel_sizes) + self.out_channels = out_channels + self.total_upsample_factor = math.prod(upsample_factors) + self.negative_slope = leaky_relu_negative_slope + + if self.num_upsample_layers != len(upsample_factors): + raise ValueError( + f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length" + f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively." + ) + + if self.resnets_per_upsample != len(resnet_dilations): + raise ValueError( + f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length" + f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively." + ) + + self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3) + + self.upsamplers = nn.ModuleList() + self.resnets = nn.ModuleList() + input_channels = hidden_channels + for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + output_channels = input_channels // 2 + self.upsamplers.append( + nn.ConvTranspose1d( + input_channels, # hidden_channels // (2 ** i) + output_channels, # hidden_channels // (2 ** (i + 1)) + kernel_size, + stride=stride, + padding=(kernel_size - stride) // 2, + ) + ) + + for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations): + self.resnets.append( + ResBlock( + output_channels, + kernel_size, + dilations=dilations, + leaky_relu_negative_slope=leaky_relu_negative_slope, + ) + ) + input_channels = output_channels + + self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3) + + def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: + r""" + Forward pass of the vocoder. + + Args: + hidden_states (`torch.Tensor`): + Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last` + is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is + `True`. + time_last (`bool`, *optional*, defaults to `False`): + Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension. + + Returns: + `torch.Tensor`: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + + # Ensure that the time/frame dimension is last + if not time_last: + hidden_states = hidden_states.transpose(2, 3) + # Combine channels and frequency (mel bins) dimensions + hidden_states = hidden_states.flatten(1, 2) + + hidden_states = self.conv_in(hidden_states) + + for i in range(self.num_upsample_layers): + hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) + hidden_states = self.upsamplers[i](hidden_states) + + # Run all resnets in parallel on hidden_states + start = i * self.resnets_per_upsample + end = (i + 1) * self.resnets_per_upsample + resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0) + + hidden_states = torch.mean(resnet_outputs, dim=0) + + # NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of + # 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended + hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01) + hidden_states = self.conv_out(hidden_states) + hidden_states = torch.tanh(hidden_states) + + return hidden_states From 6c56954fa876cd0aef5054d1eb0dc3ad684ebaa3 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 20 Dec 2025 02:40:38 +0100 Subject: [PATCH 19/19] Use RMSNorm implementation closer to original for LTX 2.0 video VAE --- .../autoencoders/autoencoder_kl_ltx2.py | 49 +++++++++++++++---- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index 6e7b4d324f..df59e2d748 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -29,6 +29,38 @@ from ..normalization import RMSNorm from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution +class PerChannelRMSNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + + For each element along the chosen dimension, this layer normalizes the tensor + by the root-mean-square of its values across that dimension: + + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + """ + + def __init__(self, channel_dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.channel_dim = channel_dim + self.eps = eps + + def forward(self, x: torch.Tensor, channel_dim: Optional[int] = None) -> torch.Tensor: + """ + Apply RMS normalization along the configured dimension. + """ + channel_dim = channel_dim or self.channel_dim + # Compute mean of squared values along `dim`, keep dimensions for broadcasting. + mean_sq = torch.mean(x**2, dim=self.channel_dim, keepdim=True) + # Normalize by the root-mean-square (RMS). + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + # Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime class LTX2VideoCausalConv3d(nn.Module): def __init__( @@ -120,7 +152,7 @@ class LTX2VideoResnetBlock3d(nn.Module): self.nonlinearity = get_activation(non_linearity) - self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.norm1 = PerChannelRMSNorm() self.conv1 = LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, @@ -128,7 +160,7 @@ class LTX2VideoResnetBlock3d(nn.Module): spatial_padding_mode=spatial_padding_mode, ) - self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.norm2 = PerChannelRMSNorm() self.dropout = nn.Dropout(dropout) self.conv2 = LTX2VideoCausalConv3d( in_channels=out_channels, @@ -165,8 +197,7 @@ class LTX2VideoResnetBlock3d(nn.Module): ) -> torch.Tensor: hidden_states = inputs - # Normalize over the channels dimension (dim 1), which is not the last dim - hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.norm1(hidden_states) if self.scale_shift_table is not None: temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None] @@ -183,7 +214,7 @@ class LTX2VideoResnetBlock3d(nn.Module): )[None] hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] - hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.norm2(hidden_states) if self.scale_shift_table is not None: hidden_states = hidden_states * (1 + scale_2) + shift_2 @@ -746,7 +777,7 @@ class LTX2VideoEncoder3d(nn.Module): ) # out - self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.norm_out = PerChannelRMSNorm() self.conv_act = nn.SiLU() self.conv_out = LTX2VideoCausalConv3d( in_channels=output_channel, @@ -788,7 +819,7 @@ class LTX2VideoEncoder3d(nn.Module): hidden_states = self.mid_block(hidden_states, causal=causal) - hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states, causal=causal) @@ -900,7 +931,7 @@ class LTX2VideoDecoder3d(nn.Module): self.up_blocks.append(up_block) # out - self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.norm_out = PerChannelRMSNorm() self.conv_act = nn.SiLU() self.conv_out = LTX2VideoCausalConv3d( in_channels=output_channel, @@ -942,7 +973,7 @@ class LTX2VideoDecoder3d(nn.Module): for up_block in self.up_blocks: hidden_states = up_block(hidden_states, temb, causal=causal) - hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.norm_out(hidden_states) if self.time_embedder is not None: temb = self.time_embedder(