diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 3b8c9598b5..fa46197978 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -1,5 +1,4 @@ import argparse -import math import os from contextlib import nullcontext from typing import Any, Dict, Optional, Tuple @@ -8,9 +7,15 @@ import safetensors.torch import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from transformers import AutoModel, AutoTokenizer, Gemma3ForConditionalGeneration +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration -from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler, LTX2Pipeline, LTX2VideoTransformer3DModel +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2Pipeline, + LTX2VideoTransformer3DModel, +) from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder from diffusers.utils.import_utils import is_accelerate_available @@ -186,7 +191,7 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, "num_attention_heads": 2, "attention_head_dim": 8, "cross_attention_dim": 16, - "vae_scale_factors": (8, 32 ,32), + "vae_scale_factors": (8, 32, 32), "pos_embed_max_pos": 20, "base_height": 2048, "base_width": 2048, @@ -229,7 +234,7 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, "num_attention_heads": 32, "attention_head_dim": 128, "cross_attention_dim": 4096, - "vae_scale_factors": (8, 32 ,32), + "vae_scale_factors": (8, 32, 32), "pos_embed_max_pos": 20, "base_height": 2048, "base_width": 2048, @@ -257,7 +262,7 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, "causal_offset": 1, "timestep_scale_multiplier": 1000, "cross_attn_timestep_scale_multiplier": 1000, - "rope_type": "split" + "rope_type": "split", }, } rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT @@ -307,7 +312,7 @@ def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, "rope_type": "split", }, } - + rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT special_keys_remap = {} @@ -541,7 +546,7 @@ def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "leaky_relu_negative_slope": 0.1, "output_sampling_rate": 24000, - } + }, } rename_dict = LTX_2_0_VOCODER_RENAME_DICT special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP @@ -574,7 +579,6 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D return vocoder - def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: if args.original_state_dict_repo_id is not None: ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) @@ -757,7 +761,7 @@ def main(args): transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version) if not args.full_pipeline: transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer")) - + if args.connectors or args.full_pipeline: if args.dit_filename is not None: original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename) @@ -810,6 +814,6 @@ def main(args): pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") -if __name__ == '__main__': +if __name__ == "__main__": args = get_args() main(args) diff --git a/scripts/ltx2_test_full_pipeline.py b/scripts/ltx2_test_full_pipeline.py index 5f0f366e71..16ea9f8040 100644 --- a/scripts/ltx2_test_full_pipeline.py +++ b/scripts/ltx2_test_full_pipeline.py @@ -1,5 +1,4 @@ import argparse -import math import os from fractions import Fraction from typing import Optional @@ -211,6 +210,6 @@ def main(args): ) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() main(args) diff --git a/scripts/ltx2_test_full_pipeline_i2v.py b/scripts/ltx2_test_full_pipeline_i2v.py index 01b18e5eb8..8c39647eae 100644 --- a/scripts/ltx2_test_full_pipeline_i2v.py +++ b/scripts/ltx2_test_full_pipeline_i2v.py @@ -1,12 +1,11 @@ - import argparse import os from fractions import Fraction from typing import Optional -from PIL import Image import av # Needs to be installed separately (`pip install av`) import torch +from PIL import Image from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline @@ -131,7 +130,7 @@ def parse_args(): parser.add_argument( "--negative_prompt", type=str, - default="shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + default="shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static.", ) parser.add_argument("--num_inference_steps", type=int, default=40) @@ -166,7 +165,9 @@ def parse_args(): def main(args): pipeline = LTX2ImageToVideoPipeline.from_pretrained( - args.model_id, revision=args.revision, torch_dtype=args.dtype, + args.model_id, + revision=args.revision, + torch_dtype=args.dtype, ) if args.cpu_offload: pipeline.enable_model_cpu_offload() @@ -201,6 +202,6 @@ def main(args): ) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() main(args) diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py index a6ba16ed9e..3aa2a65d3f 100644 --- a/scripts/test_ltx2_audio_conversion.py +++ b/scripts/test_ltx2_audio_conversion.py @@ -90,7 +90,7 @@ def main() -> None: latent_width, device=device, dtype=dtype, - generator=torch.Generator(device).manual_seed(42) + generator=torch.Generator(device).manual_seed(42), ) original_out = original_decoder(dummy) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2e99ea8063..9c9ade9154 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -193,9 +193,9 @@ else: "AutoencoderKLHunyuanImageRefiner", "AutoencoderKLHunyuanVideo", "AutoencoderKLHunyuanVideo15", - "AutoencoderKLLTXVideo", "AutoencoderKLLTX2Audio", "AutoencoderKLLTX2Video", + "AutoencoderKLLTXVideo", "AutoencoderKLMagvit", "AutoencoderKLMochi", "AutoencoderKLQwenImage", @@ -237,8 +237,8 @@ else: "Kandinsky3UNet", "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", - "LTXVideoTransformer3DModel", "LTX2VideoTransformer3DModel", + "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", "MochiTransformer3DModel", @@ -533,12 +533,13 @@ else: "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LTX2ImageToVideoPipeline", + "LTX2Pipeline", + "LTX2Pipeline", "LTXConditionPipeline", "LTXImageToVideoPipeline", "LTXLatentUpsamplePipeline", "LTXPipeline", - "LTX2Pipeline", - "LTX2ImageToVideoPipeline", "LucyEditPipeline", "Lumina2Pipeline", "Lumina2Text2ImgPipeline", @@ -931,9 +932,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, - AutoencoderKLLTXVideo, AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, + AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, @@ -975,8 +976,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: Kandinsky3UNet, Kandinsky5Transformer3DModel, LatteTransformer3DModel, - LTXVideoTransformer3DModel, LTX2VideoTransformer3DModel, + LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, @@ -1241,12 +1242,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LTX2ImageToVideoPipeline, + LTX2Pipeline, LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline, - LTX2Pipeline, - LTX2ImageToVideoPipeline, LucyEditPipeline, Lumina2Pipeline, Lumina2Text2ImgPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index d3bcb3bcee..4d372e1112 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -154,9 +154,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, - AutoencoderKLLTXVideo, AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, + AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, @@ -213,8 +213,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: HunyuanVideoTransformer3DModel, Kandinsky5Transformer3DModel, LatteTransformer3DModel, - LTXVideoTransformer3DModel, LTX2VideoTransformer3DModel, + LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 38d52f0eb5..8e7a9c81d2 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -10,8 +10,8 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 from .autoencoder_kl_ltx import AutoencoderKLLTXVideo -from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video +from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index df59e2d748..2d55f166c6 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -25,7 +25,6 @@ from ..activations import get_activation from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from ..normalization import RMSNorm from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution @@ -33,8 +32,8 @@ class PerChannelRMSNorm(nn.Module): """ Per-pixel (per-location) RMS normalization layer. - For each element along the chosen dimension, this layer normalizes the tensor - by the root-mean-square of its values across that dimension: + For each element along the chosen dimension, this layer normalizes the tensor by the root-mean-square of its values + across that dimension: y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) """ @@ -174,9 +173,7 @@ class LTX2VideoResnetBlock3d(nn.Module): if in_channels != out_channels: self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) # LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d - self.conv_shortcut = nn.Conv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1 - ) + self.conv_shortcut = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1) self.per_channel_scale1 = None self.per_channel_scale2 = None @@ -953,7 +950,10 @@ class LTX2VideoDecoder3d(nn.Module): self.gradient_checkpointing = False def forward( - self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, causal: Optional[bool] = None, + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, ) -> torch.Tensor: causal = causal or self.is_causal @@ -1279,7 +1279,8 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig if self.use_slicing and z.shape[0] > 1: if temb is not None: decoded_slices = [ - self._decode(z_slice, t_slice, causal=causal).sample for z_slice, t_slice in (z.split(1), temb.split(1)) + self._decode(z_slice, t_slice, causal=causal).sample + for z_slice, t_slice in (z.split(1), temb.split(1)) ] else: decoded_slices = [self._decode(z_slice, causal=causal).sample for z_slice in z.split(1)] diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 8cdcfa1a74..091d55645a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -249,6 +249,7 @@ class LTX2AudioUpsample(nn.Module): return x + class LTX2AudioAudioPatchifier: """ Patchifier for spectrogram/audio latents. @@ -405,9 +406,7 @@ class LTX2AudioDecoder(nn.Module): final_block_channels = block_in if self.norm_type == "group": - self.norm_out = nn.GroupNorm( - num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True - ) + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True) elif self.norm_type == "pixel": self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) else: @@ -538,8 +537,8 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): # Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over # the entire dataset and stored in model's checkpoint under AudioVAE state_dict - latents_std = torch.zeros((base_channels, )) - latents_mean = torch.ones((base_channels, )) + latents_std = torch.zeros((base_channels,)) + latents_mean = torch.ones((base_channels,)) self.register_buffer("latents_mean", latents_mean, persistent=True) self.register_buffer("latents_std", latents_std, persistent=True) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 4e3cd84ec7..2182a59cd0 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -22,16 +22,21 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + is_torch_version, + logging, + scale_lora_layers, + unscale_lora_layers, +) from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin -from ..embeddings import PixArtAlphaTextProjection, PixArtAlphaCombinedTimestepSizeEmbeddings -from ..modeling_outputs import Transformer2DModelOutput +from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle, RMSNorm +from ..normalization import RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -44,6 +49,7 @@ def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, tor out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out + def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: cos, sin = freqs @@ -65,7 +71,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten # (..., 2, r) split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float - first_x = split_x[..., :1, :] # (..., 1, r) + first_x = split_x[..., :1, :] # (..., 1, r) second_x = split_x[..., 1:, :] # (..., 1, r) cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r) @@ -89,6 +95,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten ROTARY_FN_MAP = {"interleaved": apply_interleaved_rotary_emb, "split": apply_split_rotary_emb} + @dataclass class AudioVisualModelOutput(BaseOutput): r""" @@ -192,7 +199,9 @@ class LTX2AudioVideoAttnProcessor: if query_rotary_emb is not None: query = ROTARY_FN_MAP[attn.rope_type](query, query_rotary_emb) - key = ROTARY_FN_MAP[attn.rope_type](key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) + key = ROTARY_FN_MAP[attn.rope_type]( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) @@ -368,7 +377,7 @@ class LTX2VideoTransformerBlock(nn.Module): bias=attention_bias, out_bias=attention_out_bias, qk_norm=qk_norm, - rope_type=rope_type + rope_type=rope_type, ) self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) @@ -381,7 +390,7 @@ class LTX2VideoTransformerBlock(nn.Module): bias=attention_bias, out_bias=attention_out_bias, qk_norm=qk_norm, - rope_type=rope_type + rope_type=rope_type, ) # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention @@ -396,7 +405,7 @@ class LTX2VideoTransformerBlock(nn.Module): bias=attention_bias, out_bias=attention_out_bias, qk_norm=qk_norm, - rope_type=rope_type + rope_type=rope_type, ) # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video @@ -481,7 +490,9 @@ class LTX2VideoTransformerBlock(nn.Module): audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( batch_size, temb_audio.size(1), num_audio_ada_params, -1 ) - audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = audio_ada_values.unbind(dim=2) + audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = ( + audio_ada_values.unbind(dim=2) + ) norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa attn_audio_hidden_states = self.audio_attn1( @@ -550,8 +561,12 @@ class LTX2VideoTransformerBlock(nn.Module): if use_a2v_cross_attn: # Audio-to-Video Cross Attention: Q: Video; K,V: Audio - mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(2) - mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_a2v_ca_scale.squeeze(2)) + audio_a2v_ca_shift.squeeze(2) + mod_norm_hidden_states = norm_hidden_states * ( + 1 + video_a2v_ca_scale.squeeze(2) + ) + video_a2v_ca_shift.squeeze(2) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_a2v_ca_scale.squeeze(2) + ) + audio_a2v_ca_shift.squeeze(2) a2v_attn_hidden_states = self.audio_to_video_attn( mod_norm_hidden_states, @@ -565,8 +580,12 @@ class LTX2VideoTransformerBlock(nn.Module): if use_v2a_cross_attn: # Video-to-Audio Cross Attention: Q: Audio; K,V: Video - mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(2) - mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_v2a_ca_scale.squeeze(2)) + audio_v2a_ca_shift.squeeze(2) + mod_norm_hidden_states = norm_hidden_states * ( + 1 + video_v2a_ca_scale.squeeze(2) + ) + video_v2a_ca_shift.squeeze(2) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_v2a_ca_scale.squeeze(2) + ) + audio_v2a_ca_shift.squeeze(2) v2a_attn_hidden_states = self.video_to_audio_attn( mod_norm_audio_hidden_states, @@ -596,9 +615,10 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): Args: causal_offset (`int`, *optional*, defaults to `1`): - Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where - the VAE treats the very first frame differently), but could also be 0 (for non-causal modeling). + Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where the VAE + treats the very first frame differently), but could also be 0 (for non-causal modeling). """ + def __init__( self, dim: int, @@ -658,9 +678,9 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): fps: float = 25.0, ) -> torch.Tensor: """ - Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original - pixel space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, - num_patches, 2) where + Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original pixel + space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, num_patches, 2) + where - axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames) - axis 3 (size 2) stores `[start, end)` indices within each dimension @@ -727,8 +747,8 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): shift: int = 0, ) -> torch.Tensor: """ - Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent - frame. This will ultimately have shape (batch_size, 3, num_patches, 2) where + Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent frame. + This will ultimately have shape (batch_size, 3, num_patches, 2) where - axis 1 (size 1) represents the temporal dimension - axis 3 (size 2) stores `[start, end)` indices within each dimension @@ -763,7 +783,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): # Handle first frame causal offset, ensuring non-negative timestamps grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0) # Convert mel bins back into seconds - grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate # 3. Calculate start timstamps in seconds with respect to the original spectrogram grid grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor @@ -862,7 +882,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems]) cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) - + elif self.rope_type == "split": expected_freqs = self.dim // 2 current_freqs = freqs.shape[-1] @@ -1087,7 +1107,7 @@ class LTX2VideoTransformer3DModel( modality="audio", double_precision=rope_double_precision, rope_type=rope_type, - num_attention_heads=audio_num_attention_heads + num_attention_heads=audio_num_attention_heads, ) # 5. Transformer Blocks @@ -1154,7 +1174,7 @@ class LTX2VideoTransformer3DModel( encoder_hidden_states (`torch.Tensor`): Input text embeddings of shape TODO. TODO for the rest. - + Returns: `AudioVisualModelOutput` or `tuple`: If `return_dict` is `True`, returns a structured output of type `AudioVisualModelOutput`, otherwise a @@ -1204,14 +1224,18 @@ class LTX2VideoTransformer3DModel( audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) - audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :], device=audio_hidden_states.device) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope( + audio_coords[:, 0:1, :], device=audio_hidden_states.device + ) # 2. Patchify input projections hidden_states = self.proj_in(hidden_states) audio_hidden_states = self.audio_proj_in(audio_hidden_states) # 3. Prepare timestep embeddings and modulation parameters - timestep_cross_attn_gate_scale_factor = self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier + timestep_cross_attn_gate_scale_factor = ( + self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier + ) # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters # temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer @@ -1243,7 +1267,9 @@ class LTX2VideoTransformer3DModel( batch_size=batch_size, hidden_dtype=hidden_states.dtype, ) - video_cross_attn_scale_shift = video_cross_attn_scale_shift.view(batch_size, -1, video_cross_attn_scale_shift.shape[-1]) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.view( + batch_size, -1, video_cross_attn_scale_shift.shape[-1] + ) video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( @@ -1256,7 +1282,9 @@ class LTX2VideoTransformer3DModel( batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) - audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view(batch_size, -1, audio_cross_attn_scale_shift.shape[-1]) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view( + batch_size, -1, audio_cross_attn_scale_shift.shape[-1] + ) audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) # 4. Prepare prompt embeddings diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index eaf444d5ec..39c8ce6623 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -720,7 +720,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: LEditsPPPipelineStableDiffusionXL, ) from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline - from .ltx2 import LTX2Pipeline, LTX2ImageToVideoPipeline + from .ltx2 import LTX2ImageToVideoPipeline, LTX2Pipeline from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index 95d5f8d4a4..2760f8f7fe 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -22,9 +22,9 @@ except OptionalDependencyNotAvailable: _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: + _import_structure["connectors"] = ["LTX2TextConnectors"] _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] - _import_structure["connectors"] = ["LTX2TextConnectors"] _import_structure["vocoder"] = ["LTX2Vocoder"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -35,9 +35,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: + from .connectors import LTX2TextConnectors from .pipeline_ltx2 import LTX2Pipeline from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline - from .connectors import LTX2TextConnectors from .vocoder import LTX2Vocoder else: diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py index c146c9833e..2608c2783f 100644 --- a/src/diffusers/pipelines/ltx2/connectors.py +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -9,6 +9,7 @@ from ...models.attention import FeedForward from ...models.modeling_utils import ModelMixin from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor + class LTX2RotaryPosEmbed1d(nn.Module): """ 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. @@ -21,12 +22,12 @@ class LTX2RotaryPosEmbed1d(nn.Module): theta: float = 10000.0, double_precision: bool = True, rope_type: str = "interleaved", - num_attention_heads: int = 32 + num_attention_heads: int = 32, ): super().__init__() if rope_type not in ["interleaved", "split"]: raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") - + self.dim = dim self.base_seq_len = base_seq_len self.theta = theta @@ -69,7 +70,7 @@ class LTX2RotaryPosEmbed1d(nn.Module): sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) - + elif self.rope_type == "split": expected_freqs = self.dim // 2 current_freqs = freqs.shape[-1] @@ -116,7 +117,7 @@ class LTX2TransformerBlock1d(nn.Module): kv_heads=num_attention_heads, dim_head=attention_head_dim, processor=LTX2AudioVideoAttnProcessor(), - rope_type=rope_type + rope_type=rope_type, ) self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) @@ -159,7 +160,7 @@ class LTX2ConnectorTransformer1d(nn.Module): rope_double_precision: bool = True, eps: float = 1e-6, causal_temporal_positioning: bool = False, - rope_type: str = "interleaved" + rope_type: str = "interleaved", ): super().__init__() self.num_attention_heads = num_attention_heads @@ -173,12 +174,12 @@ class LTX2ConnectorTransformer1d(nn.Module): self.learnable_registers = torch.nn.Parameter(init_registers) self.rope = LTX2RotaryPosEmbed1d( - self.inner_dim, - base_seq_len=rope_base_seq_len, - theta=rope_theta, + self.inner_dim, + base_seq_len=rope_base_seq_len, + theta=rope_theta, double_precision=rope_double_precision, rope_type=rope_type, - num_attention_heads=num_attention_heads + num_attention_heads=num_attention_heads, ) self.transformer_blocks = torch.nn.ModuleList( @@ -187,7 +188,7 @@ class LTX2ConnectorTransformer1d(nn.Module): dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, - rope_type=rope_type + rope_type=rope_type, ) for _ in range(num_layers) ] @@ -253,8 +254,8 @@ class LTX2ConnectorTransformer1d(nn.Module): class LTX2TextConnectors(ModelMixin, ConfigMixin): """ - Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and - audio streams. + Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio + streams. """ @register_to_config @@ -287,7 +288,7 @@ class LTX2TextConnectors(ModelMixin, ConfigMixin): rope_theta=rope_theta, rope_double_precision=rope_double_precision, causal_temporal_positioning=causal_temporal_positioning, - rope_type=rope_type + rope_type=rope_type, ) self.audio_connector = LTX2ConnectorTransformer1d( num_attention_heads=audio_connector_num_attention_heads, @@ -298,7 +299,7 @@ class LTX2TextConnectors(ModelMixin, ConfigMixin): rope_theta=rope_theta, rope_double_precision=rope_double_precision, causal_temporal_positioning=causal_temporal_positioning, - rope_type=rope_type + rope_type=rope_type, ) def forward( diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 08fad91c41..7cbcca67d2 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -674,7 +674,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: duration_s = num_frames / frame_rate - latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) + latents_per_second = ( + float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) + ) latent_length = int(duration_s * latents_per_second) if latents is not None: @@ -995,7 +997,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = latent_model_input.to(prompt_embeds.dtype) - audio_latent_model_input = torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -1026,10 +1030,14 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix if self.do_classifier_free_guidance: noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) - noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (noise_pred_video_text - noise_pred_video_uncond) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) - noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) if self.guidance_rescale > 0: # Based on 3.4. in https://huggingface.co/papers/2305.08891 diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index caad9a1767..0a707806ce 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -13,25 +13,26 @@ # limitations under the License. import copy -from typing import Any, Callable, Dict, List, Optional, Union import inspect +from typing import Any, Callable, Dict, List, Optional, Union + import numpy as np import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast -from ...schedulers import FlowMatchEulerDiscreteScheduler from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput -from ...utils import is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import randn_tensor -from .connectors import LTX2TextConnectors -from .pipeline_output import LTX2PipelineOutput -from ..pipeline_utils import DiffusionPipeline -from .vocoder import LTX2Vocoder from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video from ...models.transformers import LTX2VideoTransformer3DModel -from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .connectors import LTX2TextConnectors +from .pipeline_output import LTX2PipelineOutput +from .vocoder import LTX2Vocoder if is_torch_xla_available(): @@ -86,6 +87,7 @@ def retrieve_latents( else: raise AttributeError("Could not access latents of provided encoder_output") + # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, @@ -665,7 +667,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL shape = (batch_size, num_channels_latents, num_frames, height, width) mask_shape = (batch_size, 1, num_frames, height, width) - + if latents is not None: conditioning_mask = latents.new_zeros(mask_shape) conditioning_mask[:, :, 0] = 1.0 @@ -697,7 +699,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL init_latents = torch.cat(init_latents, dim=0).to(dtype) init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) - + # First condition is image latents and those should be kept clean. conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) conditioning_mask[:, :, 0] = 1.0 @@ -731,7 +733,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: duration_s = num_frames / frame_rate - latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) + latents_per_second = ( + float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) + ) latent_length = int(duration_s * latents_per_second) if latents is not None: @@ -982,7 +986,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL ) if self.do_classifier_free_guidance: conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) - + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio @@ -1063,12 +1067,14 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = latent_model_input.to(prompt_embeds.dtype) - audio_latent_model_input = torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) timestep = t.expand(latent_model_input.shape[0]) video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) - + with self.transformer.cache_context("cond_uncond"): noise_pred_video, noise_pred_audio = self.transformer( hidden_states=latent_model_input, @@ -1095,10 +1101,14 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL if self.do_classifier_free_guidance: noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) - noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (noise_pred_video_text - noise_pred_video_uncond) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) - noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) if self.guidance_rescale > 0: # Based on 3.4. in https://huggingface.co/papers/2305.08891 diff --git a/src/diffusers/pipelines/ltx2/vocoder.py b/src/diffusers/pipelines/ltx2/vocoder.py index c3b3c1f367..217c68103e 100644 --- a/src/diffusers/pipelines/ltx2/vocoder.py +++ b/src/diffusers/pipelines/ltx2/vocoder.py @@ -25,32 +25,18 @@ class ResBlock(nn.Module): self.convs1 = nn.ModuleList( [ - nn.Conv1d( - channels, - channels, - kernel_size, - stride=stride, - dilation=dilation, - padding=padding_mode - ) + nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode) for dilation in dilations ] ) self.convs2 = nn.ModuleList( [ - nn.Conv1d( - channels, - channels, - kernel_size, - stride=stride, - dilation=1, - padding=padding_mode - ) + nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode) for _ in range(len(dilations)) ] ) - + def forward(self, x: torch.Tensor) -> torch.Tensor: for conv1, conv2 in zip(self.convs1, self.convs2): xt = F.leaky_relu(x, negative_slope=self.negative_slope) @@ -127,7 +113,7 @@ class LTX2Vocoder(ModelMixin, ConfigMixin): input_channels = output_channels self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3) - + def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: r""" Forward pass of the vocoder. diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8628893200..54746ecb58 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -502,6 +502,36 @@ class AutoencoderKLHunyuanVideo15(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class AutoencoderKLLTX2Audio(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoencoderKLLTX2Video(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLLTXVideo(metaclass=DummyObject): _backends = ["torch"] @@ -1132,6 +1162,21 @@ class LatteTransformer3DModel(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class LTX2VideoTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LTXVideoTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index da64742518..50a88afbb2 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1802,6 +1802,36 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class LTX2ImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LTX2Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXConditionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py index 25984d621a..146241361a 100644 --- a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py @@ -15,8 +15,6 @@ import unittest -import torch - from diffusers import AutoencoderKLLTX2Video from ...testing_utils import ( diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index 1b0a7dd28f..8a6b50b55e 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -52,9 +52,9 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): sequence_length = 16 hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device) - audio_hidden_states = torch.randn( - (batch_size, audio_num_frames, audio_num_channels * num_mel_bins) - ).to(torch_device) + audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to( + torch_device + ) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) diff --git a/tests/pipelines/ltx2/__init__.py b/tests/pipelines/ltx2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/ltx2/test_ltx2.py b/tests/pipelines/ltx2/test_ltx2.py new file mode 100644 index 0000000000..73d08e6b1a --- /dev/null +++ b/tests/pipelines/ltx2/test_ltx2.py @@ -0,0 +1,239 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2Pipeline, + LTX2VideoTransformer3DModel, +) +from diffusers.pipelines.ltx2 import LTX2TextConnectors +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder + +from ...testing_utils import enable_full_determinism +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTX2Pipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "audio_latents", + "output_type", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_attention_slicing = False + test_xformers_attention = False + supports_dduf = False + + base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3" + + def get_dummy_components(self): + tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id) + + torch.manual_seed(0) + transformer = LTX2VideoTransformer3DModel( + in_channels=4, + out_channels=4, + patch_size=1, + patch_size_t=1, + num_attention_heads=2, + attention_head_dim=8, + cross_attention_dim=16, + audio_in_channels=4, + audio_out_channels=4, + audio_num_attention_heads=2, + audio_attention_head_dim=4, + audio_cross_attention_dim=8, + num_layers=2, + qk_norm="rms_norm_across_heads", + caption_channels=text_encoder.config.text_config.hidden_size, + rope_double_precision=False, + rope_type="split", + ) + + torch.manual_seed(0) + connectors = LTX2TextConnectors( + caption_channels=text_encoder.config.text_config.hidden_size, + text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1, + video_connector_num_attention_heads=4, + video_connector_attention_head_dim=8, + video_connector_num_layers=1, + video_connector_num_learnable_registers=None, + audio_connector_num_attention_heads=4, + audio_connector_attention_head_dim=8, + audio_connector_num_layers=1, + audio_connector_num_learnable_registers=None, + connector_rope_base_seq_len=32, + rope_theta=10000.0, + rope_double_precision=False, + causal_temporal_positioning=False, + rope_type="split", + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTX2Video( + in_channels=3, + out_channels=3, + latent_channels=4, + block_out_channels=(8,), + decoder_block_out_channels=(8,), + layers_per_block=(1,), + decoder_layers_per_block=(1, 1), + spatio_temporal_scaling=(True,), + decoder_spatio_temporal_scaling=(True,), + decoder_inject_noise=(False, False), + downsample_type=("spatial",), + upsample_residual=(False,), + upsample_factor=(1,), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + audio_vae = AutoencoderKLLTX2Audio( + base_channels=4, + output_channels=2, + ch_mult=(1,), + num_res_blocks=1, + attn_resolutions=None, + in_channels=2, + resolution=32, + latent_channels=2, + norm_type="pixel", + causality_axis="height", + dropout=0.0, + mid_block_add_attention=False, + sample_rate=16000, + mel_hop_length=160, + is_causal=True, + mel_bins=8, + ) + + torch.manual_seed(0) + vocoder = LTX2Vocoder( + in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins, + hidden_channels=32, + out_channels=2, + upsample_kernel_sizes=[4, 4], + upsample_factors=[2, 2], + resnet_kernel_sizes=[3], + resnet_dilations=[[1, 3, 5]], + leaky_relu_negative_slope=0.1, + output_sampling_rate=16000, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + components = { + "transformer": transformer, + "vae": vae, + "audio_vae": audio_vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "connectors": connectors, + "vocoder": vocoder, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "a robot dancing", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "num_frames": 5, + "frame_rate": 25.0, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = pipe(**inputs) + video = output.frames + audio = output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.4331, 0.6203, 0.3245, 0.7294, 0.4822, 0.5703, 0.2999, 0.7700, 0.4961, 0.4242, 0.4581, 0.4351, 0.1137, 0.4437, 0.6304, 0.3184 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0229, 0.0503, 0.1220, 0.1083, 0.1745, 0.1075, 0.1779, 0.0974, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2) diff --git a/tests/pipelines/ltx2/test_ltx2_image2video.py b/tests/pipelines/ltx2/test_ltx2_image2video.py new file mode 100644 index 0000000000..9c58b4fc41 --- /dev/null +++ b/tests/pipelines/ltx2/test_ltx2_image2video.py @@ -0,0 +1,241 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2ImageToVideoPipeline, + LTX2VideoTransformer3DModel, +) +from diffusers.pipelines.ltx2 import LTX2TextConnectors +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder + +from ...testing_utils import enable_full_determinism +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTX2ImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "audio_latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_attention_slicing = False + test_xformers_attention = False + supports_dduf = False + + base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3" + + def get_dummy_components(self): + tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id) + + torch.manual_seed(0) + transformer = LTX2VideoTransformer3DModel( + in_channels=4, + out_channels=4, + patch_size=1, + patch_size_t=1, + num_attention_heads=2, + attention_head_dim=8, + cross_attention_dim=16, + audio_in_channels=4, + audio_out_channels=4, + audio_num_attention_heads=2, + audio_attention_head_dim=4, + audio_cross_attention_dim=8, + num_layers=2, + qk_norm="rms_norm_across_heads", + caption_channels=text_encoder.config.text_config.hidden_size, + rope_double_precision=False, + rope_type="split", + ) + + torch.manual_seed(0) + connectors = LTX2TextConnectors( + caption_channels=text_encoder.config.text_config.hidden_size, + text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1, + video_connector_num_attention_heads=4, + video_connector_attention_head_dim=8, + video_connector_num_layers=1, + video_connector_num_learnable_registers=None, + audio_connector_num_attention_heads=4, + audio_connector_attention_head_dim=8, + audio_connector_num_layers=1, + audio_connector_num_learnable_registers=None, + connector_rope_base_seq_len=32, + rope_theta=10000.0, + rope_double_precision=False, + causal_temporal_positioning=False, + rope_type="split", + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTX2Video( + in_channels=3, + out_channels=3, + latent_channels=4, + block_out_channels=(8,), + decoder_block_out_channels=(8,), + layers_per_block=(1,), + decoder_layers_per_block=(1, 1), + spatio_temporal_scaling=(True,), + decoder_spatio_temporal_scaling=(True,), + decoder_inject_noise=(False, False), + downsample_type=("spatial",), + upsample_residual=(False,), + upsample_factor=(1,), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + audio_vae = AutoencoderKLLTX2Audio( + base_channels=4, + output_channels=2, + ch_mult=(1,), + num_res_blocks=1, + attn_resolutions=None, + in_channels=2, + resolution=32, + latent_channels=2, + norm_type="pixel", + causality_axis="height", + dropout=0.0, + mid_block_add_attention=False, + sample_rate=16000, + mel_hop_length=160, + is_causal=True, + mel_bins=8, + ) + + torch.manual_seed(0) + vocoder = LTX2Vocoder( + in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins, + hidden_channels=32, + out_channels=2, + upsample_kernel_sizes=[4, 4], + upsample_factors=[2, 2], + resnet_kernel_sizes=[3], + resnet_dilations=[[1, 3, 5]], + leaky_relu_negative_slope=0.1, + output_sampling_rate=16000, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + components = { + "transformer": transformer, + "vae": vae, + "audio_vae": audio_vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "connectors": connectors, + "vocoder": vocoder, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = torch.rand((1, 3, 32, 32), generator=generator, device=device) + + inputs = { + "image": image, + "prompt": "a robot dancing", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "num_frames": 5, + "frame_rate": 25.0, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = pipe(**inputs) + video = output.frames + audio = output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.3573, 0.8382, 0.3581, 0.6114, 0.3682, 0.7969, 0.2552, 0.6399, 0.3113, 0.1497, 0.3249, 0.5395, 0.3498, 0.4526, 0.4536, 0.4555 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0229, 0.0503, 0.1220, 0.1083, 0.1745, 0.1075, 0.1779, 0.0974, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)