diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index c3bb7a00a4..c1ad5f180f 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -297,7 +297,7 @@ class LTX2VideoTransformerBlock(nn.Module): qk_norm=qk_norm, ) - self.audio_norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) self.audio_attn1 = LTX2Attention( query_dim=audio_dim, heads=audio_num_attention_heads, @@ -322,7 +322,7 @@ class LTX2VideoTransformerBlock(nn.Module): qk_norm=qk_norm, ) - self.audio_norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) self.audio_attn2 = LTX2Attention( query_dim=audio_dim, cross_attention_dim=audio_cross_attention_dim, @@ -349,7 +349,7 @@ class LTX2VideoTransformerBlock(nn.Module): ) # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video - self.video_to_audio_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.video_to_audio_norm = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) self.video_to_audio_attn = LTX2Attention( query_dim=audio_dim, cross_attention_dim=dim, @@ -365,13 +365,13 @@ class LTX2VideoTransformerBlock(nn.Module): self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) self.ff = FeedForward(dim, activation_fn=activation_fn) - self.audio_norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_norm3 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn) # 5. Per-Layer Modulation Parameters # Self-Attention / Feedforward AdaLayerNorm-Zero mod params self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) - self.audio_scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5) # Per-layer a2v, v2a Cross-Attention mod params self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) @@ -459,8 +459,8 @@ class LTX2VideoTransformerBlock(nn.Module): # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention if use_a2v_cross_attn or use_v2a_cross_attn: - norm_hidden_states = self.norm3(hidden_states) - norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) + norm_hidden_states = self.audio_to_video_norm(hidden_states) + norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) # Combine global and per-layer cross attention modulation parameters # Video @@ -1114,7 +1114,7 @@ class LTX2VideoTransformer3DModel( batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) - temb_audio = temb.view(batch_size, -1, temb_audio.size(-1)) + temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) # 3.2. Prepare global modality cross attention modulation parameters diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index fc089e6190..c382a63eaa 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -94,8 +94,8 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): "audio_in_channels": 4, "audio_out_channels": 4, "audio_num_attention_heads": 2, - "audio_attention_head_dim": 8, - "audio_cross_attention_dim": 16, + "audio_attention_head_dim": 4, + "audio_cross_attention_dim": 8, "num_layers": 2, "qk_norm": "rms_norm_across_heads", "caption_channels": 16,