mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Cont'd] Add the SDE variant of ~~DPM-Solver~~ and DPM-Solver++ to DPM Single Step (#8269)
* Add the SDE variant of DPM-Solver and DPM-Solver++ to DPM Single Step --------- Co-authored-by: cmdr2 <secondary.cmdr2@gmail.com>
This commit is contained in:
@@ -22,6 +22,7 @@ import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@@ -108,11 +109,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
||||
`algorithm_type="dpmsolver++"`.
|
||||
algorithm_type (`str`, defaults to `dpmsolver++`):
|
||||
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
|
||||
algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type
|
||||
implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is
|
||||
recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in
|
||||
Stable Diffusion.
|
||||
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver`
|
||||
type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the
|
||||
`dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095)
|
||||
paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided
|
||||
sampling like in Stable Diffusion.
|
||||
solver_type (`str`, defaults to `midpoint`):
|
||||
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
||||
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
||||
@@ -186,7 +187,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]:
|
||||
if algorithm_type == "deis":
|
||||
self.register_to_config(algorithm_type="dpmsolver++")
|
||||
else:
|
||||
@@ -197,7 +198,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
||||
|
||||
if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
|
||||
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
|
||||
)
|
||||
@@ -493,10 +494,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||||
if self.config.variance_type in ["learned_range"]:
|
||||
if self.config.variance_type in ["learned", "learned_range"]:
|
||||
model_output = model_output[:, :3]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
@@ -517,34 +518,43 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
return x0_pred
|
||||
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
if self.config.prediction_type == "epsilon":
|
||||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||||
if self.config.variance_type in ["learned_range"]:
|
||||
model_output = model_output[:, :3]
|
||||
return model_output
|
||||
if self.config.variance_type in ["learned", "learned_range"]:
|
||||
epsilon = model_output[:, :3]
|
||||
else:
|
||||
epsilon = model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
return epsilon
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction` for the DPMSolverSinglestepScheduler."
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
||||
|
||||
return epsilon
|
||||
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -594,6 +604,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
return x_t
|
||||
|
||||
def singlestep_dpm_solver_second_order_update(
|
||||
@@ -601,6 +618,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output_list: List[torch.Tensor],
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -688,6 +706,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
)
|
||||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
if self.config.solver_type == "midpoint":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s1 * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
||||
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config.solver_type == "heun":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s1 * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
||||
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
return x_t
|
||||
|
||||
def singlestep_dpm_solver_third_order_update(
|
||||
@@ -800,6 +834,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
order: int = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -848,9 +883,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if order == 1:
|
||||
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
|
||||
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample, noise=noise)
|
||||
elif order == 2:
|
||||
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
|
||||
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
|
||||
elif order == 3:
|
||||
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
|
||||
else:
|
||||
@@ -894,6 +929,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
sample: torch.Tensor,
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -929,6 +965,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
|
||||
if self.config.algorithm_type == "sde-dpmsolver++":
|
||||
noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||||
)
|
||||
else:
|
||||
noise = None
|
||||
|
||||
order = self.order_list[self.step_index]
|
||||
|
||||
# For img2img denoising might start with order>1 which is not possible
|
||||
@@ -940,9 +983,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if order == 1:
|
||||
self.sample = sample
|
||||
|
||||
prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order)
|
||||
prev_sample = self.singlestep_dpm_solver_update(
|
||||
self.model_outputs, sample=self.sample, order=order, noise=noise
|
||||
)
|
||||
|
||||
# upon completion increase step index by one
|
||||
# upon completion increase step index by one, noise=noise
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -194,16 +194,20 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_solver_order_and_type(self):
|
||||
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
|
||||
for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]:
|
||||
for solver_type in ["midpoint", "heun"]:
|
||||
for order in [1, 2, 3]:
|
||||
for prediction_type in ["epsilon", "sample"]:
|
||||
self.check_over_configs(
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
prediction_type=prediction_type,
|
||||
algorithm_type=algorithm_type,
|
||||
)
|
||||
if algorithm_type == "sde-dpmsolver++":
|
||||
if order == 3:
|
||||
continue
|
||||
else:
|
||||
self.check_over_configs(
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
prediction_type=prediction_type,
|
||||
algorithm_type=algorithm_type,
|
||||
)
|
||||
sample = self.full_loop(
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
|
||||
Reference in New Issue
Block a user