1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

When using split RoPE, make sure that the output dtype is same as input dtype

This commit is contained in:
Daniel Gu
2026-01-06 00:19:39 +01:00
parent c5b52d6c9f
commit 2fa4f8471f

View File

@@ -47,6 +47,7 @@ def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, tor
def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
cos, sin = freqs
x_dtype = x.dtype
needs_reshape = False
if x.ndim != 4 and cos.ndim == 4:
# cos is (b, h, t, r) -> reshape x to (b, h, t, dim_per_head)
@@ -61,7 +62,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten
r = last // 2
# (..., 2, r)
split_x = x.reshape(*x.shape[:-1], 2, r)
split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float
first_x = split_x[..., :1, :] # (..., 1, r)
second_x = split_x[..., 1:, :] # (..., 1, r)
@@ -80,6 +81,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten
if needs_reshape:
out = out.swapaxes(1, 2).reshape(b, t, -1)
out = out.to(dtype=x_dtype)
return out