mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
refactor DPMSolverMultistepScheduler using sigmas (#4986)
--------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -22,6 +22,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@@ -186,6 +187,14 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
@@ -225,17 +234,16 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.use_karras_sigmas:
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
# when num_inference_steps == num_train_timesteps, we can end up with
|
||||
# duplicates in timesteps.
|
||||
_, unique_indices = np.unique(timesteps, return_index=True)
|
||||
timesteps = timesteps[np.sort(unique_indices)]
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
@@ -245,6 +253,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -280,8 +291,57 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
log_sigma = np.log(sigma)
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - log_sigmas[:, np.newaxis]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = log_sigmas[low_idx]
|
||||
high = log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = np.clip(w, 0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type the DEIS algorithm needs.
|
||||
@@ -298,13 +358,26 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
if self.config.prediction_type == "epsilon":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -316,7 +389,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
if self.config.algorithm_type == "deis":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
return (sample - alpha_t * x0_pred) / sigma_t
|
||||
else:
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
@@ -324,9 +396,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def deis_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the first-order DEIS (equivalent to DDIM).
|
||||
@@ -345,9 +417,33 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||
sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
||||
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "deis":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
@@ -358,9 +454,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_deis_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the second-order multistep DEIS.
|
||||
@@ -368,10 +464,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -379,10 +471,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1]
|
||||
sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1]
|
||||
|
||||
rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1
|
||||
|
||||
@@ -403,9 +523,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_deis_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the third-order multistep DEIS.
|
||||
@@ -413,10 +533,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by diffusion process.
|
||||
|
||||
@@ -424,15 +540,47 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2]
|
||||
sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2]
|
||||
|
||||
rho_t, rho_s0, rho_s1, rho_s2 = (
|
||||
sigma_t / alpha_t,
|
||||
sigma_s0 / alpha_s0,
|
||||
sigma_s1 / alpha_s1,
|
||||
simga_s2 / alpha_s2,
|
||||
sigma_s2 / alpha_s2,
|
||||
)
|
||||
|
||||
if self.config.algorithm_type == "deis":
|
||||
@@ -460,6 +608,25 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -492,42 +659,34 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
lower_order_final = (
|
||||
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
lower_order_second = (
|
||||
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
|
||||
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||||
prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample)
|
||||
prev_sample = self.deis_first_order_update(model_output, sample=sample)
|
||||
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||
timestep_list = [self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_deis_second_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
prev_sample = self.multistep_deis_second_order_update(self.model_outputs, sample=sample)
|
||||
else:
|
||||
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_deis_third_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
prev_sample = self.multistep_deis_third_order_update(self.model_outputs, sample=sample)
|
||||
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -548,28 +707,30 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -21,6 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
@@ -203,6 +204,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
@@ -242,19 +251,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
# when num_inference_steps == num_train_timesteps, we can end up with
|
||||
# duplicates in timesteps.
|
||||
_, unique_indices = np.unique(timesteps, return_index=True)
|
||||
timesteps = timesteps[np.sort(unique_indices)]
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
@@ -264,6 +273,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -323,6 +335,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
@@ -338,7 +356,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
||||
@@ -355,8 +377,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -364,6 +384,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
||||
@@ -371,12 +403,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||||
if self.config.variance_type in ["learned", "learned_range"]:
|
||||
model_output = model_output[:, :3]
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -398,10 +432,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
epsilon = model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -410,7 +446,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
||||
@@ -420,10 +457,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the first-order DPMSolver (equivalent to DDIM).
|
||||
@@ -431,10 +468,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -442,9 +475,33 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
||||
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
||||
@@ -469,10 +526,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the second-order multistep DPMSolver.
|
||||
@@ -480,10 +537,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -491,11 +544,43 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
||||
@@ -564,9 +649,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the third-order multistep DPMSolver.
|
||||
@@ -574,10 +659,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by diffusion process.
|
||||
|
||||
@@ -585,16 +666,47 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
self.lambda_t[t],
|
||||
self.lambda_t[s0],
|
||||
self.lambda_t[s1],
|
||||
self.lambda_t[s2],
|
||||
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
||||
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
|
||||
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m0
|
||||
@@ -619,6 +731,25 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -654,22 +785,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
lower_order_final = (
|
||||
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
lower_order_second = (
|
||||
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
@@ -682,23 +808,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise = None
|
||||
|
||||
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||||
prev_sample = self.dpm_solver_first_order_update(
|
||||
model_output, timestep, prev_timestep, sample, noise=noise
|
||||
)
|
||||
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
|
||||
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||
timestep_list = [self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
|
||||
)
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
|
||||
else:
|
||||
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
|
||||
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -719,28 +840,30 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -21,6 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
@@ -203,8 +204,16 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -244,11 +253,19 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
timesteps = timesteps.copy().astype(np.int64)
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_max = (
|
||||
(1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep]
|
||||
) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
@@ -266,6 +283,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -325,6 +345,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
@@ -341,7 +368,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
||||
@@ -358,8 +389,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -367,6 +396,18 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
||||
@@ -374,12 +415,14 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||||
if self.config.variance_type in ["learned", "learned_range"]:
|
||||
model_output = model_output[:, :3]
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -401,10 +444,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
epsilon = model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -413,20 +458,22 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
||||
|
||||
return epsilon
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the first-order DPMSolver (equivalent to DDIM).
|
||||
@@ -434,10 +481,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -445,27 +488,62 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
||||
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
elif "sde" in self.config.algorithm_type:
|
||||
raise NotImplementedError(
|
||||
f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}."
|
||||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config.algorithm_type == "sde-dpmsolver":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(alpha_t / alpha_s) * sample
|
||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the second-order multistep DPMSolver.
|
||||
@@ -473,10 +551,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -484,11 +558,43 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
||||
@@ -520,19 +626,47 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
)
|
||||
elif "sde" in self.config.algorithm_type:
|
||||
raise NotImplementedError(
|
||||
f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}."
|
||||
)
|
||||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
if self.config.solver_type == "midpoint":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
||||
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config.solver_type == "heun":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
||||
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config.algorithm_type == "sde-dpmsolver":
|
||||
assert noise is not None
|
||||
if self.config.solver_type == "midpoint":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample
|
||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- (sigma_t * (torch.exp(h) - 1.0)) * D1
|
||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
elif self.config.solver_type == "heun":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample
|
||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
|
||||
def multistep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the third-order multistep DPMSolver.
|
||||
@@ -540,10 +674,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by diffusion process.
|
||||
|
||||
@@ -551,16 +681,47 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
self.lambda_t[t],
|
||||
self.lambda_t[s0],
|
||||
self.lambda_t[s1],
|
||||
self.lambda_t[s2],
|
||||
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
||||
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
|
||||
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m0
|
||||
@@ -585,6 +746,27 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -604,6 +786,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||
|
||||
@@ -618,24 +802,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
prev_timestep = (
|
||||
self.noisiest_timestep if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
)
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
lower_order_final = (
|
||||
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
lower_order_second = (
|
||||
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
@@ -648,23 +825,18 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise = None
|
||||
|
||||
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||||
prev_sample = self.dpm_solver_first_order_update(
|
||||
model_output, timestep, prev_timestep, sample, noise=noise
|
||||
)
|
||||
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
|
||||
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||
timestep_list = [self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
|
||||
)
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
|
||||
else:
|
||||
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
|
||||
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -686,28 +858,30 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -21,7 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from ..utils import deprecate, logging
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@@ -197,6 +197,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.sample = None
|
||||
self.order_list = self.get_order_list(num_train_timesteps)
|
||||
self._step_index = None
|
||||
|
||||
def get_order_list(self, num_inference_steps: int) -> List[int]:
|
||||
"""
|
||||
@@ -232,6 +233,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
orders = [1] * steps
|
||||
return orders
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -256,11 +264,16 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.use_karras_sigmas:
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
self.model_outputs = [None] * self.config.solver_order
|
||||
@@ -274,6 +287,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.order_list = self.get_order_list(num_inference_steps)
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -333,6 +349,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
@@ -348,7 +371,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
||||
@@ -365,8 +392,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -374,18 +399,32 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
if self.config.prediction_type == "epsilon":
|
||||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||||
if self.config.variance_type in ["learned_range"]:
|
||||
model_output = model_output[:, :3]
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -405,11 +444,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output = model_output[:, :3]
|
||||
return model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
return epsilon
|
||||
else:
|
||||
@@ -421,9 +462,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the first-order DPMSolver (equivalent to DDIM).
|
||||
@@ -442,9 +483,31 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
||||
@@ -455,9 +518,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def singlestep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
|
||||
@@ -477,11 +540,42 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
||||
alpha_t, alpha_s1 = self.alpha_t[t], self.alpha_t[s1]
|
||||
sigma_t, sigma_s1 = self.sigma_t[t], self.sigma_t[s1]
|
||||
|
||||
h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m1, (1.0 / r0) * (m0 - m1)
|
||||
@@ -518,9 +612,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def singlestep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
|
||||
@@ -540,16 +634,47 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
self.lambda_t[t],
|
||||
self.lambda_t[s0],
|
||||
self.lambda_t[s1],
|
||||
self.lambda_t[s2],
|
||||
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
alpha_t, alpha_s2 = self.alpha_t[t], self.alpha_t[s2]
|
||||
sigma_t, sigma_s2 = self.sigma_t[t], self.sigma_t[s2]
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
||||
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
|
||||
h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m2
|
||||
@@ -591,10 +716,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def singlestep_dpm_solver_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
order: int,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
order: int = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the singlestep DPMSolver.
|
||||
@@ -615,19 +740,60 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if order is None:
|
||||
if len(args) > 3:
|
||||
order = args[3]
|
||||
else:
|
||||
raise ValueError(" missing `order` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if order == 1:
|
||||
return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample)
|
||||
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
|
||||
elif order == 2:
|
||||
return self.singlestep_dpm_solver_second_order_update(
|
||||
model_output_list, timestep_list, prev_timestep, sample
|
||||
)
|
||||
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
|
||||
elif order == 3:
|
||||
return self.singlestep_dpm_solver_third_order_update(
|
||||
model_output_list, timestep_list, prev_timestep, sample
|
||||
)
|
||||
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
|
||||
else:
|
||||
raise ValueError(f"Order must be 1, 2, 3, got {order}")
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -660,21 +826,15 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
|
||||
order = self.order_list[step_index]
|
||||
order = self.order_list[self.step_index]
|
||||
|
||||
# For img2img denoising might start with order>1 which is not possible
|
||||
# In this case make sure that the first two steps are both order=1
|
||||
@@ -685,10 +845,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if order == 1:
|
||||
self.sample = sample
|
||||
|
||||
timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep]
|
||||
prev_sample = self.singlestep_dpm_solver_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, self.sample, order
|
||||
)
|
||||
prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
@@ -710,28 +870,30 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -22,10 +22,16 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -38,19 +44,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
@@ -181,6 +198,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.disable_corrector = disable_corrector
|
||||
self.solver_p = solver_p
|
||||
self.last_sample = None
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
@@ -220,17 +245,16 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.use_karras_sigmas:
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
# when num_inference_steps == num_train_timesteps, we can end up with
|
||||
# duplicates in timesteps.
|
||||
_, unique_indices = np.unique(timesteps, return_index=True)
|
||||
timesteps = timesteps[np.sort(unique_indices)]
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
@@ -243,6 +267,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.solver_p:
|
||||
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -302,6 +329,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
@@ -317,7 +351,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Convert the model output to the corresponding type the UniPC algorithm needs.
|
||||
@@ -334,14 +372,28 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
|
||||
if self.predict_x0:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -357,11 +409,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.config.prediction_type == "epsilon":
|
||||
return model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
return epsilon
|
||||
else:
|
||||
@@ -373,9 +423,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_uni_p_bh_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
order: int,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
order: int = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
|
||||
@@ -394,10 +445,26 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep_list = self.timestep_list
|
||||
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if order is None:
|
||||
if len(args) > 2:
|
||||
order = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `order` as a required keyward argument")
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
model_output_list = self.model_outputs
|
||||
|
||||
s0, t = self.timestep_list[-1], prev_timestep
|
||||
s0 = self.timestep_list[-1]
|
||||
m0 = model_output_list[-1]
|
||||
x = sample
|
||||
|
||||
@@ -405,9 +472,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
||||
return x_t
|
||||
|
||||
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
|
||||
h = lambda_t - lambda_s0
|
||||
device = sample.device
|
||||
@@ -415,9 +485,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
si = timestep_list[-(i + 1)]
|
||||
si = self.step_index - i
|
||||
mi = model_output_list[-(i + 1)]
|
||||
lambda_si = self.lambda_t[si]
|
||||
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
||||
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
||||
rk = (lambda_si - lambda_s0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
@@ -481,10 +552,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_uni_c_bh_update(
|
||||
self,
|
||||
this_model_output: torch.FloatTensor,
|
||||
this_timestep: int,
|
||||
last_sample: torch.FloatTensor,
|
||||
this_sample: torch.FloatTensor,
|
||||
order: int,
|
||||
*args,
|
||||
last_sample: torch.FloatTensor = None,
|
||||
this_sample: torch.FloatTensor = None,
|
||||
order: int = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the UniC (B(h) version).
|
||||
@@ -505,18 +577,42 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The corrected sample tensor at the current timestep.
|
||||
"""
|
||||
timestep_list = self.timestep_list
|
||||
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
|
||||
if last_sample is None:
|
||||
if len(args) > 1:
|
||||
last_sample = args[1]
|
||||
else:
|
||||
raise ValueError(" missing`last_sample` as a required keyward argument")
|
||||
if this_sample is None:
|
||||
if len(args) > 2:
|
||||
this_sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`this_sample` as a required keyward argument")
|
||||
if order is None:
|
||||
if len(args) > 3:
|
||||
order = args[3]
|
||||
else:
|
||||
raise ValueError(" missing`order` as a required keyward argument")
|
||||
if this_timestep is not None:
|
||||
deprecate(
|
||||
"this_timestep",
|
||||
"1.0.0",
|
||||
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
model_output_list = self.model_outputs
|
||||
|
||||
s0, t = timestep_list[-1], this_timestep
|
||||
m0 = model_output_list[-1]
|
||||
x = last_sample
|
||||
x_t = this_sample
|
||||
model_t = this_model_output
|
||||
|
||||
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
|
||||
h = lambda_t - lambda_s0
|
||||
device = this_sample.device
|
||||
@@ -524,9 +620,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
si = timestep_list[-(i + 1)]
|
||||
si = self.step_index - (i + 1)
|
||||
mi = model_output_list[-(i + 1)]
|
||||
lambda_si = self.lambda_t[si]
|
||||
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
||||
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
||||
rk = (lambda_si - lambda_s0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
@@ -589,6 +686,25 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x_t = x_t.to(x.dtype)
|
||||
return x_t
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -616,37 +732,27 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
use_corrector = (
|
||||
step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None
|
||||
self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
|
||||
)
|
||||
|
||||
model_output_convert = self.convert_model_output(model_output, timestep, sample)
|
||||
model_output_convert = self.convert_model_output(model_output, sample=sample)
|
||||
if use_corrector:
|
||||
sample = self.multistep_uni_c_bh_update(
|
||||
this_model_output=model_output_convert,
|
||||
this_timestep=timestep,
|
||||
last_sample=self.last_sample,
|
||||
this_sample=sample,
|
||||
order=self.this_order,
|
||||
)
|
||||
|
||||
# now prepare to run the predictor
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.timestep_list[i] = self.timestep_list[i + 1]
|
||||
@@ -655,7 +761,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timestep_list[-1] = timestep
|
||||
|
||||
if self.config.lower_order_final:
|
||||
this_order = min(self.config.solver_order, len(self.timesteps) - step_index)
|
||||
this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index)
|
||||
else:
|
||||
this_order = self.config.solver_order
|
||||
|
||||
@@ -665,7 +771,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.last_sample = sample
|
||||
prev_sample = self.multistep_uni_p_bh_update(
|
||||
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
|
||||
prev_timestep=prev_timestep,
|
||||
sample=sample,
|
||||
order=self.this_order,
|
||||
)
|
||||
@@ -673,6 +778,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -693,28 +801,30 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -51,6 +51,7 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
output, new_output = sample, sample
|
||||
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
|
||||
t = scheduler.timesteps[t]
|
||||
output = scheduler.step(residual, t, output, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
output, new_output = sample, sample
|
||||
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
|
||||
t = new_scheduler.timesteps[t]
|
||||
output = scheduler.step(residual, t, output, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
|
||||
|
||||
@@ -91,6 +92,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
# copy over dummy past residual (must be after setting timesteps)
|
||||
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
|
||||
|
||||
time_step = new_scheduler.timesteps[time_step]
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
@@ -264,10 +266,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
assert sample.dtype == torch.float16
|
||||
|
||||
def test_unique_timesteps(self, **config):
|
||||
def test_duplicated_timesteps(self, **config):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
|
||||
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps
|
||||
assert len(scheduler.timesteps) == scheduler.num_inference_steps
|
||||
|
||||
@@ -54,6 +54,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
output, new_output = sample, sample
|
||||
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
|
||||
t = scheduler.timesteps[t]
|
||||
output = scheduler.step(residual, t, output, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
|
||||
|
||||
@@ -222,7 +223,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 1.7833) < 1e-3
|
||||
assert abs(result_mean.item() - 1.7833) < 2e-3
|
||||
|
||||
def test_switch(self):
|
||||
# make sure that iterating over schedulers with same config names gives same results
|
||||
|
||||
@@ -58,6 +58,7 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
output, new_output = sample, sample
|
||||
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
|
||||
t = scheduler.timesteps[t]
|
||||
output = scheduler.step(residual, t, output, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
|
||||
|
||||
@@ -248,3 +249,33 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
assert sample.dtype == torch.float16
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
# copy over dummy past residuals (must be done after set_timesteps)
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
|
||||
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
|
||||
|
||||
time_step_0 = scheduler.timesteps[0]
|
||||
time_step_1 = scheduler.timesteps[1]
|
||||
|
||||
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
|
||||
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
@@ -52,6 +52,7 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
output, new_output = sample, sample
|
||||
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
|
||||
t = scheduler.timesteps[t]
|
||||
output = scheduler.step(residual, t, output, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
|
||||
|
||||
@@ -241,11 +242,3 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
assert sample.dtype == torch.float16
|
||||
|
||||
def test_unique_timesteps(self, **config):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
|
||||
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps
|
||||
|
||||
Reference in New Issue
Block a user