diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 80775d477c..91451fa9aa 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -86,12 +86,25 @@ def get_3d_sincos_pos_embed( temporal_interpolation_scale: float = 1.0, ) -> np.ndarray: r""" + Creates 3D sinusoidal positional embeddings. + Args: embed_dim (`int`): + The embedding dimension of inputs. It must be divisible by 16. spatial_size (`int` or `Tuple[int, int]`): + The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both + spatial dimensions (height and width). temporal_size (`int`): + The temporal dimension of postional embeddings (number of frames). spatial_interpolation_scale (`float`, defaults to 1.0): + Scale factor for spatial grid interpolation. temporal_interpolation_scale (`float`, defaults to 1.0): + Scale factor for temporal grid interpolation. + + Returns: + `np.ndarray`: + The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1], + embed_dim]`. """ if embed_dim % 4 != 0: raise ValueError("`embed_dim` must be divisible by 4") @@ -129,8 +142,24 @@ def get_2d_sincos_pos_embed( embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 ): """ - grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + Creates 2D sinusoidal positional embeddings. + + Args: + embed_dim (`int`): + The embedding dimension. + grid_size (`int`): + The size of the grid height and width. + cls_token (`bool`, defaults to `False`): + Whether or not to add a classification token. + extra_tokens (`int`, defaults to `0`): + The number of extra tokens to add. + interpolation_scale (`float`, defaults to `1.0`): + The scale of the interpolation. + + Returns: + pos_embed (`np.ndarray`): + Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size, + embed_dim]` if using cls_token """ if isinstance(grid_size, int): grid_size = (grid_size, grid_size) @@ -148,6 +177,16 @@ def get_2d_sincos_pos_embed( def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + r""" + This function generates 2D sinusoidal positional embeddings from a grid. + + Args: + embed_dim (`int`): The embedding dimension. + grid (`np.ndarray`): Grid of positions with shape `(H * W,)`. + + Returns: + `np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` + """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") @@ -161,7 +200,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ - embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + This function generates 1D positional embeddings from a grid. + + Args: + embed_dim (`int`): The embedding dimension `D` + pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)` + + Returns: + `numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`. """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") @@ -181,7 +227,22 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): class PatchEmbed(nn.Module): - """2D Image to Patch Embedding with support for SD3 cropping.""" + """ + 2D Image to Patch Embedding with support for SD3 cropping. + + Args: + height (`int`, defaults to `224`): The height of the image. + width (`int`, defaults to `224`): The width of the image. + patch_size (`int`, defaults to `16`): The size of the patches. + in_channels (`int`, defaults to `3`): The number of input channels. + embed_dim (`int`, defaults to `768`): The output dimension of the embedding. + layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization. + flatten (`bool`, defaults to `True`): Whether or not to flatten the output. + bias (`bool`, defaults to `True`): Whether or not to use bias. + interpolation_scale (`float`, defaults to `1`): The scale of the interpolation. + pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding. + pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding. + """ def __init__( self, @@ -289,7 +350,15 @@ class PatchEmbed(nn.Module): class LuminaPatchEmbed(nn.Module): - """2D Image to Patch Embedding with support for Lumina-T2X""" + """ + 2D Image to Patch Embedding with support for Lumina-T2X + + Args: + patch_size (`int`, defaults to `2`): The size of the patches. + in_channels (`int`, defaults to `4`): The number of input channels. + embed_dim (`int`, defaults to `768`): The output dimension of the embedding. + bias (`bool`, defaults to `True`): Whether or not to use bias. + """ def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True): super().__init__() @@ -675,6 +744,20 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): + """ + Get 2D RoPE from grid. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + grid (`np.ndarray`): + The grid of the positional embedding. + use_real (`bool`): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. + """ assert embed_dim % 4 == 0 # use half of dimensions to encode grid_h @@ -695,6 +778,23 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0): + """ + Get 2D RoPE from grid. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + grid (`np.ndarray`): + The grid of the positional embedding. + linear_factor (`float`): + The linear factor of the positional embedding, which is used to scale the positional embedding in the linear + layer. + ntk_factor (`float`): + The ntk factor of the positional embedding, which is used to scale the positional embedding in the ntk layer. + + Returns: + `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. + """ assert embed_dim % 4 == 0 emb_h = get_1d_rotary_pos_embed(