1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix LTX 2 transformer shape errors

This commit is contained in:
Daniel Gu
2025-12-15 06:38:57 +01:00
parent 5765759cd3
commit aeecc4d712
2 changed files with 10 additions and 10 deletions

View File

@@ -297,7 +297,7 @@ class LTX2VideoTransformerBlock(nn.Module):
qk_norm=qk_norm,
)
self.audio_norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_attn1 = LTX2Attention(
query_dim=audio_dim,
heads=audio_num_attention_heads,
@@ -322,7 +322,7 @@ class LTX2VideoTransformerBlock(nn.Module):
qk_norm=qk_norm,
)
self.audio_norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_attn2 = LTX2Attention(
query_dim=audio_dim,
cross_attention_dim=audio_cross_attention_dim,
@@ -349,7 +349,7 @@ class LTX2VideoTransformerBlock(nn.Module):
)
# Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
self.video_to_audio_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.video_to_audio_norm = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
self.video_to_audio_attn = LTX2Attention(
query_dim=audio_dim,
cross_attention_dim=dim,
@@ -365,13 +365,13 @@ class LTX2VideoTransformerBlock(nn.Module):
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.ff = FeedForward(dim, activation_fn=activation_fn)
self.audio_norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_norm3 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn)
# 5. Per-Layer Modulation Parameters
# Self-Attention / Feedforward AdaLayerNorm-Zero mod params
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
self.audio_scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5)
# Per-layer a2v, v2a Cross-Attention mod params
self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))
@@ -459,8 +459,8 @@ class LTX2VideoTransformerBlock(nn.Module):
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
if use_a2v_cross_attn or use_v2a_cross_attn:
norm_hidden_states = self.norm3(hidden_states)
norm_audio_hidden_states = self.audio_norm3(audio_hidden_states)
norm_hidden_states = self.audio_to_video_norm(hidden_states)
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
# Combine global and per-layer cross attention modulation parameters
# Video
@@ -1114,7 +1114,7 @@ class LTX2VideoTransformer3DModel(
batch_size=batch_size,
hidden_dtype=audio_hidden_states.dtype,
)
temb_audio = temb.view(batch_size, -1, temb_audio.size(-1))
temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1))
audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1))
# 3.2. Prepare global modality cross attention modulation parameters

View File

@@ -94,8 +94,8 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
"audio_in_channels": 4,
"audio_out_channels": 4,
"audio_num_attention_heads": 2,
"audio_attention_head_dim": 8,
"audio_cross_attention_dim": 16,
"audio_attention_head_dim": 4,
"audio_cross_attention_dim": 8,
"num_layers": 2,
"qk_norm": "rms_norm_across_heads",
"caption_channels": 16,