mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add the UniPC scheduler (#2373)
* add UniPC scheduler * add the return type to the functions * code quality check * add tests * finish docs --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -148,7 +148,7 @@
|
||||
- local: api/pipelines/stable_diffusion/upscale
|
||||
title: Super-Resolution
|
||||
- local: api/pipelines/stable_diffusion/latent_upscale
|
||||
title: Stable-Diffusion-Latent-Upscaler
|
||||
title: Stable-Diffusion-Latent-Upscaler
|
||||
- local: api/pipelines/stable_diffusion/pix2pix
|
||||
title: InstructPix2Pix
|
||||
- local: api/pipelines/stable_diffusion/pix2pix_zero
|
||||
@@ -204,6 +204,8 @@
|
||||
title: Singlestep DPM-Solver
|
||||
- local: api/schedulers/stochastic_karras_ve
|
||||
title: Stochastic Kerras VE
|
||||
- local: api/schedulers/unipc
|
||||
title: UniPCMultistepScheduler
|
||||
- local: api/schedulers/score_sde_ve
|
||||
title: VE-SDE
|
||||
- local: api/schedulers/score_sde_vp
|
||||
|
||||
@@ -43,11 +43,11 @@ To this end, the design of schedulers is such that:
|
||||
|
||||
The following table summarizes all officially supported schedulers, their corresponding paper
|
||||
|
||||
|
||||
| Scheduler | Paper |
|
||||
|---|---|
|
||||
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) |
|
||||
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) |
|
||||
| [deis](./deis) | [**DEISMultistepScheduler**](https://arxiv.org/abs/2204.13902) |
|
||||
| [singlestep_dpm_solver](./singlestep_dpm_solver) | [**Singlestep DPM-Solver**](https://arxiv.org/abs/2206.00927) |
|
||||
| [multistep_dpm_solver](./multistep_dpm_solver) | [**Multistep DPM-Solver**](https://arxiv.org/abs/2206.00927) |
|
||||
| [heun](./heun) | [**Heun scheduler inspired by Karras et. al paper**](https://arxiv.org/abs/2206.00364) |
|
||||
@@ -62,6 +62,7 @@ The following table summarizes all officially supported schedulers, their corres
|
||||
| [euler](./euler) | [**Euler scheduler**](https://arxiv.org/abs/2206.00364) |
|
||||
| [euler_ancestral](./euler_ancestral) | [**Euler Ancestral scheduler**](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72) |
|
||||
| [vq_diffusion](./vq_diffusion) | [**VQDiffusionScheduler**](https://arxiv.org/abs/2111.14822) |
|
||||
| [unipc](./unipc) | [**UniPCMultistepScheduler**](https://arxiv.org/abs/2302.04867) |
|
||||
| [repaint](./repaint) | [**RePaint scheduler**](https://arxiv.org/abs/2201.09865) |
|
||||
|
||||
## API
|
||||
|
||||
24
docs/source/en/api/schedulers/unipc.mdx
Normal file
24
docs/source/en/api/schedulers/unipc.mdx
Normal file
@@ -0,0 +1,24 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# UniPC
|
||||
|
||||
## Overview
|
||||
|
||||
UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders.
|
||||
|
||||
For more details about the method, please refer to the [[paper]](https://arxiv.org/abs/2302.04867) and the [[code]](https://github.com/wl-zhao/UniPC).
|
||||
|
||||
Fast Sampling of Diffusion Models with Exponential Integrator.
|
||||
|
||||
## UniPCMultistepScheduler
|
||||
[[autodoc]] UniPCMultistepScheduler
|
||||
@@ -84,6 +84,7 @@ else:
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
UnCLIPScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
VQDiffusionScheduler,
|
||||
)
|
||||
from .training_utils import EMAModel
|
||||
|
||||
@@ -39,6 +39,7 @@ else:
|
||||
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
||||
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
||||
from .scheduling_unclip import UnCLIPScheduler
|
||||
from .scheduling_unipc_multistep import UniPCMultistepScheduler
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
from .scheduling_vq_diffusion import VQDiffusionScheduler
|
||||
|
||||
|
||||
@@ -159,7 +159,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
|
||||
|
||||
if solver_type not in ["logrho"]:
|
||||
if solver_type in ["midpoint", "heun"]:
|
||||
if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
|
||||
solver_type = "logrho"
|
||||
else:
|
||||
raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}")
|
||||
|
||||
@@ -169,7 +169,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
|
||||
if solver_type not in ["midpoint", "heun"]:
|
||||
if solver_type == "logrho":
|
||||
if solver_type in ["logrho", "bh1", "bh2"]:
|
||||
solver_type = "midpoint"
|
||||
else:
|
||||
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
|
||||
|
||||
@@ -168,7 +168,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
|
||||
if solver_type not in ["midpoint", "heun"]:
|
||||
if solver_type == "logrho":
|
||||
if solver_type in ["logrho", "bh1", "bh2"]:
|
||||
solver_type = "midpoint"
|
||||
else:
|
||||
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
|
||||
|
||||
607
src/diffusers/schedulers/scheduling_unipc_multistep.py
Normal file
607
src/diffusers/schedulers/scheduling_unipc_multistep.py
Normal file
@@ -0,0 +1,607 @@
|
||||
# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a
|
||||
corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders. UniPC is
|
||||
by desinged model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional sampling. It can
|
||||
also be applied to both noise prediction model and data prediction model. The corrector UniC can be also applied
|
||||
after any off-the-shelf solvers to increase the order of accuracy.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2302.04867
|
||||
|
||||
Currently, we support the multistep UniPC for both noise prediction models and data prediction models. We recommend
|
||||
to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
|
||||
|
||||
We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
|
||||
diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
|
||||
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
|
||||
stable-diffusion).
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
solver_order (`int`, default `2`):
|
||||
the order of UniPC, also the p in UniPC-p; can be any positive integer. Note that the effective order of
|
||||
accuracy is `solver_order + 1` due to the UniC. We recommend to use `solver_order=2` for guided sampling,
|
||||
and `solver_order=3` for unconditional sampling.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
|
||||
use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
|
||||
models (such as stable-diffusion).
|
||||
dynamic_thresholding_ratio (`float`, default `0.995`):
|
||||
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
(https://arxiv.org/abs/2205.11487).
|
||||
sample_max_value (`float`, default `1.0`):
|
||||
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
|
||||
predict_x0 (`bool`, default `True`):
|
||||
whether to use the updating algrithm on the predicted x0. See https://arxiv.org/abs/2211.01095 for details
|
||||
solver_type (`str`, default `bh2`):
|
||||
the solver type of UniPC. We recommend use `bh1` for unconditional sampling when steps < 10, and use `bh2`
|
||||
otherwise.
|
||||
lower_order_final (`bool`, default `True`):
|
||||
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
|
||||
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
|
||||
disable_corrector (`list`, default `[]`):
|
||||
decide which step to disable the corrector. For large guidance scale, the misalignment between the
|
||||
`epsilon_theta(x_t, c)`and `epsilon_theta(x_t^c, c)` might influence the convergence. This can be mitigated
|
||||
by disable the corrector at the first few steps (e.g., disable_corrector=[0])
|
||||
solver_p (`SchedulerMixin`, default `None`):
|
||||
can be any other scheduler. If specified, the algorithm will become solver_p + UniC.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
predict_x0: bool = True,
|
||||
solver_type: str = "bh1",
|
||||
lower_order_final: bool = True,
|
||||
disable_corrector: List[int] = [],
|
||||
solver_p: SchedulerMixin = None,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
# Currently we only support VP-type noise schedule
|
||||
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
||||
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
||||
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
if solver_type not in ["bh1", "bh2"]:
|
||||
if solver_type in ["midpoint", "heun", "logrho"]:
|
||||
solver_type = "bh1"
|
||||
else:
|
||||
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
|
||||
|
||||
self.predict_x0 = predict_x0
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.timestep_list = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self.disable_corrector = disable_corrector
|
||||
self.solver_p = solver_p
|
||||
self.last_sample = None
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, optional):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timesteps = (
|
||||
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
.round()[::-1][:-1]
|
||||
.copy()
|
||||
.astype(np.int64)
|
||||
)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
self.model_outputs = [
|
||||
None,
|
||||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
self.last_sample = None
|
||||
if self.solver_p:
|
||||
self.solver_p.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Convert the model output to the corresponding type that the algorithm PC needs.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: the converted model output.
|
||||
"""
|
||||
if self.predict_x0:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction` for the DPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
dynamic_max_val = torch.quantile(
|
||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
|
||||
)
|
||||
dynamic_max_val = torch.maximum(
|
||||
dynamic_max_val,
|
||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
x0_pred = x0_pred.type(orig_dtype)
|
||||
return x0_pred
|
||||
else:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
return model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
return epsilon
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction` for the DPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
def multistep_uni_p_bh_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
order: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
direct outputs from learned diffusion model at the current timestep.
|
||||
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
order (`int`): the order of UniP at this step, also the p in UniPC-p.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: the sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep_list = self.timestep_list
|
||||
model_output_list = self.model_outputs
|
||||
|
||||
s0, t = self.timestep_list[-1], prev_timestep
|
||||
m0 = model_output_list[-1]
|
||||
x = sample
|
||||
|
||||
if self.solver_p:
|
||||
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
||||
return x_t
|
||||
|
||||
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
|
||||
h = lambda_t - lambda_s0
|
||||
device = sample.device
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
si = timestep_list[-(i + 1)]
|
||||
mi = model_output_list[-(i + 1)]
|
||||
lambda_si = self.lambda_t[si]
|
||||
rk = (lambda_si - lambda_s0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
|
||||
rks.append(1.0)
|
||||
rks = torch.tensor(rks, device=device)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
|
||||
hh = -h if self.predict_x0 else h
|
||||
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
||||
h_phi_k = h_phi_1 / hh - 1
|
||||
|
||||
factorial_i = 1
|
||||
|
||||
if self.config.solver_type == "bh1":
|
||||
B_h = hh
|
||||
elif self.config.solver_type == "bh2":
|
||||
B_h = torch.expm1(hh)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(torch.pow(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= i + 1
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = torch.stack(R)
|
||||
b = torch.tensor(b, device=device)
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||
# for order 2, we use a simplified version
|
||||
if order == 2:
|
||||
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
|
||||
else:
|
||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
if self.predict_x0:
|
||||
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - alpha_t * B_h * pred_res
|
||||
else:
|
||||
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - sigma_t * B_h * pred_res
|
||||
|
||||
x_t = x_t.to(x.dtype)
|
||||
return x_t
|
||||
|
||||
def multistep_uni_c_bh_update(
|
||||
self,
|
||||
this_model_output: torch.FloatTensor,
|
||||
this_timestep: int,
|
||||
last_sample: torch.FloatTensor,
|
||||
this_sample: torch.FloatTensor,
|
||||
order: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the UniC (B(h) version).
|
||||
|
||||
Args:
|
||||
this_model_output (`torch.FloatTensor`): the model outputs at `x_t`
|
||||
this_timestep (`int`): the current timestep `t`
|
||||
last_sample (`torch.FloatTensor`): the generated sample before the last predictor: `x_{t-1}`
|
||||
this_sample (`torch.FloatTensor`): the generated sample after the last predictor: `x_{t}`
|
||||
order (`int`): the `p` of UniC-p at this step. Note that the effective order of accuracy
|
||||
should be order + 1
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: the corrected sample tensor at the current timestep.
|
||||
"""
|
||||
timestep_list = self.timestep_list
|
||||
model_output_list = self.model_outputs
|
||||
|
||||
s0, t = timestep_list[-1], this_timestep
|
||||
m0 = model_output_list[-1]
|
||||
x = last_sample
|
||||
x_t = this_sample
|
||||
model_t = this_model_output
|
||||
|
||||
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
|
||||
h = lambda_t - lambda_s0
|
||||
device = this_sample.device
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
si = timestep_list[-(i + 1)]
|
||||
mi = model_output_list[-(i + 1)]
|
||||
lambda_si = self.lambda_t[si]
|
||||
rk = (lambda_si - lambda_s0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
|
||||
rks.append(1.0)
|
||||
rks = torch.tensor(rks, device=device)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
|
||||
hh = -h if self.predict_x0 else h
|
||||
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
||||
h_phi_k = h_phi_1 / hh - 1
|
||||
|
||||
factorial_i = 1
|
||||
|
||||
if self.config.solver_type == "bh1":
|
||||
B_h = hh
|
||||
elif self.config.solver_type == "bh2":
|
||||
B_h = torch.expm1(hh)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(torch.pow(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= i + 1
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = torch.stack(R)
|
||||
b = torch.tensor(b, device=device)
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1)
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
# for order 1, we use a simplified version
|
||||
if order == 1:
|
||||
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
|
||||
else:
|
||||
rhos_c = torch.linalg.solve(R, b)
|
||||
|
||||
if self.predict_x0:
|
||||
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = model_t - m0
|
||||
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
||||
else:
|
||||
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = model_t - m0
|
||||
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
||||
x_t = x_t.to(x.dtype)
|
||||
return x_t
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Step function propagating the sample with the multistep UniPC.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
|
||||
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
|
||||
use_corrector = (
|
||||
step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None
|
||||
)
|
||||
|
||||
model_output_convert = self.convert_model_output(model_output, timestep, sample)
|
||||
if use_corrector:
|
||||
sample = self.multistep_uni_c_bh_update(
|
||||
this_model_output=model_output_convert,
|
||||
this_timestep=timestep,
|
||||
last_sample=self.last_sample,
|
||||
this_sample=sample,
|
||||
order=self.this_order,
|
||||
)
|
||||
|
||||
# now prepare to run the predictor
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.timestep_list[i] = self.timestep_list[i + 1]
|
||||
|
||||
self.model_outputs[-1] = model_output_convert
|
||||
self.timestep_list[-1] = timestep
|
||||
|
||||
if self.config.lower_order_final:
|
||||
this_order = min(self.config.solver_order, len(self.timesteps) - step_index)
|
||||
else:
|
||||
this_order = self.config.solver_order
|
||||
|
||||
self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
|
||||
assert self.this_order > 0
|
||||
|
||||
self.last_sample = sample
|
||||
prev_sample = self.multistep_uni_p_bh_update(
|
||||
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
|
||||
prev_timestep=prev_timestep,
|
||||
sample=sample,
|
||||
order=self.this_order,
|
||||
)
|
||||
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -42,6 +42,7 @@ class KarrasDiffusionSchedulers(Enum):
|
||||
KDPM2DiscreteScheduler = 10
|
||||
KDPM2AncestralDiscreteScheduler = 11
|
||||
DEISMultistepScheduler = 12
|
||||
UniPCMultistepScheduler = 13
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -600,6 +600,21 @@ class UnCLIPScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UniPCMultistepScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class VQDiffusionScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ from diffusers import (
|
||||
PNDMScheduler,
|
||||
ScoreSdeVeScheduler,
|
||||
UnCLIPScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
VQDiffusionScheduler,
|
||||
logging,
|
||||
)
|
||||
@@ -2636,6 +2637,200 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest):
|
||||
assert sample.dtype == torch.float16
|
||||
|
||||
|
||||
class UniPCMultistepSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (UniPCMultistepScheduler,)
|
||||
forward_default_kwargs = (("num_inference_steps", 25),)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1000,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"solver_order": 2,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
|
||||
|
||||
output, new_output = sample, sample
|
||||
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
|
||||
output = scheduler.step(residual, t, output, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# copy over dummy past residuals (must be after setting timesteps)
|
||||
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# copy over dummy past residual (must be after setting timesteps)
|
||||
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def full_loop(self, **config):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
# copy over dummy past residuals (must be done after set_timesteps)
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
|
||||
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
|
||||
|
||||
time_step_0 = scheduler.timesteps[5]
|
||||
time_step_1 = scheduler.timesteps[6]
|
||||
|
||||
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
|
||||
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [25, 50, 100, 999, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_thresholding(self):
|
||||
self.check_over_configs(thresholding=False)
|
||||
for order in [1, 2, 3]:
|
||||
for solver_type in ["bh1", "bh2"]:
|
||||
for threshold in [0.5, 1.0, 2.0]:
|
||||
for prediction_type in ["epsilon", "sample"]:
|
||||
self.check_over_configs(
|
||||
thresholding=True,
|
||||
prediction_type=prediction_type,
|
||||
sample_max_value=threshold,
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_solver_order_and_type(self):
|
||||
for solver_type in ["bh1", "bh2"]:
|
||||
for order in [1, 2, 3]:
|
||||
for prediction_type in ["epsilon", "sample"]:
|
||||
self.check_over_configs(
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
sample = self.full_loop(
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
assert not torch.isnan(sample).any(), "Samples have nan numbers"
|
||||
|
||||
def test_lower_order_final(self):
|
||||
self.check_over_configs(lower_order_final=True)
|
||||
self.check_over_configs(lower_order_final=False)
|
||||
|
||||
def test_inference_steps(self):
|
||||
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
|
||||
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
sample = self.full_loop()
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.2521) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction")
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.1096) < 1e-3
|
||||
|
||||
def test_fp16_support(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter.half()
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
assert sample.dtype == torch.float16
|
||||
|
||||
|
||||
class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (KDPM2AncestralDiscreteScheduler,)
|
||||
num_inference_steps = 10
|
||||
|
||||
Reference in New Issue
Block a user