From e6110f68569c7b620306e678c3a3d9eee1a293e2 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Thu, 8 Sep 2022 01:07:44 -0600 Subject: [PATCH] [docs sprint] schedulers docs, will update (#376) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * init schedulers docs * add some docstrings, fix sidebar formatting * add docstrings * [Type hint] PNDM schedulers (#335) * [Type hint] PNDM Schedulers * ran make style * updated timesteps type hint * apply suggestions from code review * ran make style * removed unused import * [Type hint] scheduling ddim (#343) * [Type hint] scheduling ddim * apply suggestions from code review apply suggestions to also return the return type Co-authored-by: Patrick von Platen Co-authored-by: Patrick von Platen * make style * update class docstrings * add docstrings * missed merge edit * add general docs page * modify headings for right sidebar Co-authored-by: Partho Co-authored-by: Santiago Víquez Co-authored-by: Patrick von Platen --- docs/source/api/schedulers.mdx | 96 +++++++++++++++++-- src/diffusers/schedulers/scheduling_ddim.py | 65 ++++++++++++- src/diffusers/schedulers/scheduling_ddpm.py | 62 +++++++++++- .../schedulers/scheduling_karras_ve.py | 75 +++++++++++---- .../schedulers/scheduling_lms_discrete.py | 58 +++++++++-- src/diffusers/schedulers/scheduling_pndm.py | 87 +++++++++++++++-- src/diffusers/schedulers/scheduling_sde_ve.py | 74 +++++++++++--- src/diffusers/schedulers/scheduling_sde_vp.py | 9 ++ src/diffusers/schedulers/scheduling_utils.py | 3 + 9 files changed, 468 insertions(+), 61 deletions(-) diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index 5c435dc8e1..1deff1a4bb 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -10,19 +10,95 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Models +# Schedulers + +Diffusers contains multiple pre-built schedule functions for the diffusion process. + +## What is a schduler? +The schedule functions, denoted *Schedulers* in the library take in the output of a trained model, a sample which the diffusion process is iterating on, and a timestep to return a denoised sample. + +- Schedulers define the methodology for iteratively adding noise to an image or for updating a sample based on model outputs. + - adding noise in different manners represent the algorithmic processes to train a diffusion model by adding noise to images. + - for inference, the scheduler defines how to update a sample based on an output from a pretrained model. +- Schedulers are often defined by a *noise schedule* and an *update rule* to solve the differential equation solution. + +### Discrete versus continuous schedulers +All schedulers take in a timestep to predict the updated version of the sample being diffused. +The timesteps dictate where in the diffusion process the step is, where data is generated by iterating forward in time and inference is executed by propagating backwards through timesteps. +Different algorithms use timesteps that both discrete (accepting `int` inputs), such as the [`DDPMScheduler`] or [`PNDMScheduler`], and continuous (accepting 'float` inputs), such as the score-based schedulers [`ScoreSdeVeScheduler`] or [`ScoreSdeVpScheduler`]. + +## Designing Re-usable schedulers +The core design principle between the schedule functions is to be model, system, and framework independent. +This allows for rapid experimentation and cleaner abstractions in the code, where the model prediction is separated from the sample update. +To this end, the design of schedulers is such that: +- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality. +- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Numpy support currently exists). -Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models. -The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$. -The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub. ## API +The core API for any new scheduler must follow a limited structure. +- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively. +- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task. +- Schedulers should be framework-agonstic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch +with a `set_format(...)` method. -Models should provide the `def forward` function and initialization of the model. -All saving, loading, and utilities should be in the base ['ModelMixin'] class. +### Core +The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers. -## Examples +#### SchedulerMixin +[[autodoc]] SchedulerMixin -- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3. -- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991). -- TODO: mention VAE / SDE score estimation \ No newline at end of file +#### SchedulerOutput +The class [`SchedulerOutput`] contains the ouputs from any schedulers `step(...)` call. +[[autodoc]] schedulers.scheduling_utils.SchedulerOutput + +### Existing Schedulers + +#### Denoising diffusion implicit models (DDIM) + +Original paper can be found here. + +[[autodoc]] schedulers.scheduling_ddim.DDIMScheduler + +#### Denoising diffusion probabilistic models (DDPM) + +Original paper can be found [here](https://arxiv.org/abs/2010.02502). + +[[autodoc]] schedulers.scheduling_ddpm.DDPMScheduler + +#### Varience exploding, stochastic sampling from Karras et. al + +Original paper can be found [here](https://arxiv.org/abs/2006.11239). + +[[autodoc]] schedulers.scheduling_karras_ve.KarrasVeScheduler + +#### Linear multistep scheduler for discrete beta schedules + +Original implementation can be found [here](https://arxiv.org/abs/2206.00364). + + +[[autodoc]] schedulers.scheduling_lms_discrete.LMSDiscreteScheduler + +#### Pseudo numerical methods for diffusion models (PNDM) + +Original implementation can be found [here](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181). + +[[autodoc]] schedulers.scheduling_pndm.PNDMScheduler + +#### variance exploding stochastic differential equation (SDE) scheduler + +Original paper can be found [here](https://arxiv.org/abs/2011.13456). + +[[autodoc]] schedulers.scheduling_sde_ve.ScoreSdeVeScheduler + +#### variance preserving stochastic differential equation (SDE) scheduler + +Original paper can be found [here](https://arxiv.org/abs/2011.13456). + + + +Score SDE-VP is under construction. + + + +[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 78c2f7353d..d15c55410c 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -30,11 +30,17 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t - from 0 to 1 and - produces the cumulative product of (1-beta) up to that part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ def alpha_bar(time_step): @@ -49,6 +55,29 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): class DDIMScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): TODO + timestep_values (`np.ndarray`, optional): TODO + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + set_alpha_to_one (`bool`, default `True`): + if alpha for final step is 1 or the final alpha of the "non-previous" one. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + @register_to_config def __init__( self, @@ -62,7 +91,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): set_alpha_to_one: bool = True, tensor_format: str = "pt", ): - + if trained_betas is not None: + self.betas = np.asarray(trained_betas) if beta_schedule == "linear": self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) elif beta_schedule == "scaled_linear": @@ -101,6 +131,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): return variance def set_timesteps(self, num_inference_steps: int, offset: int = 0): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + offset (`int`): TODO + """ self.num_inference_steps = num_inference_steps self.timesteps = np.arange( 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps @@ -118,7 +156,24 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): TODO + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + `SchedulerOutput`: updated sample in the diffusion chain. + + """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index b0f4b0819d..bdd87f508e 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -29,11 +29,17 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t - from 0 to 1 and - produces the cumulative product of (1-beta) up to that part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ def alpha_bar(time_step): @@ -48,6 +54,29 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): class DDPMScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. + + For more details, see the original paper: https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): TODO + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + @register_to_config def __init__( self, @@ -88,6 +117,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): self.variance_type = variance_type def set_timesteps(self, num_inference_steps: int): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) self.num_inference_steps = num_inference_steps self.timesteps = np.arange( @@ -137,7 +173,25 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + predict_epsilon (`bool`): + optional flag to use when model predicts the samples directly instead of the noise, epsilon. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + `SchedulerOutput`: updated sample in the diffusion chain. + + """ t = timestep if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 227996ec89..0352be6e3e 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -49,6 +49,24 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 + + For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of + Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the + optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. + + Args: + sigma_min (`float`): minimum noise magnitude + sigma_max (`float`): maximum noise magnitude + s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. + A reasonable range is [1.000, 1.011]. + s_churn (`float`): the parameter controlling the overall amount of stochasticity. + A reasonable range is [0, 100]. + s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). + A reasonable range is [0, 10]. + s_max (`float`): the end value of the sigma range where we add noise. + A reasonable range is [0.2, 80]. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + """ @register_to_config @@ -62,23 +80,6 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): s_max: float = 50, tensor_format: str = "pt", ): - """ - For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of - Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the - optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. - - Args: - sigma_min (`float`): minimum noise magnitude - sigma_max (`float`): maximum noise magnitude - s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. - A reasonable range is [1.000, 1.011]. - s_churn (`float`): the parameter controlling the overall amount of stochasticity. - A reasonable range is [0, 100]. - s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). - A reasonable range is [0, 10]. - s_max (`float`): the end value of the sigma range where we add noise. - A reasonable range is [0.2, 80]. - """ # setable values self.num_inference_steps = None self.timesteps = None @@ -88,6 +89,14 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): self.set_format(tensor_format=tensor_format) def set_timesteps(self, num_inference_steps: int): + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + + """ self.num_inference_steps = num_inference_steps self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.schedule = [ @@ -104,6 +113,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): """ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. + + TODO Args: """ if self.s_min <= sigma <= self.s_max: gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1) @@ -125,6 +136,21 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): sample_hat: Union[torch.FloatTensor, np.ndarray], return_dict: bool = True, ) -> Union[KarrasVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). + + """ pred_original_sample = sample_hat + sigma_hat * model_output derivative = (sample_hat - pred_original_sample) / sigma_hat @@ -145,7 +171,22 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): derivative: Union[torch.FloatTensor, np.ndarray], return_dict: bool = True, ) -> Union[KarrasVeOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. TODO complete description + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO + sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO + derivative (`torch.FloatTensor` or `np.ndarray`): TODO + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO + + """ pred_original_sample = sample_prev + sigma_prev * model_output derivative_corr = (sample_prev - pred_original_sample) / sigma_prev sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index b35cb7bd1d..31d482ae59 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -24,6 +24,26 @@ from .scheduling_utils import SchedulerMixin, SchedulerOutput class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): TODO + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + timestep_values (`np.ndarry`, optional): TODO + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + @register_to_config def __init__( self, @@ -35,12 +55,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): timestep_values: Optional[np.ndarray] = None, tensor_format: str = "pt", ): - """ - Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by - Katherine Crowson: - https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 - """ - + if trained_betas is not None: + self.betas = np.asarray(trained_betas) if beta_schedule == "linear": self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) elif beta_schedule == "scaled_linear": @@ -64,7 +80,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): def get_lms_coefficient(self, order, t, current_order): """ - Compute a linear multistep coefficient + Compute a linear multistep coefficient. + + Args: + order (TODO): + t (TODO): + current_order (TODO): """ def lms_derivative(tau): @@ -80,6 +101,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): return integrated_coeff def set_timesteps(self, num_inference_steps: int): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ self.num_inference_steps = num_inference_steps self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) @@ -102,6 +130,22 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): order: int = 4, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + order: coefficient for multi-step inference. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain. + + """ sigma = self.sigmas[timestep] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index a8778fed5d..171b509898 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -15,7 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math -from typing import Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch @@ -29,11 +29,17 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t - from 0 to 1 and - produces the cumulative product of (1-beta) up to that part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ def alpha_bar(time_step): @@ -48,6 +54,27 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): class PNDMScheduler(SchedulerMixin, ConfigMixin): + """ + Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, + namely Runge-Kutta method and a linear multi-step method. + + For more details, see the original paper: https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): TODO + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays + skip_prk_steps (`bool`): + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms steps; defaults to `False`. + + """ + @register_to_config def __init__( self, @@ -55,10 +82,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, tensor_format: str = "pt", skip_prk_steps: bool = False, ): - + if trained_betas is not None: + self.betas = np.asarray(trained_betas) if beta_schedule == "linear": self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) elif beta_schedule == "scaled_linear": @@ -98,6 +127,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): self.set_format(tensor_format=tensor_format) def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + offset (`int`): TODO + """ self.num_inference_steps = num_inference_steps self._timesteps = list( range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) @@ -135,7 +172,23 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): sample: Union[torch.FloatTensor, np.ndarray], return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + `SchedulerOutput`: updated sample in the diffusion chain. + + """ if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) else: @@ -151,6 +204,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): """ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the solution to the differential equation. + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain. + """ if self.num_inference_steps is None: raise ValueError( @@ -194,6 +258,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain. + """ if self.num_inference_steps is None: raise ValueError( diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index f6b0ba936e..7e203db673 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -47,12 +47,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): """ The variance exploding stochastic differential equation (SDE) scheduler. - :param snr: coefficient weighting the step from the model_output sample (from the network) to the random noise. - :param sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the - distribution of the data. - :param sigma_max: :param sampling_eps: the end value of sampling, where timesteps decrease progessively from 1 to - epsilon. :param correct_steps: number of correction steps performed on a produced sample. :param tensor_format: - "np" or "pt" for the expected format of samples passed to the Scheduler. + For more information, see the original paper: https://arxiv.org/abs/2011.13456 + + Args: + snr (`float`): + coefficient weighting the step from the model_output sample (from the network) to the random noise. + sigma_min (`float`): + initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the + distribution of the data. + sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. + sampling_eps (`float`): the end value of sampling, where timesteps decrease progessively from 1 to + epsilon. + correct_steps (`int`): number of correction steps performed on a produced sample. + tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler. """ @register_to_config @@ -66,11 +73,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): correct_steps=1, tensor_format="pt", ): - # self.sigmas = None - # self.discrete_sigmas = None - # - # # setable values - # self.num_inference_steps = None + # setable values self.timesteps = None self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) @@ -79,6 +82,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): self.set_format(tensor_format=tensor_format) def set_timesteps(self, num_inference_steps, sampling_eps=None): + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). + + """ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps tensor_format = getattr(self, "tensor_format", "pt") if tensor_format == "np": @@ -89,6 +101,20 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None): + """ + Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. + + The sigmas control the weight of the `drift` and `diffusion` components of sample update. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sigma_min (`float`, optional): + initial noise scale value (overrides value given at Scheduler instantiation). + sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). + sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). + + """ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps @@ -140,7 +166,20 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): **kwargs, ) -> Union[SdeVeOutput, Tuple]: """ - Predict the sample at the previous timestep by reversing the SDE. + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain. + """ if "seed" in kwargs and kwargs["seed"] is not None: self.set_seed(kwargs["seed"]) @@ -186,6 +225,17 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): """ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly after making the prediction for the previous timestep. + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain. + """ if "seed" in kwargs and kwargs["seed"] is not None: self.set_seed(kwargs["seed"]) diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index abb1d15bd3..d3482f4b00 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -24,6 +24,15 @@ from .scheduling_utils import SchedulerMixin class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): + """ + The variance preserving stochastic differential equation (SDE) scheduler. + + For more information, see the original paper: https://arxiv.org/abs/2011.13456 + + UNDER CONSTRUCTION + + """ + @register_to_config def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 7d176e6366..f2bcd73acf 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -38,6 +38,9 @@ class SchedulerOutput(BaseOutput): class SchedulerMixin: + """ + Mixin containing common functions for the schedulers. + """ config_name = SCHEDULER_CONFIG_NAME ignore_for_config = ["tensor_format"]