1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

small update on rotary embedding (#9354)

* update

* fix

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
YiYi Xu
2024-09-03 07:18:33 -10:00
committed by GitHub
parent 8ba90aa706
commit dcf320f293

View File

@@ -608,8 +608,11 @@ def get_1d_rotary_pos_embed(
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
freqs = freqs.to(pos.device)
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox