From 01a56927f1603f1e89d1e5ada74d2aa75da2d46b Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Sat, 15 Nov 2025 16:14:34 +0100 Subject: [PATCH] Rope in float32 for mps or npu compatibility (#12665) rope in float32 --- src/diffusers/models/transformers/transformer_prx.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_prx.py b/src/diffusers/models/transformers/transformer_prx.py index 9b2664b9cb..ccbc83ffca 100644 --- a/src/diffusers/models/transformers/transformer_prx.py +++ b/src/diffusers/models/transformers/transformer_prx.py @@ -275,7 +275,12 @@ class PRXEmbedND(nn.Module): def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0 - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + + is_mps = pos.device.type == "mps" + is_npu = pos.device.type == "npu" + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + + scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim omega = 1.0 / (theta**scale) out = pos.unsqueeze(-1) * omega.unsqueeze(0) out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)