From dcf320f2937f19581bad195e35d6ba796d807c42 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 3 Sep 2024 07:18:33 -1000 Subject: [PATCH] small update on rotary embedding (#9354) * update * fix --------- Co-authored-by: Sayak Paul --- src/diffusers/models/embeddings.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index db8f4fd172..eb5067c377 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -608,8 +608,11 @@ def get_1d_rotary_pos_embed( pos = torch.from_numpy(pos) # type: ignore # [S] theta = theta * ntk_factor - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] - freqs = freqs.to(pos.device) + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox