From d86f89ddea76952279af1da5ff188562f615325f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 16 Dec 2025 07:58:12 +0100 Subject: [PATCH] Add more LTX 2 transformer audio arguments --- .../models/transformers/transformer_ltx2.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index c1ad5f180f..2ce6106eec 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -394,6 +394,7 @@ class LTX2VideoTransformerBlock(nn.Module): ca_video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ca_audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, a2v_cross_attention_mask: Optional[torch.Tensor] = None, v2a_cross_attention_mask: Optional[torch.Tensor] = None, use_video_self_attn: bool = True, @@ -453,7 +454,7 @@ class LTX2VideoTransformerBlock(nn.Module): norm_audio_hidden_states, encoder_hidden_states=audio_encoder_hidden_states, query_rotary_emb=None, - attention_mask=encoder_attention_mask, + attention_mask=audio_encoder_attention_mask, ) hidden_states = hidden_states + attn_hidden_states @@ -1024,11 +1025,13 @@ class LTX2VideoTransformer3DModel( encoder_hidden_states: torch.Tensor, audio_encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, - encoder_attention_mask: torch.Tensor, + encoder_attention_mask: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, num_frames: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None, fps: float = 25.0, + audio_num_frames: Optional[int] = None, video_coords: Optional[torch.Tensor] = None, audio_coords: Optional[torch.Tensor] = None, timestep_scale_multiplier: int = 1000, @@ -1075,13 +1078,17 @@ class LTX2VideoTransformer3DModel( encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: + audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0 + audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) + batch_size = hidden_states.size(0) # 1. Prepare RoPE positional embeddings if video_coords is None: video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, hidden_states.device) if audio_coords is None: - audio_coords = self.audio_rope.prepare_audio_coords(batch_size, num_frames, audio_hidden_states.device) + audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames, audio_hidden_states.device) video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device) audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) @@ -1171,6 +1178,7 @@ class LTX2VideoTransformer3DModel( video_cross_attn_rotary_emb, audio_cross_attn_rotary_emb, encoder_attention_mask, + audio_encoder_attention_mask, ) else: hidden_states, audio_hidden_states = block( @@ -1189,6 +1197,7 @@ class LTX2VideoTransformer3DModel( ca_video_rotary_emb=video_cross_attn_rotary_emb, ca_audio_rotary_emb=audio_cross_attn_rotary_emb, encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, ) # 6. Output layers (including unpatchification)