mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve docstrings and type hints in scheduling_ddim.py (#12622)
* Improve docstrings and type hints in scheduling_ddim.py - Add complete type hints for all function parameters - Enhance docstrings to follow project conventions - Add missing parameter descriptions Fixes #9567 * Enhance docstrings and type hints in scheduling_ddim.py - Update parameter types and descriptions for clarity - Improve explanations in method docstrings to align with project standards - Add optional annotations for parameters where applicable * Refine type hints and docstrings in scheduling_ddim.py - Update parameter types to use Literal for specific string options - Enhance docstring descriptions for clarity and consistency - Ensure all parameters have appropriate type annotations and defaults * Apply review feedback on scheduling_ddim.py - Replace "prevent singularities" with "avoid numerical instability" for better clarity - Add backticks around `alpha_bar` variable name for consistent formatting - Convert Imagen Video paper URLs to Hugging Face papers references * Propagate changes using 'make fix-copies' * Add missing Literal
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -92,11 +92,10 @@ def betas_for_alpha_bar(
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.Tensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
@@ -143,9 +142,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"]`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Must be one
|
||||
of `"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
clip_sample (`bool`, defaults to `True`):
|
||||
@@ -158,10 +157,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
otherwise it uses the alpha value at step 0.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
An offset added to the inference steps, as required by some model families.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://imagen.research.google/video/paper.pdf) paper).
|
||||
prediction_type (`Literal["epsilon", "sample", "v_prediction"]`, defaults to `"epsilon"`):
|
||||
Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion
|
||||
process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen
|
||||
Video](https://huggingface.co/papers/2210.02303) paper).
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||
as Stable Diffusion.
|
||||
@@ -169,9 +168,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
||||
sample_max_value (`float`, defaults to 1.0):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
timestep_spacing (`str`, defaults to `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
timestep_spacing (`Literal["leading", "trailing", "linspace"]`, defaults to `"leading"`):
|
||||
The way the timesteps should be scaled. Must be one of `"leading"`, `"trailing"`, or `"linspace"`. Refer to
|
||||
Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
@@ -187,17 +187,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
clip_sample_range: float = 1.0,
|
||||
sample_max_value: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
@@ -250,7 +250,25 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
|
||||
"""
|
||||
Computes the variance of the noise added at a given diffusion step.
|
||||
|
||||
For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM
|
||||
literature:
|
||||
var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively.
|
||||
|
||||
Args:
|
||||
timestep (`int`):
|
||||
The current timestep in the diffusion process.
|
||||
prev_timestep (`int`):
|
||||
The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The variance for the current timestep.
|
||||
"""
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
@@ -294,13 +312,18 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`Union[str, torch.device]`, *optional*):
|
||||
The device to use for the timesteps.
|
||||
|
||||
Raises:
|
||||
ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
|
||||
"""
|
||||
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
@@ -346,7 +369,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample: torch.Tensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDIMSchedulerOutput, Tuple]:
|
||||
@@ -357,20 +380,21 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
eta (`float`):
|
||||
The weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`, defaults to `False`):
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
The weight of noise for added noise in diffusion step. A value of 0 corresponds to DDIM (deterministic)
|
||||
and 1 corresponds to DDPM (fully stochastic).
|
||||
use_clipped_model_output (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
||||
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
||||
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
||||
`use_clipped_model_output` has no effect.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
variance_noise (`torch.Tensor`):
|
||||
A random number generator for reproducible sampling.
|
||||
variance_noise (`torch.Tensor`, *optional*):
|
||||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||
itself. Useful for methods such as [`CycleDiffusion`].
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
@@ -517,5 +541,5 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return velocity
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -95,7 +95,6 @@ def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.Tensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -97,7 +97,6 @@ def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.Tensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
@@ -194,17 +193,17 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
clip_sample_range: float = 1.0,
|
||||
sample_max_value: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
@@ -324,6 +323,11 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`Union[str, torch.device]`, *optional*):
|
||||
The device to use for the timesteps.
|
||||
|
||||
Raises:
|
||||
ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
|
||||
"""
|
||||
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
|
||||
@@ -94,7 +94,6 @@ def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.Tensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
@@ -96,7 +96,6 @@ def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.Tensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
@@ -80,7 +80,6 @@ def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.Tensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
@@ -97,7 +97,6 @@ def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.Tensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
@@ -100,7 +100,6 @@ def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.Tensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
@@ -99,7 +99,6 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.Tensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
@@ -98,7 +98,6 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.Tensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
@@ -316,6 +315,24 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler._get_variance
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
"""
|
||||
Computes the variance of the noise added at a given diffusion step.
|
||||
|
||||
For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM
|
||||
literature:
|
||||
var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively.
|
||||
|
||||
Args:
|
||||
timestep (`int`):
|
||||
The current timestep in the diffusion process.
|
||||
prev_timestep (`int`):
|
||||
The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The variance for the current timestep.
|
||||
"""
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
@@ -80,7 +80,6 @@ def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.Tensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
Reference in New Issue
Block a user