From 673eb60f1c4d971e1a577bed767053e50578b461 Mon Sep 17 00:00:00 2001 From: Alan Du Date: Wed, 10 Jul 2024 21:54:44 -0400 Subject: [PATCH] Reformat docstring for `get_timestep_embedding` (#8811) * Reformat docstring for `get_timestep_embedding` --------- Co-authored-by: YiYi Xu --- src/diffusers/models/embeddings.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 8bc30f7cab..ec1c68b86c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -35,10 +35,21 @@ def get_timestep_embedding( """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] Tensor of positional embeddings. + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"