mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into feat/autoencodermixin
This commit is contained in:
@@ -180,7 +180,6 @@ class QwenEmbedRope(nn.Module):
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
self.rope_cache = {}
|
||||
|
||||
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
|
||||
self.scale_rope = scale_rope
|
||||
@@ -195,10 +194,20 @@ class QwenEmbedRope(nn.Module):
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
|
||||
txt_seq_lens: List[int],
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
||||
txt_length: [bs] a list of 1 integers representing the length of the text
|
||||
Args:
|
||||
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
|
||||
A list of 3 integers [frame, height, width] representing the shape of the video.
|
||||
txt_seq_lens (`List[int]`):
|
||||
A list of integers of length batch_size representing the length of each text prompt.
|
||||
device: (`torch.device`):
|
||||
The device on which to perform the RoPE computation.
|
||||
"""
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
@@ -213,14 +222,8 @@ class QwenEmbedRope(nn.Module):
|
||||
max_vid_index = 0
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
rope_key = f"{idx}_{height}_{width}"
|
||||
|
||||
if not torch.compiler.is_compiling():
|
||||
if rope_key not in self.rope_cache:
|
||||
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
|
||||
video_freq = self.rope_cache[rope_key]
|
||||
else:
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
||||
# RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
||||
video_freq = video_freq.to(device)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
@@ -235,8 +238,8 @@ class QwenEmbedRope(nn.Module):
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0):
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
Reference in New Issue
Block a user