1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add the UniPC scheduler (#2373)

* add UniPC scheduler

* add the return type to the functions

* code quality check

* add tests

* finish docs

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Wenliang Zhao
2023-02-17 02:19:06 +08:00
committed by GitHub
parent 2777264ee8
commit aaaec06487
12 changed files with 852 additions and 5 deletions

View File

@@ -148,7 +148,7 @@
- local: api/pipelines/stable_diffusion/upscale
title: Super-Resolution
- local: api/pipelines/stable_diffusion/latent_upscale
title: Stable-Diffusion-Latent-Upscaler
title: Stable-Diffusion-Latent-Upscaler
- local: api/pipelines/stable_diffusion/pix2pix
title: InstructPix2Pix
- local: api/pipelines/stable_diffusion/pix2pix_zero
@@ -204,6 +204,8 @@
title: Singlestep DPM-Solver
- local: api/schedulers/stochastic_karras_ve
title: Stochastic Kerras VE
- local: api/schedulers/unipc
title: UniPCMultistepScheduler
- local: api/schedulers/score_sde_ve
title: VE-SDE
- local: api/schedulers/score_sde_vp

View File

@@ -43,11 +43,11 @@ To this end, the design of schedulers is such that:
The following table summarizes all officially supported schedulers, their corresponding paper
| Scheduler | Paper |
|---|---|
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) |
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) |
| [deis](./deis) | [**DEISMultistepScheduler**](https://arxiv.org/abs/2204.13902) |
| [singlestep_dpm_solver](./singlestep_dpm_solver) | [**Singlestep DPM-Solver**](https://arxiv.org/abs/2206.00927) |
| [multistep_dpm_solver](./multistep_dpm_solver) | [**Multistep DPM-Solver**](https://arxiv.org/abs/2206.00927) |
| [heun](./heun) | [**Heun scheduler inspired by Karras et. al paper**](https://arxiv.org/abs/2206.00364) |
@@ -62,6 +62,7 @@ The following table summarizes all officially supported schedulers, their corres
| [euler](./euler) | [**Euler scheduler**](https://arxiv.org/abs/2206.00364) |
| [euler_ancestral](./euler_ancestral) | [**Euler Ancestral scheduler**](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72) |
| [vq_diffusion](./vq_diffusion) | [**VQDiffusionScheduler**](https://arxiv.org/abs/2111.14822) |
| [unipc](./unipc) | [**UniPCMultistepScheduler**](https://arxiv.org/abs/2302.04867) |
| [repaint](./repaint) | [**RePaint scheduler**](https://arxiv.org/abs/2201.09865) |
## API

View File

@@ -0,0 +1,24 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# UniPC
## Overview
UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders.
For more details about the method, please refer to the [[paper]](https://arxiv.org/abs/2302.04867) and the [[code]](https://github.com/wl-zhao/UniPC).
Fast Sampling of Diffusion Models with Exponential Integrator.
## UniPCMultistepScheduler
[[autodoc]] UniPCMultistepScheduler

View File

@@ -84,6 +84,7 @@ else:
SchedulerMixin,
ScoreSdeVeScheduler,
UnCLIPScheduler,
UniPCMultistepScheduler,
VQDiffusionScheduler,
)
from .training_utils import EMAModel

View File

@@ -39,6 +39,7 @@ else:
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_unclip import UnCLIPScheduler
from .scheduling_unipc_multistep import UniPCMultistepScheduler
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler

View File

@@ -159,7 +159,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["logrho"]:
if solver_type in ["midpoint", "heun"]:
if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
solver_type = "logrho"
else:
raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}")

View File

@@ -169,7 +169,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]:
if solver_type == "logrho":
if solver_type in ["logrho", "bh1", "bh2"]:
solver_type = "midpoint"
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")

View File

@@ -168,7 +168,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]:
if solver_type == "logrho":
if solver_type in ["logrho", "bh1", "bh2"]:
solver_type = "midpoint"
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")

View File

@@ -0,0 +1,607 @@
# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a
corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders. UniPC is
by desinged model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional sampling. It can
also be applied to both noise prediction model and data prediction model. The corrector UniC can be also applied
after any off-the-shelf solvers to increase the order of accuracy.
For more details, see the original paper: https://arxiv.org/abs/2302.04867
Currently, we support the multistep UniPC for both noise prediction models and data prediction models. We recommend
to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion).
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
solver_order (`int`, default `2`):
the order of UniPC, also the p in UniPC-p; can be any positive integer. Note that the effective order of
accuracy is `solver_order + 1` due to the UniC. We recommend to use `solver_order=2` for guided sampling,
and `solver_order=3` for unconditional sampling.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
models (such as stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://arxiv.org/abs/2205.11487).
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
predict_x0 (`bool`, default `True`):
whether to use the updating algrithm on the predicted x0. See https://arxiv.org/abs/2211.01095 for details
solver_type (`str`, default `bh2`):
the solver type of UniPC. We recommend use `bh1` for unconditional sampling when steps < 10, and use `bh2`
otherwise.
lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
disable_corrector (`list`, default `[]`):
decide which step to disable the corrector. For large guidance scale, the misalignment between the
`epsilon_theta(x_t, c)`and `epsilon_theta(x_t^c, c)` might influence the convergence. This can be mitigated
by disable the corrector at the first few steps (e.g., disable_corrector=[0])
solver_p (`SchedulerMixin`, default `None`):
can be any other scheduler. If specified, the algorithm will become solver_p + UniC.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
predict_x0: bool = True,
solver_type: str = "bh1",
lower_order_final: bool = True,
disable_corrector: List[int] = [],
solver_p: SchedulerMixin = None,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Currently we only support VP-type noise schedule
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
if solver_type not in ["bh1", "bh2"]:
if solver_type in ["midpoint", "heun", "logrho"]:
solver_type = "bh1"
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
self.predict_x0 = predict_x0
# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order
self.timestep_list = [None] * solver_order
self.lower_order_nums = 0
self.disable_corrector = disable_corrector
self.solver_p = solver_p
self.last_sample = None
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.model_outputs = [
None,
] * self.config.solver_order
self.lower_order_nums = 0
self.last_sample = None
if self.solver_p:
self.solver_p.set_timesteps(num_inference_steps, device=device)
def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
r"""
Convert the model output to the corresponding type that the algorithm PC needs.
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
Returns:
`torch.FloatTensor`: the converted model output.
"""
if self.predict_x0:
if self.config.prediction_type == "epsilon":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DPMSolverMultistepScheduler."
)
if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float()
dynamic_max_val = torch.quantile(
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
)
dynamic_max_val = torch.maximum(
dynamic_max_val,
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
x0_pred = x0_pred.type(orig_dtype)
return x0_pred
else:
if self.config.prediction_type == "epsilon":
return model_output
elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DPMSolverMultistepScheduler."
)
def multistep_uni_p_bh_update(
self,
model_output: torch.FloatTensor,
prev_timestep: int,
sample: torch.FloatTensor,
order: int,
) -> torch.FloatTensor:
"""
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
Args:
model_output (`torch.FloatTensor`):
direct outputs from learned diffusion model at the current timestep.
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
order (`int`): the order of UniP at this step, also the p in UniPC-p.
Returns:
`torch.FloatTensor`: the sample tensor at the previous timestep.
"""
timestep_list = self.timestep_list
model_output_list = self.model_outputs
s0, t = self.timestep_list[-1], prev_timestep
m0 = model_output_list[-1]
x = sample
if self.solver_p:
x_t = self.solver_p.step(model_output, s0, x).prev_sample
return x_t
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h = lambda_t - lambda_s0
device = sample.device
rks = []
D1s = []
for i in range(1, order):
si = timestep_list[-(i + 1)]
mi = model_output_list[-(i + 1)]
lambda_si = self.lambda_t[si]
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = torch.tensor(rks, device=device)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
else:
D1s = None
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res
x_t = x_t.to(x.dtype)
return x_t
def multistep_uni_c_bh_update(
self,
this_model_output: torch.FloatTensor,
this_timestep: int,
last_sample: torch.FloatTensor,
this_sample: torch.FloatTensor,
order: int,
) -> torch.FloatTensor:
"""
One step for the UniC (B(h) version).
Args:
this_model_output (`torch.FloatTensor`): the model outputs at `x_t`
this_timestep (`int`): the current timestep `t`
last_sample (`torch.FloatTensor`): the generated sample before the last predictor: `x_{t-1}`
this_sample (`torch.FloatTensor`): the generated sample after the last predictor: `x_{t}`
order (`int`): the `p` of UniC-p at this step. Note that the effective order of accuracy
should be order + 1
Returns:
`torch.FloatTensor`: the corrected sample tensor at the current timestep.
"""
timestep_list = self.timestep_list
model_output_list = self.model_outputs
s0, t = timestep_list[-1], this_timestep
m0 = model_output_list[-1]
x = last_sample
x_t = this_sample
model_t = this_model_output
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h = lambda_t - lambda_s0
device = this_sample.device
rks = []
D1s = []
for i in range(1, order):
si = timestep_list[-(i + 1)]
mi = model_output_list[-(i + 1)]
lambda_si = self.lambda_t[si]
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = torch.tensor(rks, device=device)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1)
else:
D1s = None
# for order 1, we use a simplified version
if order == 1:
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_c = torch.linalg.solve(R, b)
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
x_t = x_t.to(x.dtype)
return x_t
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Step function propagating the sample with the multistep UniPC.
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
use_corrector = (
step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None
)
model_output_convert = self.convert_model_output(model_output, timestep, sample)
if use_corrector:
sample = self.multistep_uni_c_bh_update(
this_model_output=model_output_convert,
this_timestep=timestep,
last_sample=self.last_sample,
this_sample=sample,
order=self.this_order,
)
# now prepare to run the predictor
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.timestep_list[i] = self.timestep_list[i + 1]
self.model_outputs[-1] = model_output_convert
self.timestep_list[-1] = timestep
if self.config.lower_order_final:
this_order = min(self.config.solver_order, len(self.timesteps) - step_index)
else:
this_order = self.config.solver_order
self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
assert self.this_order > 0
self.last_sample = sample
prev_sample = self.multistep_uni_p_bh_update(
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
prev_timestep=prev_timestep,
sample=sample,
order=self.this_order,
)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -42,6 +42,7 @@ class KarrasDiffusionSchedulers(Enum):
KDPM2DiscreteScheduler = 10
KDPM2AncestralDiscreteScheduler = 11
DEISMultistepScheduler = 12
UniPCMultistepScheduler = 13
@dataclass

View File

@@ -600,6 +600,21 @@ class UnCLIPScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class UniPCMultistepScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class VQDiffusionScheduler(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -40,6 +40,7 @@ from diffusers import (
PNDMScheduler,
ScoreSdeVeScheduler,
UnCLIPScheduler,
UniPCMultistepScheduler,
VQDiffusionScheduler,
logging,
)
@@ -2636,6 +2637,200 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest):
assert sample.dtype == torch.float16
class UniPCMultistepSchedulerTest(SchedulerCommonTest):
scheduler_classes = (UniPCMultistepScheduler,)
forward_default_kwargs = (("num_inference_steps", 25),)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"solver_order": 2,
}
config.update(**kwargs)
return config
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals (must be after setting timesteps)
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residual (must be after setting timesteps)
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
return sample
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
# copy over dummy past residuals (must be done after set_timesteps)
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
time_step_0 = scheduler.timesteps[5]
time_step_1 = scheduler.timesteps[6]
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_thresholding(self):
self.check_over_configs(thresholding=False)
for order in [1, 2, 3]:
for solver_type in ["bh1", "bh2"]:
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
solver_order=order,
solver_type=solver_type,
)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_solver_order_and_type(self):
for solver_type in ["bh1", "bh2"]:
for order in [1, 2, 3]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
)
sample = self.full_loop(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
)
assert not torch.isnan(sample).any(), "Samples have nan numbers"
def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False)
def test_inference_steps(self):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2521) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.1096) < 1e-3
def test_fp16_support(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter.half()
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
assert sample.dtype == torch.float16
class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (KDPM2AncestralDiscreteScheduler,)
num_inference_steps = 10