mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* Fix MPS compatibility in get_1d_sincos_pos_embed_from_grid #12432 * Fix trailing whitespace in docstring * Apply style fixes --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
85eb505672
commit
74b5fed434
@@ -319,13 +319,17 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False, dtype=None):
|
||||
"""
|
||||
This function generates 1D positional embeddings from a grid.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`): The embedding dimension `D`
|
||||
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
|
||||
output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings.
|
||||
dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to
|
||||
`torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` on other devices.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
|
||||
@@ -341,7 +345,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
|
||||
# Auto-detect appropriate dtype if not specified
|
||||
if dtype is None:
|
||||
dtype = torch.float32 if pos.device.type == "mps" else torch.float64
|
||||
|
||||
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user