1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Rope in float32 for mps or npu compatibility (#12665)

rope in float32
This commit is contained in:
David Bertoin
2025-11-15 16:14:34 +01:00
committed by GitHub
parent a9e4883b6a
commit 01a56927f1

View File

@@ -275,7 +275,12 @@ class PRXEmbedND(nn.Module):
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
is_mps = pos.device.type == "mps"
is_npu = pos.device.type == "npu"
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)