1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add more LTX 2 transformer audio arguments

This commit is contained in:
Daniel Gu
2025-12-16 07:58:12 +01:00
parent a5f2d2da6c
commit d86f89ddea

View File

@@ -394,6 +394,7 @@ class LTX2VideoTransformerBlock(nn.Module):
ca_video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
ca_audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
audio_encoder_attention_mask: Optional[torch.Tensor] = None,
a2v_cross_attention_mask: Optional[torch.Tensor] = None,
v2a_cross_attention_mask: Optional[torch.Tensor] = None,
use_video_self_attn: bool = True,
@@ -453,7 +454,7 @@ class LTX2VideoTransformerBlock(nn.Module):
norm_audio_hidden_states,
encoder_hidden_states=audio_encoder_hidden_states,
query_rotary_emb=None,
attention_mask=encoder_attention_mask,
attention_mask=audio_encoder_attention_mask,
)
hidden_states = hidden_states + attn_hidden_states
@@ -1024,11 +1025,13 @@ class LTX2VideoTransformer3DModel(
encoder_hidden_states: torch.Tensor,
audio_encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
encoder_attention_mask: Optional[torch.Tensor] = None,
audio_encoder_attention_mask: Optional[torch.Tensor] = None,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
fps: float = 25.0,
audio_num_frames: Optional[int] = None,
video_coords: Optional[torch.Tensor] = None,
audio_coords: Optional[torch.Tensor] = None,
timestep_scale_multiplier: int = 1000,
@@ -1075,13 +1078,17 @@ class LTX2VideoTransformer3DModel(
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2:
audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0
audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1)
batch_size = hidden_states.size(0)
# 1. Prepare RoPE positional embeddings
if video_coords is None:
video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, hidden_states.device)
if audio_coords is None:
audio_coords = self.audio_rope.prepare_audio_coords(batch_size, num_frames, audio_hidden_states.device)
audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames, audio_hidden_states.device)
video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device)
audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device)
@@ -1171,6 +1178,7 @@ class LTX2VideoTransformer3DModel(
video_cross_attn_rotary_emb,
audio_cross_attn_rotary_emb,
encoder_attention_mask,
audio_encoder_attention_mask,
)
else:
hidden_states, audio_hidden_states = block(
@@ -1189,6 +1197,7 @@ class LTX2VideoTransformer3DModel(
ca_video_rotary_emb=video_cross_attn_rotary_emb,
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
encoder_attention_mask=encoder_attention_mask,
audio_encoder_attention_mask=audio_encoder_attention_mask,
)
# 6. Output layers (including unpatchification)