1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix MPS compatibility in get_1d_sincos_pos_embed_from_grid #12432 (#12449)

* Fix MPS compatibility in get_1d_sincos_pos_embed_from_grid #12432

* Fix trailing whitespace in docstring

* Apply style fixes

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Aishwarya Badlani
2025-10-23 15:48:07 +05:00
committed by GitHub
parent 85eb505672
commit 74b5fed434

View File

@@ -319,13 +319,17 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False, dtype=None):
"""
This function generates 1D positional embeddings from a grid.
Args:
embed_dim (`int`): The embedding dimension `D`
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings.
dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to
`torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` on other devices.
Returns:
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
@@ -341,7 +345,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
# Auto-detect appropriate dtype if not specified
if dtype is None:
dtype = torch.float32 if pos.device.type == "mps" else torch.float64
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)