mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix Graph Breaks When Compiling CogView4 (#10959)
* Fix Graph Breaks When Compiling CogView4 Eliminate this: ``` t]V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] Recompiling function forward in /home/zeyi/repos/diffusers/src/diffusers/models/transformers/transformer_cogview4.py:374 V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] triggered by the following guard failure(s): V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/3: ___check_obj_id(L['self'].rope.freqs_h, 139976127328032) V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/2: ___check_obj_id(L['self'].rope.freqs_h, 139976107780960) V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/1: ___check_obj_id(L['self'].rope.freqs_h, 140022511848960) V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/0: ___check_obj_id(L['self'].rope.freqs_h, 140024081342416) ``` * Update transformer_cogview4.py * fix cogview4 rotary pos embed * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -244,30 +244,34 @@ class CogView4RotaryPosEmbed(nn.Module):
|
||||
def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.patch_size = patch_size
|
||||
self.rope_axes_dim = rope_axes_dim
|
||||
|
||||
dim_h, dim_w = dim // 2, dim // 2
|
||||
h_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h))
|
||||
w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w))
|
||||
h_seq = torch.arange(self.rope_axes_dim[0])
|
||||
w_seq = torch.arange(self.rope_axes_dim[1])
|
||||
self.freqs_h = torch.outer(h_seq, h_inv_freq)
|
||||
self.freqs_w = torch.outer(w_seq, w_inv_freq)
|
||||
self.theta = theta
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
height, width = height // self.patch_size, width // self.patch_size
|
||||
|
||||
h_idx = torch.arange(height)
|
||||
w_idx = torch.arange(width)
|
||||
dim_h, dim_w = self.dim // 2, self.dim // 2
|
||||
h_inv_freq = 1.0 / (
|
||||
self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
|
||||
)
|
||||
w_inv_freq = 1.0 / (
|
||||
self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
|
||||
)
|
||||
h_seq = torch.arange(self.rope_axes_dim[0])
|
||||
w_seq = torch.arange(self.rope_axes_dim[1])
|
||||
freqs_h = torch.outer(h_seq, h_inv_freq)
|
||||
freqs_w = torch.outer(w_seq, w_inv_freq)
|
||||
|
||||
h_idx = torch.arange(height, device=freqs_h.device)
|
||||
w_idx = torch.arange(width, device=freqs_w.device)
|
||||
inner_h_idx = h_idx * self.rope_axes_dim[0] // height
|
||||
inner_w_idx = w_idx * self.rope_axes_dim[1] // width
|
||||
|
||||
self.freqs_h = self.freqs_h.to(hidden_states.device)
|
||||
self.freqs_w = self.freqs_w.to(hidden_states.device)
|
||||
freqs_h = self.freqs_h[inner_h_idx]
|
||||
freqs_w = self.freqs_w[inner_w_idx]
|
||||
freqs_h = freqs_h[inner_h_idx]
|
||||
freqs_w = freqs_w[inner_w_idx]
|
||||
|
||||
# Create position matrices for height and width
|
||||
# [height, 1, dim//4] and [1, width, dim//4]
|
||||
|
||||
Reference in New Issue
Block a user