1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix LTX 2 transformer bugs so consistency test passes

This commit is contained in:
Daniel Gu
2025-12-16 10:53:43 +01:00
parent a7bc052e89
commit bda3ff13db
2 changed files with 15 additions and 11 deletions

View File

@@ -456,7 +456,7 @@ class LTX2VideoTransformerBlock(nn.Module):
query_rotary_emb=None,
attention_mask=audio_encoder_attention_mask,
)
hidden_states = hidden_states + attn_hidden_states
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
if use_a2v_cross_attn or use_v2a_cross_attn:
@@ -557,7 +557,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
base_width: int = 2048,
sampling_rate: int = 16000,
hop_length: int = 160,
scale_factors: Tuple[int, ...] = (8, 32 ,32),
scale_factors: Tuple[int, ...] = (8, 32, 32),
theta: float = 10000.0,
causal_offset: int = 1,
modality: str = "video",
@@ -594,6 +594,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
height: int,
width: int,
device: torch.device,
fps: float = 25.0,
) -> torch.Tensor:
"""
Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original
@@ -651,6 +652,9 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
# and clamp to keep the first-frame timestamps causal and non-negative.
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0)
# Scale the temporal coordinates by the video FPS
pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps
return pixel_coords
def prepare_audio_coords(
@@ -742,15 +746,15 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
height,
width,
device=device,
fps=fps,
)
# Scale the temporal coordinates by the video FPS
coords[:, 0, ...] = coords[:, 0, ...] / fps
elif coords is None and self.modality == "audio":
coords = self.prepare_audio_coords(
batch_size,
num_frames,
device=device,
shift=shift,
fps=fps,
)
# Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn)
num_pos_dims = coords.shape[1]
@@ -1086,9 +1090,13 @@ class LTX2VideoTransformer3DModel(
# 1. Prepare RoPE positional embeddings
if video_coords is None:
video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, hidden_states.device)
video_coords = self.rope.prepare_video_coords(
batch_size, num_frames, height, width, hidden_states.device, fps=fps
)
if audio_coords is None:
audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames, audio_hidden_states.device)
audio_coords = self.audio_rope.prepare_audio_coords(
batch_size, audio_num_frames, audio_hidden_states.device, fps=fps
)
video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device)
audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device)
@@ -1104,7 +1112,7 @@ class LTX2VideoTransformer3DModel(
# Scale timestep
timestep = timestep * timestep_scale_multiplier
timestep_cross_attn_gate_scale_factor = cross_attn_timestep_scale_multiplier / timestep_scale_multiplier
# 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters
# temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer
# modulation with scale_shift_table (and similarly for audio)

View File

@@ -207,14 +207,10 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
video_output_flat = video_output.cpu().flatten().float()
video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]])
print(f"Video Expected Slice: {video_expected_slice}")
print(f"Video Generated Slice: {video_generated_slice}")
self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4))
audio_output_flat = audio_output.cpu().flatten().float()
audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]])
print(f"Audio Expected Slice: {audio_expected_slice}")
print(f"Audio Generated Slice: {audio_generated_slice}")
self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4))