mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix some documentation in ./src/diffusers/models/embeddings.py for demo (#9579)
* Fix some documentation in ./src/diffusers/models/embeddings.py as demonstration. --------- Co-authored-by: DaAccursed05 <68813178+DaAccursed05@users.noreply.github.com> Co-authored-by: Aryan <contact.aryanvs@gmail.com> Co-authored-by: Aryan <aryan@huggingface.co> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -86,12 +86,25 @@ def get_3d_sincos_pos_embed(
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
) -> np.ndarray:
|
||||
r"""
|
||||
Creates 3D sinusoidal positional embeddings.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`):
|
||||
The embedding dimension of inputs. It must be divisible by 16.
|
||||
spatial_size (`int` or `Tuple[int, int]`):
|
||||
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
||||
spatial dimensions (height and width).
|
||||
temporal_size (`int`):
|
||||
The temporal dimension of postional embeddings (number of frames).
|
||||
spatial_interpolation_scale (`float`, defaults to 1.0):
|
||||
Scale factor for spatial grid interpolation.
|
||||
temporal_interpolation_scale (`float`, defaults to 1.0):
|
||||
Scale factor for temporal grid interpolation.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`:
|
||||
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
|
||||
embed_dim]`.
|
||||
"""
|
||||
if embed_dim % 4 != 0:
|
||||
raise ValueError("`embed_dim` must be divisible by 4")
|
||||
@@ -129,8 +142,24 @@ def get_2d_sincos_pos_embed(
|
||||
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
||||
):
|
||||
"""
|
||||
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
||||
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
Creates 2D sinusoidal positional embeddings.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`):
|
||||
The embedding dimension.
|
||||
grid_size (`int`):
|
||||
The size of the grid height and width.
|
||||
cls_token (`bool`, defaults to `False`):
|
||||
Whether or not to add a classification token.
|
||||
extra_tokens (`int`, defaults to `0`):
|
||||
The number of extra tokens to add.
|
||||
interpolation_scale (`float`, defaults to `1.0`):
|
||||
The scale of the interpolation.
|
||||
|
||||
Returns:
|
||||
pos_embed (`np.ndarray`):
|
||||
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
|
||||
embed_dim]` if using cls_token
|
||||
"""
|
||||
if isinstance(grid_size, int):
|
||||
grid_size = (grid_size, grid_size)
|
||||
@@ -148,6 +177,16 @@ def get_2d_sincos_pos_embed(
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
r"""
|
||||
This function generates 2D sinusoidal positional embeddings from a grid.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`): The embedding dimension.
|
||||
grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
|
||||
"""
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
@@ -161,7 +200,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
||||
This function generates 1D positional embeddings from a grid.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`): The embedding dimension `D`
|
||||
pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`
|
||||
|
||||
Returns:
|
||||
`numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
|
||||
"""
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
@@ -181,7 +227,22 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding with support for SD3 cropping."""
|
||||
"""
|
||||
2D Image to Patch Embedding with support for SD3 cropping.
|
||||
|
||||
Args:
|
||||
height (`int`, defaults to `224`): The height of the image.
|
||||
width (`int`, defaults to `224`): The width of the image.
|
||||
patch_size (`int`, defaults to `16`): The size of the patches.
|
||||
in_channels (`int`, defaults to `3`): The number of input channels.
|
||||
embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
|
||||
layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
|
||||
flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
|
||||
bias (`bool`, defaults to `True`): Whether or not to use bias.
|
||||
interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
|
||||
pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
|
||||
pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -289,7 +350,15 @@ class PatchEmbed(nn.Module):
|
||||
|
||||
|
||||
class LuminaPatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding with support for Lumina-T2X"""
|
||||
"""
|
||||
2D Image to Patch Embedding with support for Lumina-T2X
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`): The size of the patches.
|
||||
in_channels (`int`, defaults to `4`): The number of input channels.
|
||||
embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
|
||||
bias (`bool`, defaults to `True`): Whether or not to use bias.
|
||||
"""
|
||||
|
||||
def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
|
||||
super().__init__()
|
||||
@@ -675,6 +744,20 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
"""
|
||||
Get 2D RoPE from grid.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
grid (`np.ndarray`):
|
||||
The grid of the positional embedding.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
||||
"""
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
@@ -695,6 +778,23 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
|
||||
"""
|
||||
Get 2D RoPE from grid.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
grid (`np.ndarray`):
|
||||
The grid of the positional embedding.
|
||||
linear_factor (`float`):
|
||||
The linear factor of the positional embedding, which is used to scale the positional embedding in the linear
|
||||
layer.
|
||||
ntk_factor (`float`):
|
||||
The ntk factor of the positional embedding, which is used to scale the positional embedding in the ntk layer.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
||||
"""
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
emb_h = get_1d_rotary_pos_embed(
|
||||
|
||||
Reference in New Issue
Block a user