mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Simplify LTX 2 RoPE forward by removing coords is None logic
This commit is contained in:
@@ -795,51 +795,22 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
coords: Optional[torch.Tensor] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
fps: float = 25.0,
|
||||
shift: int = 0,
|
||||
coords: torch.Tensor,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if coords is not None:
|
||||
device = device or coords.device
|
||||
batch_size = batch_size or coords.size(0)
|
||||
else:
|
||||
device = device or "cpu"
|
||||
batch_size = batch_size or 1
|
||||
device = device or coords.device
|
||||
|
||||
# 1. Calculate the coordinate grid with respect to data space for the given modality (video, audio).
|
||||
if coords is None and self.modality == "video":
|
||||
coords = self.prepare_video_coords(
|
||||
batch_size,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
device=device,
|
||||
fps=fps,
|
||||
)
|
||||
elif coords is None and self.modality == "audio":
|
||||
coords = self.prepare_audio_coords(
|
||||
batch_size,
|
||||
num_frames,
|
||||
device=device,
|
||||
shift=shift,
|
||||
fps=fps,
|
||||
)
|
||||
# Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn)
|
||||
num_pos_dims = coords.shape[1]
|
||||
|
||||
# 2. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch
|
||||
# 1. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch
|
||||
# position index
|
||||
if coords.ndim == 4:
|
||||
coords_start, coords_end = coords.chunk(2, dim=-1)
|
||||
coords = (coords_start + coords_end) / 2.0
|
||||
coords = coords.squeeze(-1) # [B, num_pos_dims, num_patches]
|
||||
|
||||
# 3. Get coordinates as a fraction of the base data shape
|
||||
# 2. Get coordinates as a fraction of the base data shape
|
||||
if self.modality == "video":
|
||||
max_positions = (self.base_num_frames, self.base_height, self.base_width)
|
||||
elif self.modality == "audio":
|
||||
@@ -849,7 +820,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
# Number of spatiotemporal dimensions (3 for video, 1 for audio and cross attn) times 2 for cos, sin
|
||||
num_rope_elems = num_pos_dims * 2
|
||||
|
||||
# 4. Create a 1D grid of frequencies for RoPE
|
||||
# 3. Create a 1D grid of frequencies for RoPE
|
||||
freqs_dtype = torch.float64 if self.double_precision else torch.float32
|
||||
pow_indices = torch.pow(
|
||||
self.theta,
|
||||
@@ -857,12 +828,12 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
)
|
||||
freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
|
||||
|
||||
# 5. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape
|
||||
# 4. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape
|
||||
# (self.dim // num_elems,)
|
||||
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, num_patches, num_pos_dims, self.dim // num_elems]
|
||||
freqs = freqs.transpose(-1, -2).flatten(2) # [B, num_patches, self.dim // 2]
|
||||
|
||||
# 6. Get real, interleaved (cos, sin) frequencies, padded to self.dim
|
||||
# 5. Get real, interleaved (cos, sin) frequencies, padded to self.dim
|
||||
# TODO: consider implementing this as a utility and reuse in `connectors.py`.
|
||||
# src/diffusers/pipelines/ltx2/connectors.py
|
||||
if self.rope_type == "interleaved":
|
||||
@@ -1212,7 +1183,7 @@ class LTX2VideoTransformer3DModel(
|
||||
batch_size, audio_num_frames, audio_hidden_states.device, fps=fps
|
||||
)
|
||||
|
||||
video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device)
|
||||
video_rotary_emb = self.rope(video_coords, device=hidden_states.device)
|
||||
audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device)
|
||||
|
||||
video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device)
|
||||
|
||||
Reference in New Issue
Block a user