From 461efc57c5921024875711cbd4719f100d700f73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=A5=87=E5=8B=8B?= <46553287+wangqixun@users.noreply.github.com> Date: Sat, 20 Jul 2024 12:57:36 +0800 Subject: [PATCH] [fix code annotation] Adjust the dimensions of the rotary positional embedding. (#8890) * 2d rotary pos emb dim * make style --------- Co-authored-by: haofanwang --- src/diffusers/models/embeddings.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0890842f57..7684fdf9cd 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -319,12 +319,16 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): assert embed_dim % 4 == 0 # use half of dimensions to encode grid_h - emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) - emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) + emb_h = get_1d_rotary_pos_embed( + embed_dim // 2, grid[0].reshape(-1), use_real=use_real + ) # (H*W, D/2) if use_real else (H*W, D/4) + emb_w = get_1d_rotary_pos_embed( + embed_dim // 2, grid[1].reshape(-1), use_real=use_real + ) # (H*W, D/2) if use_real else (H*W, D/4) if use_real: - cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2) - sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D) return cos, sin else: emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) @@ -371,6 +375,8 @@ def get_1d_rotary_pos_embed( Returns: `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] """ + assert dim % 2 == 0 + if isinstance(pos, int): pos = np.arange(pos) theta = theta * ntk_factor