1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Tests for T2V and I2V (#6)

* add ltx2 pipeline tests.

* up

* up

* up

* up

* remove content

* style

* Denormalize audio latents in I2V pipeline (analogous to T2V change)

* Initial refactor to put video and audio text encoder connectors in transformer

* Get LTX 2 transformer tests working after connector refactor

* up

* up

* i2v tests.

* up

* Address review comments

* Calculate RoPE double precisions freqs using torch instead of np

* Further simplify LTX 2 RoPE freq calc

* revert unneded changes.

* up

* up

* update to split style rope.

* up

---------

Co-authored-by: Daniel Gu <dgu8957@gmail.com>
This commit is contained in:
Sayak Paul
2026-01-06 08:05:30 +05:30
committed by GitHub
parent ce9da5d472
commit 93a417f24a
23 changed files with 725 additions and 134 deletions

View File

@@ -1,5 +1,4 @@
import argparse
import math
import os
from contextlib import nullcontext
from typing import Any, Dict, Optional, Tuple
@@ -8,9 +7,15 @@ import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import AutoModel, AutoTokenizer, Gemma3ForConditionalGeneration
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler, LTX2Pipeline, LTX2VideoTransformer3DModel
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder
from diffusers.utils.import_utils import is_accelerate_available
@@ -186,7 +191,7 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
"num_attention_heads": 2,
"attention_head_dim": 8,
"cross_attention_dim": 16,
"vae_scale_factors": (8, 32 ,32),
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
@@ -229,7 +234,7 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
"num_attention_heads": 32,
"attention_head_dim": 128,
"cross_attention_dim": 4096,
"vae_scale_factors": (8, 32 ,32),
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
@@ -257,7 +262,7 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1000,
"rope_type": "split"
"rope_type": "split",
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
@@ -307,7 +312,7 @@ def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
special_keys_remap = {}
@@ -541,7 +546,7 @@ def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"leaky_relu_negative_slope": 0.1,
"output_sampling_rate": 24000,
}
},
}
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
@@ -574,7 +579,6 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D
return vocoder
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
@@ -757,7 +761,7 @@ def main(args):
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
if not args.full_pipeline:
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
if args.connectors or args.full_pipeline:
if args.dit_filename is not None:
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
@@ -810,6 +814,6 @@ def main(args):
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if __name__ == '__main__':
if __name__ == "__main__":
args = get_args()
main(args)

View File

@@ -1,5 +1,4 @@
import argparse
import math
import os
from fractions import Fraction
from typing import Optional
@@ -211,6 +210,6 @@ def main(args):
)
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -1,12 +1,11 @@
import argparse
import os
from fractions import Fraction
from typing import Optional
from PIL import Image
import av # Needs to be installed separately (`pip install av`)
import torch
from PIL import Image
from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline
@@ -131,7 +130,7 @@ def parse_args():
parser.add_argument(
"--negative_prompt",
type=str,
default="shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
default="shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static.",
)
parser.add_argument("--num_inference_steps", type=int, default=40)
@@ -166,7 +165,9 @@ def parse_args():
def main(args):
pipeline = LTX2ImageToVideoPipeline.from_pretrained(
args.model_id, revision=args.revision, torch_dtype=args.dtype,
args.model_id,
revision=args.revision,
torch_dtype=args.dtype,
)
if args.cpu_offload:
pipeline.enable_model_cpu_offload()
@@ -201,6 +202,6 @@ def main(args):
)
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -90,7 +90,7 @@ def main() -> None:
latent_width,
device=device,
dtype=dtype,
generator=torch.Generator(device).manual_seed(42)
generator=torch.Generator(device).manual_seed(42),
)
original_out = original_decoder(dummy)

View File

@@ -193,9 +193,9 @@ else:
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLHunyuanVideo15",
"AutoencoderKLLTXVideo",
"AutoencoderKLLTX2Audio",
"AutoencoderKLLTX2Video",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
"AutoencoderKLQwenImage",
@@ -237,8 +237,8 @@ else:
"Kandinsky3UNet",
"Kandinsky5Transformer3DModel",
"LatteTransformer3DModel",
"LTXVideoTransformer3DModel",
"LTX2VideoTransformer3DModel",
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
"LuminaNextDiT2DModel",
"MochiTransformer3DModel",
@@ -533,12 +533,13 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LTX2ImageToVideoPipeline",
"LTX2Pipeline",
"LTX2Pipeline",
"LTXConditionPipeline",
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
"LTXPipeline",
"LTX2Pipeline",
"LTX2ImageToVideoPipeline",
"LucyEditPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
@@ -931,9 +932,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTXVideo,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
AutoencoderKLQwenImage,
@@ -975,8 +976,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky3UNet,
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
LTX2VideoTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
@@ -1241,12 +1242,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LTX2ImageToVideoPipeline,
LTX2Pipeline,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
LTX2Pipeline,
LTX2ImageToVideoPipeline,
LucyEditPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,

View File

@@ -154,9 +154,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTXVideo,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
AutoencoderKLQwenImage,
@@ -213,8 +213,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanVideoTransformer3DModel,
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
LTX2VideoTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,

View File

@@ -10,8 +10,8 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage

View File

@@ -25,7 +25,6 @@ from ..activations import get_activation
from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
@@ -33,8 +32,8 @@ class PerChannelRMSNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
For each element along the chosen dimension, this layer normalizes the tensor
by the root-mean-square of its values across that dimension:
For each element along the chosen dimension, this layer normalizes the tensor by the root-mean-square of its values
across that dimension:
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
"""
@@ -174,9 +173,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
if in_channels != out_channels:
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
# LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d
self.conv_shortcut = nn.Conv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1
)
self.conv_shortcut = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1)
self.per_channel_scale1 = None
self.per_channel_scale2 = None
@@ -953,7 +950,10 @@ class LTX2VideoDecoder3d(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, causal: Optional[bool] = None,
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
causal: Optional[bool] = None,
) -> torch.Tensor:
causal = causal or self.is_causal
@@ -1279,7 +1279,8 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
if self.use_slicing and z.shape[0] > 1:
if temb is not None:
decoded_slices = [
self._decode(z_slice, t_slice, causal=causal).sample for z_slice, t_slice in (z.split(1), temb.split(1))
self._decode(z_slice, t_slice, causal=causal).sample
for z_slice, t_slice in (z.split(1), temb.split(1))
]
else:
decoded_slices = [self._decode(z_slice, causal=causal).sample for z_slice in z.split(1)]

View File

@@ -249,6 +249,7 @@ class LTX2AudioUpsample(nn.Module):
return x
class LTX2AudioAudioPatchifier:
"""
Patchifier for spectrogram/audio latents.
@@ -405,9 +406,7 @@ class LTX2AudioDecoder(nn.Module):
final_block_channels = block_in
if self.norm_type == "group":
self.norm_out = nn.GroupNorm(
num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True
)
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
elif self.norm_type == "pixel":
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
@@ -538,8 +537,8 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
latents_std = torch.zeros((base_channels, ))
latents_mean = torch.ones((base_channels, ))
latents_std = torch.zeros((base_channels,))
latents_mean = torch.ones((base_channels,))
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)

View File

@@ -22,16 +22,21 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
is_torch_version,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, PixArtAlphaCombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
from ..normalization import RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -44,6 +49,7 @@ def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, tor
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
cos, sin = freqs
@@ -65,7 +71,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten
# (..., 2, r)
split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float
first_x = split_x[..., :1, :] # (..., 1, r)
first_x = split_x[..., :1, :] # (..., 1, r)
second_x = split_x[..., 1:, :] # (..., 1, r)
cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r)
@@ -89,6 +95,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten
ROTARY_FN_MAP = {"interleaved": apply_interleaved_rotary_emb, "split": apply_split_rotary_emb}
@dataclass
class AudioVisualModelOutput(BaseOutput):
r"""
@@ -192,7 +199,9 @@ class LTX2AudioVideoAttnProcessor:
if query_rotary_emb is not None:
query = ROTARY_FN_MAP[attn.rope_type](query, query_rotary_emb)
key = ROTARY_FN_MAP[attn.rope_type](key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb)
key = ROTARY_FN_MAP[attn.rope_type](
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
@@ -368,7 +377,7 @@ class LTX2VideoTransformerBlock(nn.Module):
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type
rope_type=rope_type,
)
self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -381,7 +390,7 @@ class LTX2VideoTransformerBlock(nn.Module):
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type
rope_type=rope_type,
)
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -396,7 +405,7 @@ class LTX2VideoTransformerBlock(nn.Module):
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type
rope_type=rope_type,
)
# Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
@@ -481,7 +490,9 @@ class LTX2VideoTransformerBlock(nn.Module):
audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
batch_size, temb_audio.size(1), num_audio_ada_params, -1
)
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = audio_ada_values.unbind(dim=2)
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
audio_ada_values.unbind(dim=2)
)
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
attn_audio_hidden_states = self.audio_attn1(
@@ -550,8 +561,12 @@ class LTX2VideoTransformerBlock(nn.Module):
if use_a2v_cross_attn:
# Audio-to-Video Cross Attention: Q: Video; K,V: Audio
mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(2)
mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_a2v_ca_scale.squeeze(2)) + audio_a2v_ca_shift.squeeze(2)
mod_norm_hidden_states = norm_hidden_states * (
1 + video_a2v_ca_scale.squeeze(2)
) + video_a2v_ca_shift.squeeze(2)
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
1 + audio_a2v_ca_scale.squeeze(2)
) + audio_a2v_ca_shift.squeeze(2)
a2v_attn_hidden_states = self.audio_to_video_attn(
mod_norm_hidden_states,
@@ -565,8 +580,12 @@ class LTX2VideoTransformerBlock(nn.Module):
if use_v2a_cross_attn:
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(2)
mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_v2a_ca_scale.squeeze(2)) + audio_v2a_ca_shift.squeeze(2)
mod_norm_hidden_states = norm_hidden_states * (
1 + video_v2a_ca_scale.squeeze(2)
) + video_v2a_ca_shift.squeeze(2)
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
1 + audio_v2a_ca_scale.squeeze(2)
) + audio_v2a_ca_shift.squeeze(2)
v2a_attn_hidden_states = self.video_to_audio_attn(
mod_norm_audio_hidden_states,
@@ -596,9 +615,10 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
Args:
causal_offset (`int`, *optional*, defaults to `1`):
Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where
the VAE treats the very first frame differently), but could also be 0 (for non-causal modeling).
Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where the VAE
treats the very first frame differently), but could also be 0 (for non-causal modeling).
"""
def __init__(
self,
dim: int,
@@ -658,9 +678,9 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
fps: float = 25.0,
) -> torch.Tensor:
"""
Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original
pixel space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3,
num_patches, 2) where
Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original pixel
space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, num_patches, 2)
where
- axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames)
- axis 3 (size 2) stores `[start, end)` indices within each dimension
@@ -727,8 +747,8 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
shift: int = 0,
) -> torch.Tensor:
"""
Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent
frame. This will ultimately have shape (batch_size, 3, num_patches, 2) where
Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent frame.
This will ultimately have shape (batch_size, 3, num_patches, 2) where
- axis 1 (size 1) represents the temporal dimension
- axis 3 (size 2) stores `[start, end)` indices within each dimension
@@ -763,7 +783,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
# Handle first frame causal offset, ensuring non-negative timestamps
grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0)
# Convert mel bins back into seconds
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate
# 3. Calculate start timstamps in seconds with respect to the original spectrogram grid
grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor
@@ -862,7 +882,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems])
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
elif self.rope_type == "split":
expected_freqs = self.dim // 2
current_freqs = freqs.shape[-1]
@@ -1087,7 +1107,7 @@ class LTX2VideoTransformer3DModel(
modality="audio",
double_precision=rope_double_precision,
rope_type=rope_type,
num_attention_heads=audio_num_attention_heads
num_attention_heads=audio_num_attention_heads,
)
# 5. Transformer Blocks
@@ -1154,7 +1174,7 @@ class LTX2VideoTransformer3DModel(
encoder_hidden_states (`torch.Tensor`):
Input text embeddings of shape TODO.
TODO for the rest.
Returns:
`AudioVisualModelOutput` or `tuple`:
If `return_dict` is `True`, returns a structured output of type `AudioVisualModelOutput`, otherwise a
@@ -1204,14 +1224,18 @@ class LTX2VideoTransformer3DModel(
audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device)
video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device)
audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :], device=audio_hidden_states.device)
audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(
audio_coords[:, 0:1, :], device=audio_hidden_states.device
)
# 2. Patchify input projections
hidden_states = self.proj_in(hidden_states)
audio_hidden_states = self.audio_proj_in(audio_hidden_states)
# 3. Prepare timestep embeddings and modulation parameters
timestep_cross_attn_gate_scale_factor = self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier
timestep_cross_attn_gate_scale_factor = (
self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier
)
# 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters
# temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer
@@ -1243,7 +1267,9 @@ class LTX2VideoTransformer3DModel(
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
video_cross_attn_scale_shift = video_cross_attn_scale_shift.view(batch_size, -1, video_cross_attn_scale_shift.shape[-1])
video_cross_attn_scale_shift = video_cross_attn_scale_shift.view(
batch_size, -1, video_cross_attn_scale_shift.shape[-1]
)
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
@@ -1256,7 +1282,9 @@ class LTX2VideoTransformer3DModel(
batch_size=batch_size,
hidden_dtype=audio_hidden_states.dtype,
)
audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view(batch_size, -1, audio_cross_attn_scale_shift.shape[-1])
audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view(
batch_size, -1, audio_cross_attn_scale_shift.shape[-1]
)
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
# 4. Prepare prompt embeddings

View File

@@ -720,7 +720,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
from .ltx2 import LTX2Pipeline, LTX2ImageToVideoPipeline
from .ltx2 import LTX2ImageToVideoPipeline, LTX2Pipeline
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline

View File

@@ -22,9 +22,9 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["connectors"] = ["LTX2TextConnectors"]
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
_import_structure["connectors"] = ["LTX2TextConnectors"]
_import_structure["vocoder"] = ["LTX2Vocoder"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -35,9 +35,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .connectors import LTX2TextConnectors
from .pipeline_ltx2 import LTX2Pipeline
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
from .connectors import LTX2TextConnectors
from .vocoder import LTX2Vocoder
else:

View File

@@ -9,6 +9,7 @@ from ...models.attention import FeedForward
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
class LTX2RotaryPosEmbed1d(nn.Module):
"""
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
@@ -21,12 +22,12 @@ class LTX2RotaryPosEmbed1d(nn.Module):
theta: float = 10000.0,
double_precision: bool = True,
rope_type: str = "interleaved",
num_attention_heads: int = 32
num_attention_heads: int = 32,
):
super().__init__()
if rope_type not in ["interleaved", "split"]:
raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.")
self.dim = dim
self.base_seq_len = base_seq_len
self.theta = theta
@@ -69,7 +70,7 @@ class LTX2RotaryPosEmbed1d(nn.Module):
sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
elif self.rope_type == "split":
expected_freqs = self.dim // 2
current_freqs = freqs.shape[-1]
@@ -116,7 +117,7 @@ class LTX2TransformerBlock1d(nn.Module):
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
processor=LTX2AudioVideoAttnProcessor(),
rope_type=rope_type
rope_type=rope_type,
)
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
@@ -159,7 +160,7 @@ class LTX2ConnectorTransformer1d(nn.Module):
rope_double_precision: bool = True,
eps: float = 1e-6,
causal_temporal_positioning: bool = False,
rope_type: str = "interleaved"
rope_type: str = "interleaved",
):
super().__init__()
self.num_attention_heads = num_attention_heads
@@ -173,12 +174,12 @@ class LTX2ConnectorTransformer1d(nn.Module):
self.learnable_registers = torch.nn.Parameter(init_registers)
self.rope = LTX2RotaryPosEmbed1d(
self.inner_dim,
base_seq_len=rope_base_seq_len,
theta=rope_theta,
self.inner_dim,
base_seq_len=rope_base_seq_len,
theta=rope_theta,
double_precision=rope_double_precision,
rope_type=rope_type,
num_attention_heads=num_attention_heads
num_attention_heads=num_attention_heads,
)
self.transformer_blocks = torch.nn.ModuleList(
@@ -187,7 +188,7 @@ class LTX2ConnectorTransformer1d(nn.Module):
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
rope_type=rope_type
rope_type=rope_type,
)
for _ in range(num_layers)
]
@@ -253,8 +254,8 @@ class LTX2ConnectorTransformer1d(nn.Module):
class LTX2TextConnectors(ModelMixin, ConfigMixin):
"""
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and
audio streams.
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio
streams.
"""
@register_to_config
@@ -287,7 +288,7 @@ class LTX2TextConnectors(ModelMixin, ConfigMixin):
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
rope_type=rope_type
rope_type=rope_type,
)
self.audio_connector = LTX2ConnectorTransformer1d(
num_attention_heads=audio_connector_num_attention_heads,
@@ -298,7 +299,7 @@ class LTX2TextConnectors(ModelMixin, ConfigMixin):
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
rope_type=rope_type
rope_type=rope_type,
)
def forward(

View File

@@ -674,7 +674,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
duration_s = num_frames / frame_rate
latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
latents_per_second = (
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
)
latent_length = int(duration_s * latents_per_second)
if latents is not None:
@@ -995,7 +997,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
audio_latent_model_input = torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
audio_latent_model_input = (
torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
)
audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -1026,10 +1030,14 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
if self.do_classifier_free_guidance:
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (noise_pred_video_text - noise_pred_video_uncond)
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
noise_pred_video_text - noise_pred_video_uncond
)
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
noise_pred_audio_text - noise_pred_audio_uncond
)
if self.guidance_rescale > 0:
# Based on 3.4. in https://huggingface.co/papers/2305.08891

View File

@@ -13,25 +13,26 @@
# limitations under the License.
import copy
from typing import Any, Callable, Dict, List, Optional, Union
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .connectors import LTX2TextConnectors
from .pipeline_output import LTX2PipelineOutput
from ..pipeline_utils import DiffusionPipeline
from .vocoder import LTX2Vocoder
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
from ...models.transformers import LTX2VideoTransformer3DModel
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .connectors import LTX2TextConnectors
from .pipeline_output import LTX2PipelineOutput
from .vocoder import LTX2Vocoder
if is_torch_xla_available():
@@ -86,6 +87,7 @@ def retrieve_latents(
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
@@ -665,7 +667,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
shape = (batch_size, num_channels_latents, num_frames, height, width)
mask_shape = (batch_size, 1, num_frames, height, width)
if latents is not None:
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
@@ -697,7 +699,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
init_latents = torch.cat(init_latents, dim=0).to(dtype)
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)
# First condition is image latents and those should be kept clean.
conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
conditioning_mask[:, :, 0] = 1.0
@@ -731,7 +733,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
duration_s = num_frames / frame_rate
latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
latents_per_second = (
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
)
latent_length = int(duration_s * latents_per_second)
if latents is not None:
@@ -982,7 +986,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
)
if self.do_classifier_free_guidance:
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
@@ -1063,12 +1067,14 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
audio_latent_model_input = torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
audio_latent_model_input = (
torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
)
audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
timestep = t.expand(latent_model_input.shape[0])
video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
with self.transformer.cache_context("cond_uncond"):
noise_pred_video, noise_pred_audio = self.transformer(
hidden_states=latent_model_input,
@@ -1095,10 +1101,14 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
if self.do_classifier_free_guidance:
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (noise_pred_video_text - noise_pred_video_uncond)
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
noise_pred_video_text - noise_pred_video_uncond
)
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
noise_pred_audio_text - noise_pred_audio_uncond
)
if self.guidance_rescale > 0:
# Based on 3.4. in https://huggingface.co/papers/2305.08891

View File

@@ -25,32 +25,18 @@ class ResBlock(nn.Module):
self.convs1 = nn.ModuleList(
[
nn.Conv1d(
channels,
channels,
kernel_size,
stride=stride,
dilation=dilation,
padding=padding_mode
)
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode)
for dilation in dilations
]
)
self.convs2 = nn.ModuleList(
[
nn.Conv1d(
channels,
channels,
kernel_size,
stride=stride,
dilation=1,
padding=padding_mode
)
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode)
for _ in range(len(dilations))
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv1, conv2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
@@ -127,7 +113,7 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
input_channels = output_channels
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
r"""
Forward pass of the vocoder.

View File

@@ -502,6 +502,36 @@ class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderKLLTX2Audio(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLLTX2Video(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLLTXVideo(metaclass=DummyObject):
_backends = ["torch"]
@@ -1132,6 +1162,21 @@ class LatteTransformer3DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class LTX2VideoTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LTXVideoTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1802,6 +1802,36 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LTX2ImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTX2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTXConditionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -15,8 +15,6 @@
import unittest
import torch
from diffusers import AutoencoderKLLTX2Video
from ...testing_utils import (

View File

@@ -52,9 +52,9 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
sequence_length = 16
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
audio_hidden_states = torch.randn(
(batch_size, audio_num_frames, audio_num_channels * num_mel_bins)
).to(torch_device)
audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to(
torch_device
)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)

View File

View File

@@ -0,0 +1,239 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from ...testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = LTX2Pipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"audio_latents",
"output_type",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_attention_slicing = False
test_xformers_attention = False
supports_dduf = False
base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3"
def get_dummy_components(self):
tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id)
torch.manual_seed(0)
transformer = LTX2VideoTransformer3DModel(
in_channels=4,
out_channels=4,
patch_size=1,
patch_size_t=1,
num_attention_heads=2,
attention_head_dim=8,
cross_attention_dim=16,
audio_in_channels=4,
audio_out_channels=4,
audio_num_attention_heads=2,
audio_attention_head_dim=4,
audio_cross_attention_dim=8,
num_layers=2,
qk_norm="rms_norm_across_heads",
caption_channels=text_encoder.config.text_config.hidden_size,
rope_double_precision=False,
rope_type="split",
)
torch.manual_seed(0)
connectors = LTX2TextConnectors(
caption_channels=text_encoder.config.text_config.hidden_size,
text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1,
video_connector_num_attention_heads=4,
video_connector_attention_head_dim=8,
video_connector_num_layers=1,
video_connector_num_learnable_registers=None,
audio_connector_num_attention_heads=4,
audio_connector_attention_head_dim=8,
audio_connector_num_layers=1,
audio_connector_num_learnable_registers=None,
connector_rope_base_seq_len=32,
rope_theta=10000.0,
rope_double_precision=False,
causal_temporal_positioning=False,
rope_type="split",
)
torch.manual_seed(0)
vae = AutoencoderKLLTX2Video(
in_channels=3,
out_channels=3,
latent_channels=4,
block_out_channels=(8,),
decoder_block_out_channels=(8,),
layers_per_block=(1,),
decoder_layers_per_block=(1, 1),
spatio_temporal_scaling=(True,),
decoder_spatio_temporal_scaling=(True,),
decoder_inject_noise=(False, False),
downsample_type=("spatial",),
upsample_residual=(False,),
upsample_factor=(1,),
timestep_conditioning=False,
patch_size=1,
patch_size_t=1,
encoder_causal=True,
decoder_causal=False,
)
vae.use_framewise_encoding = False
vae.use_framewise_decoding = False
torch.manual_seed(0)
audio_vae = AutoencoderKLLTX2Audio(
base_channels=4,
output_channels=2,
ch_mult=(1,),
num_res_blocks=1,
attn_resolutions=None,
in_channels=2,
resolution=32,
latent_channels=2,
norm_type="pixel",
causality_axis="height",
dropout=0.0,
mid_block_add_attention=False,
sample_rate=16000,
mel_hop_length=160,
is_causal=True,
mel_bins=8,
)
torch.manual_seed(0)
vocoder = LTX2Vocoder(
in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins,
hidden_channels=32,
out_channels=2,
upsample_kernel_sizes=[4, 4],
upsample_factors=[2, 2],
resnet_kernel_sizes=[3],
resnet_dilations=[[1, 3, 5]],
leaky_relu_negative_slope=0.1,
output_sampling_rate=16000,
)
scheduler = FlowMatchEulerDiscreteScheduler()
components = {
"transformer": transformer,
"vae": vae,
"audio_vae": audio_vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"connectors": connectors,
"vocoder": vocoder,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "a robot dancing",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.0,
"height": 32,
"width": 32,
"num_frames": 5,
"frame_rate": 25.0,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = pipe(**inputs)
video = output.frames
audio = output.audio
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
self.assertEqual(audio.shape[0], 1)
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
# fmt: off
expected_video_slice = torch.tensor(
[
0.4331, 0.6203, 0.3245, 0.7294, 0.4822, 0.5703, 0.2999, 0.7700, 0.4961, 0.4242, 0.4581, 0.4351, 0.1137, 0.4437, 0.6304, 0.3184
]
)
expected_audio_slice = torch.tensor(
[
0.0229, 0.0503, 0.1220, 0.1083, 0.1745, 0.1075, 0.1779, 0.0974, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on
video = video.flatten()
audio = audio.flatten()
generated_video_slice = torch.cat([video[:8], video[-8:]])
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)

View File

@@ -0,0 +1,241 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2ImageToVideoPipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from ...testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = LTX2ImageToVideoPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"audio_latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_attention_slicing = False
test_xformers_attention = False
supports_dduf = False
base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3"
def get_dummy_components(self):
tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id)
torch.manual_seed(0)
transformer = LTX2VideoTransformer3DModel(
in_channels=4,
out_channels=4,
patch_size=1,
patch_size_t=1,
num_attention_heads=2,
attention_head_dim=8,
cross_attention_dim=16,
audio_in_channels=4,
audio_out_channels=4,
audio_num_attention_heads=2,
audio_attention_head_dim=4,
audio_cross_attention_dim=8,
num_layers=2,
qk_norm="rms_norm_across_heads",
caption_channels=text_encoder.config.text_config.hidden_size,
rope_double_precision=False,
rope_type="split",
)
torch.manual_seed(0)
connectors = LTX2TextConnectors(
caption_channels=text_encoder.config.text_config.hidden_size,
text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1,
video_connector_num_attention_heads=4,
video_connector_attention_head_dim=8,
video_connector_num_layers=1,
video_connector_num_learnable_registers=None,
audio_connector_num_attention_heads=4,
audio_connector_attention_head_dim=8,
audio_connector_num_layers=1,
audio_connector_num_learnable_registers=None,
connector_rope_base_seq_len=32,
rope_theta=10000.0,
rope_double_precision=False,
causal_temporal_positioning=False,
rope_type="split",
)
torch.manual_seed(0)
vae = AutoencoderKLLTX2Video(
in_channels=3,
out_channels=3,
latent_channels=4,
block_out_channels=(8,),
decoder_block_out_channels=(8,),
layers_per_block=(1,),
decoder_layers_per_block=(1, 1),
spatio_temporal_scaling=(True,),
decoder_spatio_temporal_scaling=(True,),
decoder_inject_noise=(False, False),
downsample_type=("spatial",),
upsample_residual=(False,),
upsample_factor=(1,),
timestep_conditioning=False,
patch_size=1,
patch_size_t=1,
encoder_causal=True,
decoder_causal=False,
)
vae.use_framewise_encoding = False
vae.use_framewise_decoding = False
torch.manual_seed(0)
audio_vae = AutoencoderKLLTX2Audio(
base_channels=4,
output_channels=2,
ch_mult=(1,),
num_res_blocks=1,
attn_resolutions=None,
in_channels=2,
resolution=32,
latent_channels=2,
norm_type="pixel",
causality_axis="height",
dropout=0.0,
mid_block_add_attention=False,
sample_rate=16000,
mel_hop_length=160,
is_causal=True,
mel_bins=8,
)
torch.manual_seed(0)
vocoder = LTX2Vocoder(
in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins,
hidden_channels=32,
out_channels=2,
upsample_kernel_sizes=[4, 4],
upsample_factors=[2, 2],
resnet_kernel_sizes=[3],
resnet_dilations=[[1, 3, 5]],
leaky_relu_negative_slope=0.1,
output_sampling_rate=16000,
)
scheduler = FlowMatchEulerDiscreteScheduler()
components = {
"transformer": transformer,
"vae": vae,
"audio_vae": audio_vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"connectors": connectors,
"vocoder": vocoder,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = torch.rand((1, 3, 32, 32), generator=generator, device=device)
inputs = {
"image": image,
"prompt": "a robot dancing",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.0,
"height": 32,
"width": 32,
"num_frames": 5,
"frame_rate": 25.0,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = pipe(**inputs)
video = output.frames
audio = output.audio
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
self.assertEqual(audio.shape[0], 1)
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
# fmt: off
expected_video_slice = torch.tensor(
[
0.3573, 0.8382, 0.3581, 0.6114, 0.3682, 0.7969, 0.2552, 0.6399, 0.3113, 0.1497, 0.3249, 0.5395, 0.3498, 0.4526, 0.4536, 0.4555
]
)
expected_audio_slice = torch.tensor(
[
0.0229, 0.0503, 0.1220, 0.1083, 0.1745, 0.1075, 0.1779, 0.0974, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on
video = video.flatten()
audio = audio.flatten()
generated_video_slice = torch.cat([video[:8], video[-8:]])
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)