From 5e0cf2b2f0f0aff1d8c76aab787be2f88095da7c Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 6 Jan 2026 23:32:59 +0100 Subject: [PATCH] Simplify LTX 2 RoPE forward by removing coords is None logic --- .../models/transformers/transformer_ltx2.py | 45 ++++--------------- 1 file changed, 8 insertions(+), 37 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index ab369ebfe4..413ceb24fa 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -795,51 +795,22 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): def forward( self, - coords: Optional[torch.Tensor] = None, - batch_size: Optional[int] = None, - num_frames: Optional[int] = None, - height: Optional[int] = None, - width: Optional[int] = None, - fps: float = 25.0, - shift: int = 0, + coords: torch.Tensor, device: Optional[Union[str, torch.device]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if coords is not None: - device = device or coords.device - batch_size = batch_size or coords.size(0) - else: - device = device or "cpu" - batch_size = batch_size or 1 + device = device or coords.device - # 1. Calculate the coordinate grid with respect to data space for the given modality (video, audio). - if coords is None and self.modality == "video": - coords = self.prepare_video_coords( - batch_size, - num_frames, - height, - width, - device=device, - fps=fps, - ) - elif coords is None and self.modality == "audio": - coords = self.prepare_audio_coords( - batch_size, - num_frames, - device=device, - shift=shift, - fps=fps, - ) # Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn) num_pos_dims = coords.shape[1] - # 2. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch + # 1. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch # position index if coords.ndim == 4: coords_start, coords_end = coords.chunk(2, dim=-1) coords = (coords_start + coords_end) / 2.0 coords = coords.squeeze(-1) # [B, num_pos_dims, num_patches] - # 3. Get coordinates as a fraction of the base data shape + # 2. Get coordinates as a fraction of the base data shape if self.modality == "video": max_positions = (self.base_num_frames, self.base_height, self.base_width) elif self.modality == "audio": @@ -849,7 +820,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): # Number of spatiotemporal dimensions (3 for video, 1 for audio and cross attn) times 2 for cos, sin num_rope_elems = num_pos_dims * 2 - # 4. Create a 1D grid of frequencies for RoPE + # 3. Create a 1D grid of frequencies for RoPE freqs_dtype = torch.float64 if self.double_precision else torch.float32 pow_indices = torch.pow( self.theta, @@ -857,12 +828,12 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): ) freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) - # 5. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape + # 4. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape # (self.dim // num_elems,) freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, num_patches, num_pos_dims, self.dim // num_elems] freqs = freqs.transpose(-1, -2).flatten(2) # [B, num_patches, self.dim // 2] - # 6. Get real, interleaved (cos, sin) frequencies, padded to self.dim + # 5. Get real, interleaved (cos, sin) frequencies, padded to self.dim # TODO: consider implementing this as a utility and reuse in `connectors.py`. # src/diffusers/pipelines/ltx2/connectors.py if self.rope_type == "interleaved": @@ -1212,7 +1183,7 @@ class LTX2VideoTransformer3DModel( batch_size, audio_num_frames, audio_hidden_states.device, fps=fps ) - video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device) + video_rotary_emb = self.rope(video_coords, device=hidden_states.device) audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device)