From a3e8d3f7deed140f57a28d82dd0b5d965bd0fb09 Mon Sep 17 00:00:00 2001 From: wony617 <49024958+Jwaminju@users.noreply.github.com> Date: Tue, 15 Oct 2024 22:45:14 +0900 Subject: [PATCH] [docs] refactoring docstrings in `models/embeddings_flax.py` (#9592) * [docs] refactoring docstrings in `models/embeddings_flax.py` * Update src/diffusers/models/embeddings_flax.py * make style --------- Co-authored-by: Aryan --- src/diffusers/models/embeddings_flax.py | 32 ++++++++++++++++++------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index 8e343be0d3..92b5a6c358 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -29,11 +29,21 @@ def get_sinusoidal_embeddings( """Returns the positional encoding (same as Tensor2Tensor). Args: - timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - embedding_dim: The number of output channels. - min_timescale: The smallest time unit (should probably be 0.0). - max_timescale: The largest time unit. + timesteps (`jnp.ndarray` of shape `(N,)`): + A 1-D array of N indices, one per batch element. These may be fractional. + embedding_dim (`int`): + The number of output channels. + freq_shift (`float`, *optional*, defaults to `1`): + Shift applied to the frequency scaling of the embeddings. + min_timescale (`float`, *optional*, defaults to `1`): + The smallest time unit used in the sinusoidal calculation (should probably be 0.0). + max_timescale (`float`, *optional*, defaults to `1.0e4`): + The largest time unit used in the sinusoidal calculation. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the order of sinusoidal components to cosine first. + scale (`float`, *optional*, defaults to `1.0`): + A scaling factor applied to the positional embeddings. + Returns: a Tensor of timing signals [N, num_channels] """ @@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module): Args: time_embed_dim (`int`, *optional*, defaults to `32`): - Time step embedding dimension - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` + Time step embedding dimension. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + The data type for the embedding parameters. """ time_embed_dim: int = 32 @@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module): Args: dim (`int`, *optional*, defaults to `32`): - Time step embedding dimension + Time step embedding dimension. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sinusoidal function from sine to cosine. + freq_shift (`float`, *optional*, defaults to `1`): + Frequency shift applied to the sinusoidal embeddings. """ dim: int = 32