mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[docs sprint] schedulers docs, will update (#376)
* init schedulers docs * add some docstrings, fix sidebar formatting * add docstrings * [Type hint] PNDM schedulers (#335) * [Type hint] PNDM Schedulers * ran make style * updated timesteps type hint * apply suggestions from code review * ran make style * removed unused import * [Type hint] scheduling ddim (#343) * [Type hint] scheduling ddim * apply suggestions from code review apply suggestions to also return the return type Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * make style * update class docstrings * add docstrings * missed merge edit * add general docs page * modify headings for right sidebar Co-authored-by: Partho <parthodas6176@gmail.com> Co-authored-by: Santiago Víquez <santi.viquez@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -10,19 +10,95 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Models
|
||||
# Schedulers
|
||||
|
||||
Diffusers contains multiple pre-built schedule functions for the diffusion process.
|
||||
|
||||
## What is a schduler?
|
||||
The schedule functions, denoted *Schedulers* in the library take in the output of a trained model, a sample which the diffusion process is iterating on, and a timestep to return a denoised sample.
|
||||
|
||||
- Schedulers define the methodology for iteratively adding noise to an image or for updating a sample based on model outputs.
|
||||
- adding noise in different manners represent the algorithmic processes to train a diffusion model by adding noise to images.
|
||||
- for inference, the scheduler defines how to update a sample based on an output from a pretrained model.
|
||||
- Schedulers are often defined by a *noise schedule* and an *update rule* to solve the differential equation solution.
|
||||
|
||||
### Discrete versus continuous schedulers
|
||||
All schedulers take in a timestep to predict the updated version of the sample being diffused.
|
||||
The timesteps dictate where in the diffusion process the step is, where data is generated by iterating forward in time and inference is executed by propagating backwards through timesteps.
|
||||
Different algorithms use timesteps that both discrete (accepting `int` inputs), such as the [`DDPMScheduler`] or [`PNDMScheduler`], and continuous (accepting 'float` inputs), such as the score-based schedulers [`ScoreSdeVeScheduler`] or [`ScoreSdeVpScheduler`].
|
||||
|
||||
## Designing Re-usable schedulers
|
||||
The core design principle between the schedule functions is to be model, system, and framework independent.
|
||||
This allows for rapid experimentation and cleaner abstractions in the code, where the model prediction is separated from the sample update.
|
||||
To this end, the design of schedulers is such that:
|
||||
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
|
||||
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Numpy support currently exists).
|
||||
|
||||
Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models.
|
||||
The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
|
||||
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
|
||||
|
||||
## API
|
||||
The core API for any new scheduler must follow a limited structure.
|
||||
- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively.
|
||||
- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task.
|
||||
- Schedulers should be framework-agonstic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
|
||||
with a `set_format(...)` method.
|
||||
|
||||
Models should provide the `def forward` function and initialization of the model.
|
||||
All saving, loading, and utilities should be in the base ['ModelMixin'] class.
|
||||
### Core
|
||||
The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.
|
||||
|
||||
## Examples
|
||||
#### SchedulerMixin
|
||||
[[autodoc]] SchedulerMixin
|
||||
|
||||
- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3.
|
||||
- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991).
|
||||
- TODO: mention VAE / SDE score estimation
|
||||
#### SchedulerOutput
|
||||
The class [`SchedulerOutput`] contains the ouputs from any schedulers `step(...)` call.
|
||||
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
|
||||
|
||||
### Existing Schedulers
|
||||
|
||||
#### Denoising diffusion implicit models (DDIM)
|
||||
|
||||
Original paper can be found here.
|
||||
|
||||
[[autodoc]] schedulers.scheduling_ddim.DDIMScheduler
|
||||
|
||||
#### Denoising diffusion probabilistic models (DDPM)
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2010.02502).
|
||||
|
||||
[[autodoc]] schedulers.scheduling_ddpm.DDPMScheduler
|
||||
|
||||
#### Varience exploding, stochastic sampling from Karras et. al
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2006.11239).
|
||||
|
||||
[[autodoc]] schedulers.scheduling_karras_ve.KarrasVeScheduler
|
||||
|
||||
#### Linear multistep scheduler for discrete beta schedules
|
||||
|
||||
Original implementation can be found [here](https://arxiv.org/abs/2206.00364).
|
||||
|
||||
|
||||
[[autodoc]] schedulers.scheduling_lms_discrete.LMSDiscreteScheduler
|
||||
|
||||
#### Pseudo numerical methods for diffusion models (PNDM)
|
||||
|
||||
Original implementation can be found [here](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181).
|
||||
|
||||
[[autodoc]] schedulers.scheduling_pndm.PNDMScheduler
|
||||
|
||||
#### variance exploding stochastic differential equation (SDE) scheduler
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2011.13456).
|
||||
|
||||
[[autodoc]] schedulers.scheduling_sde_ve.ScoreSdeVeScheduler
|
||||
|
||||
#### variance preserving stochastic differential equation (SDE) scheduler
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2011.13456).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Score SDE-VP is under construction.
|
||||
|
||||
</Tip>
|
||||
|
||||
[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler
|
||||
|
||||
@@ -30,11 +30,17 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
|
||||
from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
@@ -49,6 +55,29 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
|
||||
|
||||
class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
|
||||
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2010.02502
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
trained_betas (`np.ndarray`, optional): TODO
|
||||
timestep_values (`np.ndarray`, optional): TODO
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
set_alpha_to_one (`bool`, default `True`):
|
||||
if alpha for final step is 1 or the final alpha of the "non-previous" one.
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
||||
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -62,7 +91,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
set_alpha_to_one: bool = True,
|
||||
tensor_format: str = "pt",
|
||||
):
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = np.asarray(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
@@ -101,6 +131,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
return variance
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
offset (`int`): TODO
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.arange(
|
||||
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
|
||||
@@ -118,7 +156,24 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
eta (`float`): weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`): TODO
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
Returns:
|
||||
`SchedulerOutput`: updated sample in the diffusion chain.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
|
||||
@@ -29,11 +29,17 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
|
||||
from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
@@ -48,6 +54,29 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
|
||||
|
||||
class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
|
||||
Langevin dynamics sampling.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2006.11239
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
trained_betas (`np.ndarray`, optional): TODO
|
||||
variance_type (`str`):
|
||||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
||||
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -88,6 +117,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.variance_type = variance_type
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.arange(
|
||||
@@ -137,7 +173,25 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
eta (`float`): weight of noise for added noise in diffusion step.
|
||||
predict_epsilon (`bool`):
|
||||
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
Returns:
|
||||
`SchedulerOutput`: updated sample in the diffusion chain.
|
||||
|
||||
"""
|
||||
t = timestep
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
|
||||
@@ -49,6 +49,24 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
|
||||
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
|
||||
differential equations." https://arxiv.org/abs/2011.13456
|
||||
|
||||
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
|
||||
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
|
||||
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
|
||||
|
||||
Args:
|
||||
sigma_min (`float`): minimum noise magnitude
|
||||
sigma_max (`float`): maximum noise magnitude
|
||||
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
|
||||
A reasonable range is [1.000, 1.011].
|
||||
s_churn (`float`): the parameter controlling the overall amount of stochasticity.
|
||||
A reasonable range is [0, 100].
|
||||
s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
|
||||
A reasonable range is [0, 10].
|
||||
s_max (`float`): the end value of the sigma range where we add noise.
|
||||
A reasonable range is [0.2, 80].
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
||||
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -62,23 +80,6 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
s_max: float = 50,
|
||||
tensor_format: str = "pt",
|
||||
):
|
||||
"""
|
||||
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
|
||||
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
|
||||
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
|
||||
|
||||
Args:
|
||||
sigma_min (`float`): minimum noise magnitude
|
||||
sigma_max (`float`): maximum noise magnitude
|
||||
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
|
||||
A reasonable range is [1.000, 1.011].
|
||||
s_churn (`float`): the parameter controlling the overall amount of stochasticity.
|
||||
A reasonable range is [0, 100].
|
||||
s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
|
||||
A reasonable range is [0, 10].
|
||||
s_max (`float`): the end value of the sigma range where we add noise.
|
||||
A reasonable range is [0.2, 80].
|
||||
"""
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = None
|
||||
@@ -88,6 +89,14 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int):
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
|
||||
self.schedule = [
|
||||
@@ -104,6 +113,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
|
||||
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
|
||||
|
||||
TODO Args:
|
||||
"""
|
||||
if self.s_min <= sigma <= self.s_max:
|
||||
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
|
||||
@@ -125,6 +136,21 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample_hat: Union[torch.FloatTensor, np.ndarray],
|
||||
return_dict: bool = True,
|
||||
) -> Union[KarrasVeOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
sigma_hat (`float`): TODO
|
||||
sigma_prev (`float`): TODO
|
||||
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
Returns:
|
||||
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
|
||||
|
||||
"""
|
||||
|
||||
pred_original_sample = sample_hat + sigma_hat * model_output
|
||||
derivative = (sample_hat - pred_original_sample) / sigma_hat
|
||||
@@ -145,7 +171,22 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
derivative: Union[torch.FloatTensor, np.ndarray],
|
||||
return_dict: bool = True,
|
||||
) -> Union[KarrasVeOutput, Tuple]:
|
||||
"""
|
||||
Correct the predicted sample based on the output model_output of the network. TODO complete description
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
sigma_hat (`float`): TODO
|
||||
sigma_prev (`float`): TODO
|
||||
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
Returns:
|
||||
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
|
||||
|
||||
"""
|
||||
pred_original_sample = sample_prev + sigma_prev * model_output
|
||||
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
|
||||
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
|
||||
|
||||
@@ -24,6 +24,26 @@ from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
|
||||
Katherine Crowson:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, optional): TODO
|
||||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
timestep_values (`np.ndarry`, optional): TODO
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
||||
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -35,12 +55,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep_values: Optional[np.ndarray] = None,
|
||||
tensor_format: str = "pt",
|
||||
):
|
||||
"""
|
||||
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
|
||||
Katherine Crowson:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
|
||||
"""
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = np.asarray(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
@@ -64,7 +80,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def get_lms_coefficient(self, order, t, current_order):
|
||||
"""
|
||||
Compute a linear multistep coefficient
|
||||
Compute a linear multistep coefficient.
|
||||
|
||||
Args:
|
||||
order (TODO):
|
||||
t (TODO):
|
||||
current_order (TODO):
|
||||
"""
|
||||
|
||||
def lms_derivative(tau):
|
||||
@@ -80,6 +101,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
return integrated_coeff
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
|
||||
|
||||
@@ -102,6 +130,22 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
order: int = 4,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
order: coefficient for multi-step inference.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
Returns:
|
||||
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
|
||||
|
||||
"""
|
||||
sigma = self.sigmas[timestep]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -29,11 +29,17 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
|
||||
from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
@@ -48,6 +54,27 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
|
||||
|
||||
class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
|
||||
namely Runge-Kutta method and a linear multi-step method.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2202.09778
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
trained_betas (`np.ndarray`, optional): TODO
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
|
||||
skip_prk_steps (`bool`):
|
||||
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
|
||||
before plms steps; defaults to `False`.
|
||||
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -55,10 +82,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
tensor_format: str = "pt",
|
||||
skip_prk_steps: bool = False,
|
||||
):
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = np.asarray(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
@@ -98,6 +127,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
offset (`int`): TODO
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self._timesteps = list(
|
||||
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
|
||||
@@ -135,7 +172,23 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
Returns:
|
||||
`SchedulerOutput`: updated sample in the diffusion chain.
|
||||
|
||||
"""
|
||||
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
|
||||
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
|
||||
else:
|
||||
@@ -151,6 +204,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
|
||||
solution to the differential equation.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
Returns:
|
||||
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
@@ -194,6 +258,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
|
||||
times to approximate the solution.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
Returns:
|
||||
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -47,12 +47,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
The variance exploding stochastic differential equation (SDE) scheduler.
|
||||
|
||||
:param snr: coefficient weighting the step from the model_output sample (from the network) to the random noise.
|
||||
:param sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
|
||||
distribution of the data.
|
||||
:param sigma_max: :param sampling_eps: the end value of sampling, where timesteps decrease progessively from 1 to
|
||||
epsilon. :param correct_steps: number of correction steps performed on a produced sample. :param tensor_format:
|
||||
"np" or "pt" for the expected format of samples passed to the Scheduler.
|
||||
For more information, see the original paper: https://arxiv.org/abs/2011.13456
|
||||
|
||||
Args:
|
||||
snr (`float`):
|
||||
coefficient weighting the step from the model_output sample (from the network) to the random noise.
|
||||
sigma_min (`float`):
|
||||
initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
|
||||
distribution of the data.
|
||||
sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
|
||||
sampling_eps (`float`): the end value of sampling, where timesteps decrease progessively from 1 to
|
||||
epsilon.
|
||||
correct_steps (`int`): number of correction steps performed on a produced sample.
|
||||
tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -66,11 +73,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
correct_steps=1,
|
||||
tensor_format="pt",
|
||||
):
|
||||
# self.sigmas = None
|
||||
# self.discrete_sigmas = None
|
||||
#
|
||||
# # setable values
|
||||
# self.num_inference_steps = None
|
||||
# setable values
|
||||
self.timesteps = None
|
||||
|
||||
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
|
||||
@@ -79,6 +82,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps, sampling_eps=None):
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
|
||||
|
||||
"""
|
||||
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
if tensor_format == "np":
|
||||
@@ -89,6 +101,20 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None):
|
||||
"""
|
||||
Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
The sigmas control the weight of the `drift` and `diffusion` components of sample update.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
sigma_min (`float`, optional):
|
||||
initial noise scale value (overrides value given at Scheduler instantiation).
|
||||
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
|
||||
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
|
||||
|
||||
"""
|
||||
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
|
||||
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
|
||||
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
|
||||
@@ -140,7 +166,20 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
**kwargs,
|
||||
) -> Union[SdeVeOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE.
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
Returns:
|
||||
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
|
||||
|
||||
"""
|
||||
if "seed" in kwargs and kwargs["seed"] is not None:
|
||||
self.set_seed(kwargs["seed"])
|
||||
@@ -186,6 +225,17 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
|
||||
after making the prediction for the previous timestep.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
Returns:
|
||||
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
|
||||
|
||||
"""
|
||||
if "seed" in kwargs and kwargs["seed"] is not None:
|
||||
self.set_seed(kwargs["seed"])
|
||||
|
||||
@@ -24,6 +24,15 @@ from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
The variance preserving stochastic differential equation (SDE) scheduler.
|
||||
|
||||
For more information, see the original paper: https://arxiv.org/abs/2011.13456
|
||||
|
||||
UNDER CONSTRUCTION
|
||||
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
|
||||
|
||||
|
||||
@@ -38,6 +38,9 @@ class SchedulerOutput(BaseOutput):
|
||||
|
||||
|
||||
class SchedulerMixin:
|
||||
"""
|
||||
Mixin containing common functions for the schedulers.
|
||||
"""
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
ignore_for_config = ["tensor_format"]
|
||||
|
||||
Reference in New Issue
Block a user