1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' into feat/autoencodermixin

This commit is contained in:
sayakpaul
2025-10-20 20:17:13 -10:00

View File

@@ -180,7 +180,6 @@ class QwenEmbedRope(nn.Module):
],
dim=1,
)
self.rope_cache = {}
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
self.scale_rope = scale_rope
@@ -195,10 +194,20 @@ class QwenEmbedRope(nn.Module):
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def forward(self, video_fhw, txt_seq_lens, device):
def forward(
self,
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
txt_seq_lens: List[int],
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
Args:
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
A list of 3 integers [frame, height, width] representing the shape of the video.
txt_seq_lens (`List[int]`):
A list of integers of length batch_size representing the length of each text prompt.
device: (`torch.device`):
The device on which to perform the RoPE computation.
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
@@ -213,14 +222,8 @@ class QwenEmbedRope(nn.Module):
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}"
if not torch.compiler.is_compiling():
if rope_key not in self.rope_cache:
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
video_freq = self.rope_cache[rope_key]
else:
video_freq = self._compute_video_freqs(frame, height, width, idx)
# RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
video_freq = self._compute_video_freqs(frame, height, width, idx)
video_freq = video_freq.to(device)
vid_freqs.append(video_freq)
@@ -235,8 +238,8 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
@functools.lru_cache(maxsize=128)
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)