mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add more LTX 2 transformer audio arguments
This commit is contained in:
@@ -394,6 +394,7 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
ca_video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
ca_audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
audio_encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
a2v_cross_attention_mask: Optional[torch.Tensor] = None,
|
||||
v2a_cross_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_video_self_attn: bool = True,
|
||||
@@ -453,7 +454,7 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
norm_audio_hidden_states,
|
||||
encoder_hidden_states=audio_encoder_hidden_states,
|
||||
query_rotary_emb=None,
|
||||
attention_mask=encoder_attention_mask,
|
||||
attention_mask=audio_encoder_attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
|
||||
@@ -1024,11 +1025,13 @@ class LTX2VideoTransformer3DModel(
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
audio_encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
audio_encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
fps: float = 25.0,
|
||||
audio_num_frames: Optional[int] = None,
|
||||
video_coords: Optional[torch.Tensor] = None,
|
||||
audio_coords: Optional[torch.Tensor] = None,
|
||||
timestep_scale_multiplier: int = 1000,
|
||||
@@ -1075,13 +1078,17 @@ class LTX2VideoTransformer3DModel(
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2:
|
||||
audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0
|
||||
audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
# 1. Prepare RoPE positional embeddings
|
||||
if video_coords is None:
|
||||
video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, hidden_states.device)
|
||||
if audio_coords is None:
|
||||
audio_coords = self.audio_rope.prepare_audio_coords(batch_size, num_frames, audio_hidden_states.device)
|
||||
audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames, audio_hidden_states.device)
|
||||
|
||||
video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device)
|
||||
audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device)
|
||||
@@ -1171,6 +1178,7 @@ class LTX2VideoTransformer3DModel(
|
||||
video_cross_attn_rotary_emb,
|
||||
audio_cross_attn_rotary_emb,
|
||||
encoder_attention_mask,
|
||||
audio_encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
hidden_states, audio_hidden_states = block(
|
||||
@@ -1189,6 +1197,7 @@ class LTX2VideoTransformer3DModel(
|
||||
ca_video_rotary_emb=video_cross_attn_rotary_emb,
|
||||
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
audio_encoder_attention_mask=audio_encoder_attention_mask,
|
||||
)
|
||||
|
||||
# 6. Output layers (including unpatchification)
|
||||
|
||||
Reference in New Issue
Block a user