mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[fix code annotation] Adjust the dimensions of the rotary positional embedding. (#8890)
* 2d rotary pos emb dim * make style --------- Co-authored-by: haofanwang <haofanwang.ai@gmail.com>
This commit is contained in:
@@ -319,12 +319,16 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
|
||||
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
|
||||
emb_h = get_1d_rotary_pos_embed(
|
||||
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
|
||||
) # (H*W, D/2) if use_real else (H*W, D/4)
|
||||
emb_w = get_1d_rotary_pos_embed(
|
||||
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
|
||||
) # (H*W, D/2) if use_real else (H*W, D/4)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
|
||||
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
|
||||
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
|
||||
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
||||
@@ -371,6 +375,8 @@ def get_1d_rotary_pos_embed(
|
||||
Returns:
|
||||
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
if isinstance(pos, int):
|
||||
pos = np.arange(pos)
|
||||
theta = theta * ntk_factor
|
||||
|
||||
Reference in New Issue
Block a user