mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix LTX 2 transformer shape errors
This commit is contained in:
@@ -297,7 +297,7 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
|
||||
self.audio_norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.audio_attn1 = LTX2Attention(
|
||||
query_dim=audio_dim,
|
||||
heads=audio_num_attention_heads,
|
||||
@@ -322,7 +322,7 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
|
||||
self.audio_norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.audio_attn2 = LTX2Attention(
|
||||
query_dim=audio_dim,
|
||||
cross_attention_dim=audio_cross_attention_dim,
|
||||
@@ -349,7 +349,7 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
)
|
||||
|
||||
# Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
|
||||
self.video_to_audio_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.video_to_audio_norm = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.video_to_audio_attn = LTX2Attention(
|
||||
query_dim=audio_dim,
|
||||
cross_attention_dim=dim,
|
||||
@@ -365,13 +365,13 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
||||
|
||||
self.audio_norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.audio_norm3 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn)
|
||||
|
||||
# 5. Per-Layer Modulation Parameters
|
||||
# Self-Attention / Feedforward AdaLayerNorm-Zero mod params
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
self.audio_scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5)
|
||||
|
||||
# Per-layer a2v, v2a Cross-Attention mod params
|
||||
self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))
|
||||
@@ -459,8 +459,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
|
||||
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
|
||||
if use_a2v_cross_attn or use_v2a_cross_attn:
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
norm_audio_hidden_states = self.audio_norm3(audio_hidden_states)
|
||||
norm_hidden_states = self.audio_to_video_norm(hidden_states)
|
||||
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
|
||||
|
||||
# Combine global and per-layer cross attention modulation parameters
|
||||
# Video
|
||||
@@ -1114,7 +1114,7 @@ class LTX2VideoTransformer3DModel(
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
temb_audio = temb.view(batch_size, -1, temb_audio.size(-1))
|
||||
temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1))
|
||||
audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1))
|
||||
|
||||
# 3.2. Prepare global modality cross attention modulation parameters
|
||||
|
||||
@@ -94,8 +94,8 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"audio_in_channels": 4,
|
||||
"audio_out_channels": 4,
|
||||
"audio_num_attention_heads": 2,
|
||||
"audio_attention_head_dim": 8,
|
||||
"audio_cross_attention_dim": 16,
|
||||
"audio_attention_head_dim": 4,
|
||||
"audio_cross_attention_dim": 8,
|
||||
"num_layers": 2,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"caption_channels": 16,
|
||||
|
||||
Reference in New Issue
Block a user