1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Improve docstrings and type hints in scheduling_ddim.py (#12622)

* Improve docstrings and type hints in scheduling_ddim.py

- Add complete type hints for all function parameters
- Enhance docstrings to follow project conventions
- Add missing parameter descriptions

Fixes #9567

* Enhance docstrings and type hints in scheduling_ddim.py

- Update parameter types and descriptions for clarity
- Improve explanations in method docstrings to align with project standards
- Add optional annotations for parameters where applicable

* Refine type hints and docstrings in scheduling_ddim.py

- Update parameter types to use Literal for specific string options
- Enhance docstring descriptions for clarity and consistency
- Ensure all parameters have appropriate type annotations and defaults

* Apply review feedback on scheduling_ddim.py

- Replace "prevent singularities" with "avoid numerical instability" for better clarity
- Add backticks around `alpha_bar` variable name for consistent formatting
- Convert Imagen Video paper URLs to Hugging Face papers references

* Propagate changes using 'make fix-copies'

* Add missing Literal
This commit is contained in:
David El Malih
2025-11-13 23:45:58 +01:00
committed by GitHub
parent 40de88af8c
commit 6fe4a6ff8e
11 changed files with 77 additions and 40 deletions

View File

@@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -92,11 +92,10 @@ def betas_for_alpha_bar(
return torch.tensor(betas, dtype=torch.float32)
def rescale_zero_terminal_snr(betas):
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
@@ -143,9 +142,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"]`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Must be one
of `"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
clip_sample (`bool`, defaults to `True`):
@@ -158,10 +157,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
otherwise it uses the alpha value at step 0.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
prediction_type (`Literal["epsilon", "sample", "v_prediction"]`, defaults to `"epsilon"`):
Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion
process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -169,9 +168,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
timestep_spacing (`Literal["leading", "trailing", "linspace"]`, defaults to `"leading"`):
The way the timesteps should be scaled. Must be one of `"leading"`, `"trailing"`, or `"linspace"`. Refer to
Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891) for more information.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -187,17 +187,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
@@ -250,7 +250,25 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def _get_variance(self, timestep, prev_timestep):
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
"""
Computes the variance of the noise added at a given diffusion step.
For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM
literature:
var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively.
Args:
timestep (`int`):
The current timestep in the diffusion process.
prev_timestep (`int`):
The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`.
Returns:
`torch.Tensor`:
The variance for the current timestep.
"""
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
@@ -294,13 +312,18 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`Union[str, torch.device]`, *optional*):
The device to use for the timesteps.
Raises:
ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
"""
if num_inference_steps > self.config.num_train_timesteps:
@@ -346,7 +369,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
generator: Optional[torch.Generator] = None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:
@@ -357,20 +380,21 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
The weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`, defaults to `False`):
eta (`float`, *optional*, defaults to 0.0):
The weight of noise for added noise in diffusion step. A value of 0 corresponds to DDIM (deterministic)
and 1 corresponds to DDPM (fully stochastic).
use_clipped_model_output (`bool`, *optional*, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.Tensor`):
A random number generator for reproducible sampling.
variance_noise (`torch.Tensor`, *optional*):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
@@ -517,5 +541,5 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -95,7 +95,6 @@ def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.

View File

@@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -97,7 +97,6 @@ def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
@@ -194,17 +193,17 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
@@ -324,6 +323,11 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`Union[str, torch.device]`, *optional*):
The device to use for the timesteps.
Raises:
ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
"""
if num_inference_steps > self.config.num_train_timesteps:

View File

@@ -94,7 +94,6 @@ def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.

View File

@@ -96,7 +96,6 @@ def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.

View File

@@ -80,7 +80,6 @@ def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.

View File

@@ -97,7 +97,6 @@ def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.

View File

@@ -100,7 +100,6 @@ def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.

View File

@@ -99,7 +99,6 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.

View File

@@ -98,7 +98,6 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
@@ -316,6 +315,24 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler._get_variance
def _get_variance(self, timestep, prev_timestep):
"""
Computes the variance of the noise added at a given diffusion step.
For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM
literature:
var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively.
Args:
timestep (`int`):
The current timestep in the diffusion process.
prev_timestep (`int`):
The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`.
Returns:
`torch.Tensor`:
The variance for the current timestep.
"""
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t

View File

@@ -80,7 +80,6 @@ def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.