mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[docs] refactoring docstrings in models/embeddings_flax.py (#9592)
* [docs] refactoring docstrings in `models/embeddings_flax.py` * Update src/diffusers/models/embeddings_flax.py * make style --------- Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
@@ -29,11 +29,21 @@ def get_sinusoidal_embeddings(
|
||||
"""Returns the positional encoding (same as Tensor2Tensor).
|
||||
|
||||
Args:
|
||||
timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
embedding_dim: The number of output channels.
|
||||
min_timescale: The smallest time unit (should probably be 0.0).
|
||||
max_timescale: The largest time unit.
|
||||
timesteps (`jnp.ndarray` of shape `(N,)`):
|
||||
A 1-D array of N indices, one per batch element. These may be fractional.
|
||||
embedding_dim (`int`):
|
||||
The number of output channels.
|
||||
freq_shift (`float`, *optional*, defaults to `1`):
|
||||
Shift applied to the frequency scaling of the embeddings.
|
||||
min_timescale (`float`, *optional*, defaults to `1`):
|
||||
The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
|
||||
max_timescale (`float`, *optional*, defaults to `1.0e4`):
|
||||
The largest time unit used in the sinusoidal calculation.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip the order of sinusoidal components to cosine first.
|
||||
scale (`float`, *optional*, defaults to `1.0`):
|
||||
A scaling factor applied to the positional embeddings.
|
||||
|
||||
Returns:
|
||||
a Tensor of timing signals [N, num_channels]
|
||||
"""
|
||||
@@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module):
|
||||
|
||||
Args:
|
||||
time_embed_dim (`int`, *optional*, defaults to `32`):
|
||||
Time step embedding dimension
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
Time step embedding dimension.
|
||||
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
||||
The data type for the embedding parameters.
|
||||
"""
|
||||
|
||||
time_embed_dim: int = 32
|
||||
@@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module):
|
||||
|
||||
Args:
|
||||
dim (`int`, *optional*, defaults to `32`):
|
||||
Time step embedding dimension
|
||||
Time step embedding dimension.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip the sinusoidal function from sine to cosine.
|
||||
freq_shift (`float`, *optional*, defaults to `1`):
|
||||
Frequency shift applied to the sinusoidal embeddings.
|
||||
"""
|
||||
|
||||
dim: int = 32
|
||||
|
||||
Reference in New Issue
Block a user