1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[fix code annotation] Adjust the dimensions of the rotary positional embedding. (#8890)

* 2d rotary pos emb dim

* make style

---------

Co-authored-by: haofanwang <haofanwang.ai@gmail.com>
This commit is contained in:
王奇勋
2024-07-20 12:57:36 +08:00
committed by GitHub
parent 3b04cdc816
commit 461efc57c5

View File

@@ -319,12 +319,16 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
emb_h = get_1d_rotary_pos_embed(
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
) # (H*W, D/2) if use_real else (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
) # (H*W, D/2) if use_real else (H*W, D/4)
if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
@@ -371,6 +375,8 @@ def get_1d_rotary_pos_embed(
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = np.arange(pos)
theta = theta * ntk_factor