From aaaec06487ffc3cd5d1c1d4e1a64af80dbaf477b Mon Sep 17 00:00:00 2001 From: Wenliang Zhao Date: Fri, 17 Feb 2023 02:19:06 +0800 Subject: [PATCH] add the UniPC scheduler (#2373) * add UniPC scheduler * add the return type to the functions * code quality check * add tests * finish docs --------- Co-authored-by: Patrick von Platen --- docs/source/en/_toctree.yml | 4 +- docs/source/en/api/schedulers/overview.mdx | 3 +- docs/source/en/api/schedulers/unipc.mdx | 24 + src/diffusers/__init__.py | 1 + src/diffusers/schedulers/__init__.py | 1 + .../schedulers/scheduling_deis_multistep.py | 2 +- .../scheduling_dpmsolver_multistep.py | 2 +- .../scheduling_dpmsolver_singlestep.py | 2 +- .../schedulers/scheduling_unipc_multistep.py | 607 ++++++++++++++++++ src/diffusers/schedulers/scheduling_utils.py | 1 + src/diffusers/utils/dummy_pt_objects.py | 15 + tests/test_scheduler.py | 195 ++++++ 12 files changed, 852 insertions(+), 5 deletions(-) create mode 100644 docs/source/en/api/schedulers/unipc.mdx create mode 100644 src/diffusers/schedulers/scheduling_unipc_multistep.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index afc6add42c..531777290c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -148,7 +148,7 @@ - local: api/pipelines/stable_diffusion/upscale title: Super-Resolution - local: api/pipelines/stable_diffusion/latent_upscale - title: Stable-Diffusion-Latent-Upscaler + title: Stable-Diffusion-Latent-Upscaler - local: api/pipelines/stable_diffusion/pix2pix title: InstructPix2Pix - local: api/pipelines/stable_diffusion/pix2pix_zero @@ -204,6 +204,8 @@ title: Singlestep DPM-Solver - local: api/schedulers/stochastic_karras_ve title: Stochastic Kerras VE + - local: api/schedulers/unipc + title: UniPCMultistepScheduler - local: api/schedulers/score_sde_ve title: VE-SDE - local: api/schedulers/score_sde_vp diff --git a/docs/source/en/api/schedulers/overview.mdx b/docs/source/en/api/schedulers/overview.mdx index 35c26a7219..cfb319b342 100644 --- a/docs/source/en/api/schedulers/overview.mdx +++ b/docs/source/en/api/schedulers/overview.mdx @@ -43,11 +43,11 @@ To this end, the design of schedulers is such that: The following table summarizes all officially supported schedulers, their corresponding paper - | Scheduler | Paper | |---|---| | [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | | [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | +| [deis](./deis) | [**DEISMultistepScheduler**](https://arxiv.org/abs/2204.13902) | | [singlestep_dpm_solver](./singlestep_dpm_solver) | [**Singlestep DPM-Solver**](https://arxiv.org/abs/2206.00927) | | [multistep_dpm_solver](./multistep_dpm_solver) | [**Multistep DPM-Solver**](https://arxiv.org/abs/2206.00927) | | [heun](./heun) | [**Heun scheduler inspired by Karras et. al paper**](https://arxiv.org/abs/2206.00364) | @@ -62,6 +62,7 @@ The following table summarizes all officially supported schedulers, their corres | [euler](./euler) | [**Euler scheduler**](https://arxiv.org/abs/2206.00364) | | [euler_ancestral](./euler_ancestral) | [**Euler Ancestral scheduler**](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72) | | [vq_diffusion](./vq_diffusion) | [**VQDiffusionScheduler**](https://arxiv.org/abs/2111.14822) | +| [unipc](./unipc) | [**UniPCMultistepScheduler**](https://arxiv.org/abs/2302.04867) | | [repaint](./repaint) | [**RePaint scheduler**](https://arxiv.org/abs/2201.09865) | ## API diff --git a/docs/source/en/api/schedulers/unipc.mdx b/docs/source/en/api/schedulers/unipc.mdx new file mode 100644 index 0000000000..fec9a7c2d3 --- /dev/null +++ b/docs/source/en/api/schedulers/unipc.mdx @@ -0,0 +1,24 @@ + + +# UniPC + +## Overview + +UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders. + +For more details about the method, please refer to the [[paper]](https://arxiv.org/abs/2302.04867) and the [[code]](https://github.com/wl-zhao/UniPC). + +Fast Sampling of Diffusion Models with Exponential Integrator. + +## UniPCMultistepScheduler +[[autodoc]] UniPCMultistepScheduler diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 78bace948d..6be22f06da 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -84,6 +84,7 @@ else: SchedulerMixin, ScoreSdeVeScheduler, UnCLIPScheduler, + UniPCMultistepScheduler, VQDiffusionScheduler, ) from .training_utils import EMAModel diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 3746acd5b5..95af68be60 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -39,6 +39,7 @@ else: from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_unclip import UnCLIPScheduler + from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index b331f18659..b1382c59f5 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -159,7 +159,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if solver_type not in ["logrho"]: - if solver_type in ["midpoint", "heun"]: + if solver_type in ["midpoint", "heun", "bh1", "bh2"]: solver_type = "logrho" else: raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}") diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 8b36fade6f..a256ec23f1 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -169,7 +169,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: - if solver_type == "logrho": + if solver_type in ["logrho", "bh1", "bh2"]: solver_type = "midpoint" else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 7b40e1d48d..9ab539647b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -168,7 +168,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: - if solver_type == "logrho": + if solver_type in ["logrho", "bh1", "bh2"]: solver_type = "midpoint" else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py new file mode 100644 index 0000000000..1dff93b19d --- /dev/null +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -0,0 +1,607 @@ +# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a + corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders. UniPC is + by desinged model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional sampling. It can + also be applied to both noise prediction model and data prediction model. The corrector UniC can be also applied + after any off-the-shelf solvers to increase the order of accuracy. + + For more details, see the original paper: https://arxiv.org/abs/2302.04867 + + Currently, we support the multistep UniPC for both noise prediction models and data prediction models. We recommend + to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic + thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + solver_order (`int`, default `2`): + the order of UniPC, also the p in UniPC-p; can be any positive integer. Note that the effective order of + accuracy is `solver_order + 1` due to the UniC. We recommend to use `solver_order=2` for guided sampling, + and `solver_order=3` for unconditional sampling. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to + use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion + models (such as stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, default `True`): + whether to use the updating algrithm on the predicted x0. See https://arxiv.org/abs/2211.01095 for details + solver_type (`str`, default `bh2`): + the solver type of UniPC. We recommend use `bh1` for unconditional sampling when steps < 10, and use `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + decide which step to disable the corrector. For large guidance scale, the misalignment between the + `epsilon_theta(x_t, c)`and `epsilon_theta(x_t^c, c)` might influence the convergence. This can be mitigated + by disable the corrector at the first few steps (e.g., disable_corrector=[0]) + solver_p (`SchedulerMixin`, default `None`): + can be any other scheduler. If specified, the algorithm will become solver_p + UniC. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh1", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + solver_type = "bh1" + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + timesteps = ( + np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(num_inference_steps, device=device) + + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + r""" + Convert the model output to the corresponding type that the algorithm PC needs. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the converted model output. + """ + if self.predict_x0: + if self.config.prediction_type == "epsilon": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + # Dynamic thresholding in https://arxiv.org/abs/2205.11487 + orig_dtype = x0_pred.dtype + if orig_dtype not in [torch.float, torch.double]: + x0_pred = x0_pred.float() + dynamic_max_val = torch.quantile( + torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1 + ) + dynamic_max_val = torch.maximum( + dynamic_max_val, + self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device), + )[(...,) + (None,) * (x0_pred.ndim - 1)] + x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val + x0_pred = x0_pred.type(orig_dtype) + return x0_pred + else: + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + def multistep_uni_p_bh_update( + self, + model_output: torch.FloatTensor, + prev_timestep: int, + sample: torch.FloatTensor, + order: int, + ) -> torch.FloatTensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.FloatTensor`): + direct outputs from learned diffusion model at the current timestep. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + order (`int`): the order of UniP at this step, also the p in UniPC-p. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + timestep_list = self.timestep_list + model_output_list = self.model_outputs + + s0, t = self.timestep_list[-1], prev_timestep + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = timestep_list[-(i + 1)] + mi = model_output_list[-(i + 1)] + lambda_si = self.lambda_t[si] + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.FloatTensor, + this_timestep: int, + last_sample: torch.FloatTensor, + this_sample: torch.FloatTensor, + order: int, + ) -> torch.FloatTensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.FloatTensor`): the model outputs at `x_t` + this_timestep (`int`): the current timestep `t` + last_sample (`torch.FloatTensor`): the generated sample before the last predictor: `x_{t-1}` + this_sample (`torch.FloatTensor`): the generated sample after the last predictor: `x_{t}` + order (`int`): the `p` of UniC-p at this step. Note that the effective order of accuracy + should be order + 1 + + Returns: + `torch.FloatTensor`: the corrected sample tensor at the current timestep. + """ + timestep_list = self.timestep_list + model_output_list = self.model_outputs + + s0, t = timestep_list[-1], this_timestep + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = timestep_list[-(i + 1)] + mi = model_output_list[-(i + 1)] + lambda_si = self.lambda_t[si] + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the multistep UniPC. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + + use_corrector = ( + step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None + ) + + model_output_convert = self.convert_model_output(model_output, timestep, sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + this_timestep=timestep, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + # now prepare to run the predictor + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, len(self.timesteps) - step_index) + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + prev_timestep=prev_timestep, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 424447efbd..ba1842ec00 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -42,6 +42,7 @@ class KarrasDiffusionSchedulers(Enum): KDPM2DiscreteScheduler = 10 KDPM2AncestralDiscreteScheduler = 11 DEISMultistepScheduler = 12 + UniPCMultistepScheduler = 13 @dataclass diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 546992bc43..ea0cb77b5b 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -600,6 +600,21 @@ class UnCLIPScheduler(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class UniPCMultistepScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class VQDiffusionScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index f38b6b6b34..22a7d9e16c 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -40,6 +40,7 @@ from diffusers import ( PNDMScheduler, ScoreSdeVeScheduler, UnCLIPScheduler, + UniPCMultistepScheduler, VQDiffusionScheduler, logging, ) @@ -2636,6 +2637,200 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest): assert sample.dtype == torch.float16 +class UniPCMultistepSchedulerTest(SchedulerCommonTest): + scheduler_classes = (UniPCMultistepScheduler,) + forward_default_kwargs = (("num_inference_steps", 25),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "solver_order": 2, + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + # copy over dummy past residuals + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + new_scheduler.set_timesteps(num_inference_steps) + # copy over dummy past residuals + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] + + output, new_output = sample, sample + for t in range(time_step, time_step + scheduler.config.solver_order + 1): + output = scheduler.step(residual, t, output, **kwargs).prev_sample + new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residuals (must be after setting timesteps) + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + # copy over dummy past residuals + new_scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residual (must be after setting timesteps) + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] + + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + return sample + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # copy over dummy past residuals (must be done after set_timesteps) + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] + + time_step_0 = scheduler.timesteps[5] + time_step_1 = scheduler.timesteps[6] + + output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample + output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [25, 50, 100, 999, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_thresholding(self): + self.check_over_configs(thresholding=False) + for order in [1, 2, 3]: + for solver_type in ["bh1", "bh2"]: + for threshold in [0.5, 1.0, 2.0]: + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs( + thresholding=True, + prediction_type=prediction_type, + sample_max_value=threshold, + solver_order=order, + solver_type=solver_type, + ) + + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_solver_order_and_type(self): + for solver_type in ["bh1", "bh2"]: + for order in [1, 2, 3]: + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + ) + sample = self.full_loop( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + ) + assert not torch.isnan(sample).any(), "Samples have nan numbers" + + def test_lower_order_final(self): + self.check_over_configs(lower_order_final=True) + self.check_over_configs(lower_order_final=False) + + def test_inference_steps(self): + for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: + self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2521) < 1e-3 + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.1096) < 1e-3 + + def test_fp16_support(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter.half() + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + assert sample.dtype == torch.float16 + + class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest): scheduler_classes = (KDPM2AncestralDiscreteScheduler,) num_inference_steps = 10