mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix LTX 2 transformer bugs so consistency test passes
This commit is contained in:
@@ -456,7 +456,7 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
query_rotary_emb=None,
|
||||
attention_mask=audio_encoder_attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
|
||||
|
||||
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
|
||||
if use_a2v_cross_attn or use_v2a_cross_attn:
|
||||
@@ -557,7 +557,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
base_width: int = 2048,
|
||||
sampling_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
scale_factors: Tuple[int, ...] = (8, 32 ,32),
|
||||
scale_factors: Tuple[int, ...] = (8, 32, 32),
|
||||
theta: float = 10000.0,
|
||||
causal_offset: int = 1,
|
||||
modality: str = "video",
|
||||
@@ -594,6 +594,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
fps: float = 25.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original
|
||||
@@ -651,6 +652,9 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
# and clamp to keep the first-frame timestamps causal and non-negative.
|
||||
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0)
|
||||
|
||||
# Scale the temporal coordinates by the video FPS
|
||||
pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps
|
||||
|
||||
return pixel_coords
|
||||
|
||||
def prepare_audio_coords(
|
||||
@@ -742,15 +746,15 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
height,
|
||||
width,
|
||||
device=device,
|
||||
fps=fps,
|
||||
)
|
||||
# Scale the temporal coordinates by the video FPS
|
||||
coords[:, 0, ...] = coords[:, 0, ...] / fps
|
||||
elif coords is None and self.modality == "audio":
|
||||
coords = self.prepare_audio_coords(
|
||||
batch_size,
|
||||
num_frames,
|
||||
device=device,
|
||||
shift=shift,
|
||||
fps=fps,
|
||||
)
|
||||
# Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn)
|
||||
num_pos_dims = coords.shape[1]
|
||||
@@ -1086,9 +1090,13 @@ class LTX2VideoTransformer3DModel(
|
||||
|
||||
# 1. Prepare RoPE positional embeddings
|
||||
if video_coords is None:
|
||||
video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, hidden_states.device)
|
||||
video_coords = self.rope.prepare_video_coords(
|
||||
batch_size, num_frames, height, width, hidden_states.device, fps=fps
|
||||
)
|
||||
if audio_coords is None:
|
||||
audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames, audio_hidden_states.device)
|
||||
audio_coords = self.audio_rope.prepare_audio_coords(
|
||||
batch_size, audio_num_frames, audio_hidden_states.device, fps=fps
|
||||
)
|
||||
|
||||
video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device)
|
||||
audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device)
|
||||
@@ -1104,7 +1112,7 @@ class LTX2VideoTransformer3DModel(
|
||||
# Scale timestep
|
||||
timestep = timestep * timestep_scale_multiplier
|
||||
timestep_cross_attn_gate_scale_factor = cross_attn_timestep_scale_multiplier / timestep_scale_multiplier
|
||||
|
||||
|
||||
# 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters
|
||||
# temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer
|
||||
# modulation with scale_shift_table (and similarly for audio)
|
||||
|
||||
@@ -207,14 +207,10 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
video_output_flat = video_output.cpu().flatten().float()
|
||||
video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]])
|
||||
print(f"Video Expected Slice: {video_expected_slice}")
|
||||
print(f"Video Generated Slice: {video_generated_slice}")
|
||||
self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4))
|
||||
|
||||
audio_output_flat = audio_output.cpu().flatten().float()
|
||||
audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]])
|
||||
print(f"Audio Expected Slice: {audio_expected_slice}")
|
||||
print(f"Audio Generated Slice: {audio_generated_slice}")
|
||||
self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user