1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Improve docstrings and type hints in scheduling_pndm.py (#12676)

* Enhance docstrings and type hints in PNDMScheduler class

- Updated parameter descriptions to include default values and specific types using Literal for better clarity.
- Improved docstring formatting and consistency across methods, including detailed explanations for the `_get_prev_sample` method.
- Added type hints for method return types to enhance code readability and maintainability.

* Refactor docstring in PNDMScheduler class to enhance clarity

- Simplified the explanation of the method for computing the previous sample from the current sample.
- Updated the reference to the PNDM paper for better accessibility.
- Removed redundant notation explanations to streamline the documentation.
This commit is contained in:
David El Malih
2025-11-19 18:36:41 +01:00
committed by GitHub
parent a96b145304
commit 15370f8412

View File

@@ -79,15 +79,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
beta_start (`float`, defaults to `0.0001`):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
beta_end (`float`, defaults to `0.02`):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
skip_prk_steps (`bool`, defaults to `False`):
@@ -97,14 +96,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0.
prediction_type (`str`, defaults to `epsilon`, *optional*):
prediction_type (`"epsilon"` or `"v_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process)
or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf)
paper).
timestep_spacing (`str`, defaults to `"leading"`):
or `v_prediction` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
steps_offset (`int`, defaults to `0`):
An offset added to the inference steps, as required by some model families.
"""
@@ -117,12 +115,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "leading",
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading",
steps_offset: int = 0,
):
if trained_betas is not None:
@@ -164,7 +162,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.plms_timesteps = None
self.timesteps = None
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -243,7 +241,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
@@ -276,14 +274,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
@@ -335,14 +332,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
@@ -403,19 +399,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://huggingface.co/papers/2202.09778
# this function computes x_(tδ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
def _get_prev_sample(
self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor
) -> torch.Tensor:
"""
Compute the previous sample x_(t-δ) from the current sample x_t using formula (9) from the [PNDM
paper](https://huggingface.co/papers/2202.09778).
# Notation (<variable name> -> <name in paper>
# alpha_prod_t -> α_t
# alpha_prod_t_prev -> α_(tδ)
# beta_prod_t -> (1 - α_t)
# beta_prod_t_prev -> (1 - α_(tδ))
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(tδ)
Args:
sample (`torch.Tensor`):
The current sample x_t.
timestep (`int`):
The current timestep t.
prev_timestep (`int`):
The previous timestep (t-δ).
model_output (`torch.Tensor`):
The model output e_θ(x_t, t).
Returns:
`torch.Tensor`:
The previous sample x_(t-δ).
"""
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
@@ -489,5 +493,5 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps