1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Cont'd] Add the SDE variant of ~~DPM-Solver~~ and DPM-Solver++ to DPM Single Step (#8269)

* Add the SDE variant of DPM-Solver and DPM-Solver++ to DPM Single Step


---------

Co-authored-by: cmdr2 <secondary.cmdr2@gmail.com>
This commit is contained in:
Tolga Cangöz
2024-07-17 04:40:02 +03:00
committed by GitHub
parent 3b37fefee9
commit e87bf62940
2 changed files with 74 additions and 25 deletions

View File

@@ -22,6 +22,7 @@ import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate, logging
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
@@ -108,11 +109,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`.
algorithm_type (`str`, defaults to `dpmsolver++`):
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type
implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is
recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in
Stable Diffusion.
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver`
type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the
`dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095)
paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided
sampling like in Stable Diffusion.
solver_type (`str`, defaults to `midpoint`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
@@ -186,7 +187,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
else:
@@ -197,7 +198,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
raise ValueError(
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
)
@@ -493,10 +494,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3]
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
@@ -517,34 +518,43 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
return model_output
if self.config.variance_type in ["learned", "learned_range"]:
epsilon = model_output[:, :3]
else:
epsilon = model_output
elif self.config.prediction_type == "sample":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DPMSolverSinglestepScheduler."
)
if self.config.thresholding:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t
return epsilon
def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
@@ -594,6 +604,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t
def singlestep_dpm_solver_second_order_update(
@@ -601,6 +618,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
@@ -688,6 +706,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
- (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s1 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s1 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t
def singlestep_dpm_solver_third_order_update(
@@ -800,6 +834,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
*args,
sample: torch.Tensor = None,
order: int = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
@@ -848,9 +883,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
)
if order == 1:
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample, noise=noise)
elif order == 2:
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
elif order == 3:
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
else:
@@ -894,6 +929,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
@@ -929,6 +965,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
if self.config.algorithm_type == "sde-dpmsolver++":
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
)
else:
noise = None
order = self.order_list[self.step_index]
# For img2img denoising might start with order>1 which is not possible
@@ -940,9 +983,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
if order == 1:
self.sample = sample
prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order)
prev_sample = self.singlestep_dpm_solver_update(
self.model_outputs, sample=self.sample, order=order, noise=noise
)
# upon completion increase step index by one
# upon completion increase step index by one, noise=noise
self._step_index += 1
if not return_dict:

View File

@@ -194,16 +194,20 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
self.check_over_configs(prediction_type=prediction_type)
def test_solver_order_and_type(self):
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]:
for solver_type in ["midpoint", "heun"]:
for order in [1, 2, 3]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
if algorithm_type == "sde-dpmsolver++":
if order == 3:
continue
else:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
sample = self.full_loop(
solver_order=order,
solver_type=solver_type,