From bda3ff13dbc895365fb6b3fcbb800df5f1844ecf Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 16 Dec 2025 10:53:43 +0100 Subject: [PATCH] Fix LTX 2 transformer bugs so consistency test passes --- .../models/transformers/transformer_ltx2.py | 22 +++++++++++++------ .../test_models_transformer_ltx2.py | 4 ---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 2ce6106eec..ea9bca115e 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -456,7 +456,7 @@ class LTX2VideoTransformerBlock(nn.Module): query_rotary_emb=None, attention_mask=audio_encoder_attention_mask, ) - hidden_states = hidden_states + attn_hidden_states + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention if use_a2v_cross_attn or use_v2a_cross_attn: @@ -557,7 +557,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): base_width: int = 2048, sampling_rate: int = 16000, hop_length: int = 160, - scale_factors: Tuple[int, ...] = (8, 32 ,32), + scale_factors: Tuple[int, ...] = (8, 32, 32), theta: float = 10000.0, causal_offset: int = 1, modality: str = "video", @@ -594,6 +594,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): height: int, width: int, device: torch.device, + fps: float = 25.0, ) -> torch.Tensor: """ Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original @@ -651,6 +652,9 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): # and clamp to keep the first-frame timestamps causal and non-negative. pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0) + # Scale the temporal coordinates by the video FPS + pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps + return pixel_coords def prepare_audio_coords( @@ -742,15 +746,15 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): height, width, device=device, + fps=fps, ) - # Scale the temporal coordinates by the video FPS - coords[:, 0, ...] = coords[:, 0, ...] / fps elif coords is None and self.modality == "audio": coords = self.prepare_audio_coords( batch_size, num_frames, device=device, shift=shift, + fps=fps, ) # Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn) num_pos_dims = coords.shape[1] @@ -1086,9 +1090,13 @@ class LTX2VideoTransformer3DModel( # 1. Prepare RoPE positional embeddings if video_coords is None: - video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, hidden_states.device) + video_coords = self.rope.prepare_video_coords( + batch_size, num_frames, height, width, hidden_states.device, fps=fps + ) if audio_coords is None: - audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames, audio_hidden_states.device) + audio_coords = self.audio_rope.prepare_audio_coords( + batch_size, audio_num_frames, audio_hidden_states.device, fps=fps + ) video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device) audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) @@ -1104,7 +1112,7 @@ class LTX2VideoTransformer3DModel( # Scale timestep timestep = timestep * timestep_scale_multiplier timestep_cross_attn_gate_scale_factor = cross_attn_timestep_scale_multiplier / timestep_scale_multiplier - + # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters # temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer # modulation with scale_shift_table (and similarly for audio) diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index 0bf08f161d..6c0b97c589 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -207,14 +207,10 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): video_output_flat = video_output.cpu().flatten().float() video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]]) - print(f"Video Expected Slice: {video_expected_slice}") - print(f"Video Generated Slice: {video_generated_slice}") self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4)) audio_output_flat = audio_output.cpu().flatten().float() audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]]) - print(f"Audio Expected Slice: {audio_expected_slice}") - print(f"Audio Generated Slice: {audio_generated_slice}") self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4))