mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Use float32 on mps or npu in transformer_hidream_image's rope (#11316)
This commit is contained in:
@@ -95,7 +95,12 @@ class HiDreamImagePatchEmbed(nn.Module):
|
||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0, "The dimension must be even."
|
||||
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
is_mps = pos.device.type == "mps"
|
||||
is_npu = pos.device.type == "npu"
|
||||
|
||||
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
||||
|
||||
scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
|
||||
batch_size, seq_length = pos.shape
|
||||
|
||||
Reference in New Issue
Block a user