mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Get LTX 2 transformer tests working
This commit is contained in:
@@ -577,6 +577,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
# Audio-specific
|
||||
self.sampling_rate = sampling_rate
|
||||
self.hop_length = hop_length
|
||||
self.audio_latents_per_second = float(sampling_rate) / float(hop_length) / float(scale_factors[0])
|
||||
|
||||
self.scale_factors = scale_factors
|
||||
self.theta = theta
|
||||
@@ -657,6 +658,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
fps: float = 25.0,
|
||||
shift: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -682,9 +684,11 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
"""
|
||||
|
||||
# 1. Generate coordinates in the frame (time) dimension.
|
||||
audio_duration_s = num_frames / fps
|
||||
latent_frames = int(audio_duration_s * self.audio_latents_per_second)
|
||||
# Always compute rope in fp32
|
||||
grid_f = torch.arange(
|
||||
start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device
|
||||
start=shift, end=latent_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
# 2. Calculate start timstamps in seconds with respect to the original spectrogram grid
|
||||
@@ -748,10 +752,11 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
device=device,
|
||||
shift=shift,
|
||||
)
|
||||
# Number of spatiotemporal dimensions (3 for video, 1 for audio)
|
||||
# Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn)
|
||||
num_pos_dims = coords.shape[1]
|
||||
|
||||
# 2. If the coords are patch boundaries [start, end), use the midpoint of these boundaries
|
||||
# 2. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch
|
||||
# position index
|
||||
if coords.ndim == 4:
|
||||
coords_start, coords_end = coords.chunk(2, dim=-1)
|
||||
coords = (coords_start + coords_end) / 2.0
|
||||
@@ -762,8 +767,9 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
max_positions = (self.base_num_frames, self.base_height, self.base_width)
|
||||
elif self.modality == "audio":
|
||||
max_positions = (self.base_num_frames,)
|
||||
# [B, num_pos_dims, num_patches] --> [B, num_patches, num_pos_dims]
|
||||
grid = torch.stack([coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1).to(device)
|
||||
# Number of spatiotemporal dimensions (3 for video, 1 for audio) times 2 for cos, sin
|
||||
# Number of spatiotemporal dimensions (3 for video, 1 for audio and cross attn) times 2 for cos, sin
|
||||
num_rope_elems = num_pos_dims * 2
|
||||
|
||||
# 4. Create a 1D grid of frequencies for RoPE
|
||||
@@ -778,11 +784,10 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
)
|
||||
freqs = freqs * math.pi / 2.0
|
||||
|
||||
# 5. Tensor-vector outer product between pos ids tensor of shape [B, 3, num_patches] and freqs vector of shape
|
||||
# self.dim // num_elems
|
||||
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, 3, num_patches, self.dim // num_elems]
|
||||
freqs = freqs.transpose(1, 2).flatten(2) # [B, num_patches, self.dim // 2]
|
||||
# freqs = freqs.transpose(-1, -2).flatten(2) # [B, 3, num_patches * self.dim // num_elems]???
|
||||
# 5. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape
|
||||
# (self.dim // num_elems,)
|
||||
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, num_patches, num_pos_dims, self.dim // num_elems]
|
||||
freqs = freqs.transpose(-1, -2).flatten(2) # [B, num_patches, self.dim // 2]
|
||||
|
||||
# 6. Get real, interleaved (cos, sin) frequencies, padded to self.dim
|
||||
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
@@ -888,7 +893,7 @@ class LTX2VideoTransformer3DModel(
|
||||
|
||||
# 1. Patchification input projections
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
self.audio_proj_in = nn.Linear(audio_in_channels, inner_dim)
|
||||
self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim)
|
||||
|
||||
# 2. Prompt embeddings
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
@@ -990,6 +995,10 @@ class LTX2VideoTransformer3DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
audio_dim=audio_inner_dim,
|
||||
audio_num_attention_heads=audio_num_attention_heads,
|
||||
audio_attention_head_dim=audio_attention_head_dim,
|
||||
audio_cross_attention_dim=audio_cross_attention_dim,
|
||||
qk_norm=qk_norm,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
@@ -1015,6 +1024,7 @@ class LTX2VideoTransformer3DModel(
|
||||
hidden_states: torch.Tensor,
|
||||
audio_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
audio_encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
num_frames: Optional[int] = None,
|
||||
@@ -1077,9 +1087,13 @@ class LTX2VideoTransformer3DModel(
|
||||
|
||||
video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device)
|
||||
audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device)
|
||||
print(f"Video RoPE cos shape: {video_rotary_emb[0].shape} | sin shape: {video_rotary_emb[1].shape}")
|
||||
print(f"Audio RoPE cos shape: {audio_rotary_emb[0].shape} | sin shape: {audio_rotary_emb[1].shape}")
|
||||
|
||||
video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device)
|
||||
audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :], device=audio_hidden_states.device)
|
||||
print(f"Video CA RoPE cos shape: {video_cross_attn_rotary_emb[0].shape} | sin shape: {video_cross_attn_rotary_emb[1].shape}")
|
||||
print(f"Audio CA RoPE cos shape: {audio_cross_attn_rotary_emb[0].shape} | sin shape: {audio_cross_attn_rotary_emb[1].shape}")
|
||||
|
||||
# 2. Patchify input projections
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
@@ -1110,12 +1124,12 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1))
|
||||
|
||||
# 3.2. Prepare global modality cross attention modulation parameters
|
||||
video_cross_attn_scale_shift = self.av_cross_attn_video_scale_shift(
|
||||
video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
|
||||
timestep.flatten(),
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
video_cross_attn_a2v_gate = self.av_cross_attn_video_a2v_gate(
|
||||
video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate(
|
||||
timestep.flatten() * timestep_cross_attn_gate_scale_factor,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
@@ -1123,12 +1137,12 @@ class LTX2VideoTransformer3DModel(
|
||||
video_cross_attn_scale_shift = video_cross_attn_scale_shift.view(batch_size, -1, video_cross_attn_scale_shift.shape[-1])
|
||||
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
|
||||
|
||||
audio_cross_attn_scale_shift = self.av_cross_attn_audio_scale_shift(
|
||||
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
|
||||
timestep.flatten(),
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
audio_cross_attn_v2a_gate = self.av_cross_attn_audio_a2v_gate(
|
||||
audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
|
||||
timestep.flatten() * timestep_cross_attn_gate_scale_factor,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
@@ -1137,13 +1151,12 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
|
||||
|
||||
# 4. Prepare prompt embeddings
|
||||
# TODO: does the audio prompt embedding start from the same text embeddings as the video one?
|
||||
audio_encoder_hidden_states = self.audio_caption_projection(encoder_hidden_states)
|
||||
audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1))
|
||||
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
|
||||
|
||||
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
|
||||
audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1))
|
||||
|
||||
# 5. Run transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
@@ -1152,6 +1165,7 @@ class LTX2VideoTransformer3DModel(
|
||||
hidden_states,
|
||||
audio_hidden_states,
|
||||
encoder_hidden_states,
|
||||
audio_encoder_hidden_states,
|
||||
temb,
|
||||
temb_audio,
|
||||
video_cross_attn_scale_shift,
|
||||
@@ -1169,6 +1183,7 @@ class LTX2VideoTransformer3DModel(
|
||||
hidden_states=hidden_states,
|
||||
audio_hidden_states=audio_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
audio_encoder_hidden_states=audio_encoder_hidden_states,
|
||||
temb=temb,
|
||||
temb_audio=temb_audio,
|
||||
temb_ca_scale_shift=video_cross_attn_scale_shift,
|
||||
|
||||
Reference in New Issue
Block a user