mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[2064]: Add Karras to DPMSolverMultistepScheduler (#3001)
* [2737]: Add Karras DPMSolverMultistepScheduler * [2737]: Add Karras DPMSolverMultistepScheduler * Add test * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix: repo consistency. * remove Copied from statement from the set_timestep method. * fix: test * Empty commit. Co-authored-by: njindal <njindal@adobe.com> --------- Co-authored-by: njindal <njindal@adobe.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -171,7 +171,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_timesteps
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
@@ -114,7 +114,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
lower_order_final (`bool`, default `True`):
|
||||
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
|
||||
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
|
||||
|
||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
|
||||
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
|
||||
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
@@ -136,6 +139,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -181,6 +185,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
@@ -199,6 +204,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
.astype(np.int64)
|
||||
)
|
||||
|
||||
if self.use_karras_sigmas:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
||||
|
||||
# when num_inference_steps == num_train_timesteps, we can end up with
|
||||
# duplicates in timesteps.
|
||||
_, unique_indices = np.unique(timesteps, return_index=True)
|
||||
@@ -248,6 +260,44 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
log_sigma = np.log(sigma)
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - log_sigmas[:, np.newaxis]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = log_sigmas[low_idx]
|
||||
high = log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = np.clip(w, 0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
@@ -206,7 +206,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if self.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas)
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
@@ -241,14 +241,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
return t
|
||||
|
||||
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, self.num_inference_steps)
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
|
||||
@@ -209,6 +209,12 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
assert abs(result_mean.item() - 0.2251) < 1e-3
|
||||
|
||||
def test_full_loop_with_karras_and_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.2096) < 1e-3
|
||||
|
||||
def test_switch(self):
|
||||
# make sure that iterating over schedulers with same config names gives same results
|
||||
# for defaults
|
||||
|
||||
Reference in New Issue
Block a user