mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
feat : add log-rho deis multistep scheduler (#1432)
* feat : add log-rho deis multistep deis * docs :fix typo * docs : add docs for impl algo * docs : remove duplicate ref * finish deis * add docs * fix Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -155,6 +155,8 @@
|
||||
title: "DDIM"
|
||||
- local: api/schedulers/ddpm
|
||||
title: "DDPM"
|
||||
- local: api/schedulers/deis
|
||||
title: "DEIS"
|
||||
- local: api/schedulers/singlestep_dpm_solver
|
||||
title: "Singlestep DPM-Solver"
|
||||
- local: api/schedulers/multistep_dpm_solver
|
||||
|
||||
22
docs/source/en/api/schedulers/deis.mdx
Normal file
22
docs/source/en/api/schedulers/deis.mdx
Normal file
@@ -0,0 +1,22 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# DEIS
|
||||
|
||||
Fast Sampling of Diffusion Models with Exponential Integrator.
|
||||
|
||||
## Overview
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2204.13902). The original implementation can be found [here](https://github.com/qsh-zh/deis).
|
||||
|
||||
## DEISMultistepScheduler
|
||||
[[autodoc]] DEISMultistepScheduler
|
||||
@@ -67,6 +67,7 @@ else:
|
||||
from .schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
|
||||
@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
from .scheduling_ddpm import DDPMScheduler
|
||||
from .scheduling_deis_multistep import DEISMultistepScheduler
|
||||
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
||||
from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
|
||||
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
|
||||
|
||||
481
src/diffusers/schedulers/scheduling_deis_multistep.py
Normal file
481
src/diffusers/schedulers/scheduling_deis_multistep.py
Normal file
@@ -0,0 +1,481 @@
|
||||
# Copyright 2022 FLAIR Lab and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: check https://arxiv.org/abs/2204.13902 and https://github.com/qsh-zh/deis for more info
|
||||
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
DEIS (https://arxiv.org/abs/2204.13902) is a fast high order solver for diffusion ODEs. We slightly modify the
|
||||
polynomial fitting formula in log-rho space instead of the original linear t space in DEIS paper. The modification
|
||||
enjoys closed-form coefficients for exponential multistep update instead of replying on the numerical solver. More
|
||||
variants of DEIS can be found in https://github.com/qsh-zh/deis.
|
||||
|
||||
Currently, we support the log-rho multistep DEIS. We recommend to use `solver_order=2 / 3` while `solver_order=1`
|
||||
reduces to DDIM.
|
||||
|
||||
We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
|
||||
diffusion models, you can set `thresholding=True` to use the dynamic thresholding.
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
solver_order (`int`, default `2`):
|
||||
the order of DEIS; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and
|
||||
`solver_order=3` for unconditional sampling.
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
|
||||
or `v-prediction`.
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
|
||||
stable-diffusion).
|
||||
dynamic_thresholding_ratio (`float`, default `0.995`):
|
||||
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
(https://arxiv.org/abs/2205.11487).
|
||||
sample_max_value (`float`, default `1.0`):
|
||||
the threshold value for dynamic thresholding. Valid woks when `thresholding=True`
|
||||
algorithm_type (`str`, default `deis`):
|
||||
the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in
|
||||
the future
|
||||
lower_order_final (`bool`, default `True`):
|
||||
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
|
||||
find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10.
|
||||
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "deis",
|
||||
solver_type: str = "logrho",
|
||||
lower_order_final: bool = True,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
# Currently we only support VP-type noise schedule
|
||||
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
||||
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
||||
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# settings for DEIS
|
||||
if algorithm_type not in ["deis"]:
|
||||
if algorithm_type in ["dpmsolver", "dpmsolver++"]:
|
||||
algorithm_type = "deis"
|
||||
else:
|
||||
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
|
||||
|
||||
if solver_type not in ["logrho"]:
|
||||
if solver_type in ["midpoint", "heun"]:
|
||||
solver_type = "logrho"
|
||||
else:
|
||||
raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}")
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, optional):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timesteps = (
|
||||
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
.round()[::-1][:-1]
|
||||
.copy()
|
||||
.astype(np.int64)
|
||||
)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
self.model_outputs = [
|
||||
None,
|
||||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type that the algorithm DEIS needs.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: the converted model output.
|
||||
"""
|
||||
if self.config.prediction_type == "epsilon":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction` for the DEISMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
dynamic_max_val = torch.quantile(
|
||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
|
||||
)
|
||||
dynamic_max_val = torch.maximum(
|
||||
dynamic_max_val,
|
||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
x0_pred = x0_pred.type(orig_dtype)
|
||||
|
||||
if self.config.algorithm_type == "deis":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
return (sample - alpha_t * x0_pred) / sigma_t
|
||||
else:
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
|
||||
def deis_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the first-order DEIS (equivalent to DDIM).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: the sample tensor at the previous timestep.
|
||||
"""
|
||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||
sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "deis":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
else:
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
return x_t
|
||||
|
||||
def multistep_deis_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the second-order multistep DEIS.
|
||||
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`): current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: the sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1]
|
||||
sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1]
|
||||
|
||||
rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1
|
||||
|
||||
if self.config.algorithm_type == "deis":
|
||||
|
||||
def ind_fn(t, b, c):
|
||||
# Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}]
|
||||
return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c))
|
||||
|
||||
coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1)
|
||||
coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0)
|
||||
|
||||
x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1)
|
||||
return x_t
|
||||
else:
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
|
||||
def multistep_deis_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the third-order multistep DEIS.
|
||||
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`): current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: the sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2]
|
||||
sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2]
|
||||
rho_t, rho_s0, rho_s1, rho_s2 = (
|
||||
sigma_t / alpha_t,
|
||||
sigma_s0 / alpha_s0,
|
||||
sigma_s1 / alpha_s1,
|
||||
simga_s2 / alpha_s2,
|
||||
)
|
||||
|
||||
if self.config.algorithm_type == "deis":
|
||||
|
||||
def ind_fn(t, b, c, d):
|
||||
# Integrate[(log(t) - log(c))(log(t) - log(d)) / (log(b) - log(c))(log(b) - log(d)), {t}]
|
||||
numerator = t * (
|
||||
np.log(c) * (np.log(d) - np.log(t) + 1)
|
||||
- np.log(d) * np.log(t)
|
||||
+ np.log(d)
|
||||
+ np.log(t) ** 2
|
||||
- 2 * np.log(t)
|
||||
+ 2
|
||||
)
|
||||
denominator = (np.log(b) - np.log(c)) * (np.log(b) - np.log(d))
|
||||
return numerator / denominator
|
||||
|
||||
coef1 = ind_fn(rho_t, rho_s0, rho_s1, rho_s2) - ind_fn(rho_s0, rho_s0, rho_s1, rho_s2)
|
||||
coef2 = ind_fn(rho_t, rho_s1, rho_s2, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s2, rho_s0)
|
||||
coef3 = ind_fn(rho_t, rho_s2, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s2, rho_s0, rho_s1)
|
||||
|
||||
x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1 + coef3 * m2)
|
||||
|
||||
return x_t
|
||||
else:
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Step function propagating the sample with the multistep DEIS.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
|
||||
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
lower_order_final = (
|
||||
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
lower_order_second = (
|
||||
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
|
||||
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||||
prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample)
|
||||
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||
timestep_list = [self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_deis_second_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
else:
|
||||
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_deis_third_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -174,9 +174,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
|
||||
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
|
||||
if algorithm_type == "deis":
|
||||
algorithm_type = "dpmsolver++"
|
||||
else:
|
||||
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
|
||||
if solver_type not in ["midpoint", "heun"]:
|
||||
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
|
||||
if solver_type == "logrho":
|
||||
solver_type = "midpoint"
|
||||
else:
|
||||
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
|
||||
@@ -163,9 +163,15 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
|
||||
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
|
||||
if algorithm_type == "deis":
|
||||
algorithm_type = "dpmsolver++"
|
||||
else:
|
||||
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
|
||||
if solver_type not in ["midpoint", "heun"]:
|
||||
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
|
||||
if solver_type == "logrho":
|
||||
solver_type = "midpoint"
|
||||
else:
|
||||
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
|
||||
@@ -41,4 +41,7 @@ _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
"DPMSolverSinglestepScheduler",
|
||||
"KDPM2DiscreteScheduler",
|
||||
"KDPM2AncestralDiscreteScheduler",
|
||||
"DEISMultistepScheduler",
|
||||
]
|
||||
|
||||
@@ -362,6 +362,21 @@ class DDPMScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DEISMultistepScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DPMSolverMultistepScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ import diffusers
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
@@ -2505,6 +2506,207 @@ class KDPM2DiscreteSchedulerTest(SchedulerCommonTest):
|
||||
assert abs(result_mean.item() - 0.0266) < 1e-3
|
||||
|
||||
|
||||
class DEISMultistepSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (DEISMultistepScheduler,)
|
||||
forward_default_kwargs = (("num_inference_steps", 25),)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1000,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"solver_order": 2,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
|
||||
|
||||
output, new_output = sample, sample
|
||||
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
|
||||
output = scheduler.step(residual, t, output, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# copy over dummy past residuals (must be after setting timesteps)
|
||||
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# copy over dummy past residual (must be after setting timesteps)
|
||||
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def full_loop(self, **config):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
# copy over dummy past residuals (must be done after set_timesteps)
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
|
||||
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
|
||||
|
||||
time_step_0 = scheduler.timesteps[5]
|
||||
time_step_1 = scheduler.timesteps[6]
|
||||
|
||||
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
|
||||
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [25, 50, 100, 999, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_thresholding(self):
|
||||
self.check_over_configs(thresholding=False)
|
||||
for order in [1, 2, 3]:
|
||||
for solver_type in ["logrho"]:
|
||||
for threshold in [0.5, 1.0, 2.0]:
|
||||
for prediction_type in ["epsilon", "sample"]:
|
||||
self.check_over_configs(
|
||||
thresholding=True,
|
||||
prediction_type=prediction_type,
|
||||
sample_max_value=threshold,
|
||||
algorithm_type="deis",
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_solver_order_and_type(self):
|
||||
for algorithm_type in ["deis"]:
|
||||
for solver_type in ["logrho"]:
|
||||
for order in [1, 2, 3]:
|
||||
for prediction_type in ["epsilon", "sample"]:
|
||||
self.check_over_configs(
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
prediction_type=prediction_type,
|
||||
algorithm_type=algorithm_type,
|
||||
)
|
||||
sample = self.full_loop(
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
prediction_type=prediction_type,
|
||||
algorithm_type=algorithm_type,
|
||||
)
|
||||
assert not torch.isnan(sample).any(), "Samples have nan numbers"
|
||||
|
||||
def test_lower_order_final(self):
|
||||
self.check_over_configs(lower_order_final=True)
|
||||
self.check_over_configs(lower_order_final=False)
|
||||
|
||||
def test_inference_steps(self):
|
||||
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
|
||||
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
sample = self.full_loop()
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.23916) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction")
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.091) < 1e-3
|
||||
|
||||
def test_fp16_support(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter.half()
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
assert sample.dtype == torch.float16
|
||||
|
||||
|
||||
class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (KDPM2AncestralDiscreteScheduler,)
|
||||
num_inference_steps = 10
|
||||
|
||||
Reference in New Issue
Block a user