From 194b0a425dfa0bcdb048ab8f37d1668682c1a91b Mon Sep 17 00:00:00 2001 From: Isotr0py <41363108+Isotr0py@users.noreply.github.com> Date: Mon, 22 May 2023 22:43:56 +0800 Subject: [PATCH] Add `use_Karras_sigmas` to DPMSolverSinglestepScheduler (#3476) * add use_karras_sigmas * add karras test * add doc --- .../scheduling_dpmsolver_singlestep.py | 52 +++++++++++++++++++ tests/schedulers/test_scheduler_dpm_single.py | 12 +++++ 2 files changed, 64 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 8ddd30b0a1..7fa8eabb5a 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -117,6 +117,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable this to use up all the function evaluations. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. lambda_min_clipped (`float`, default `-inf`): the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for cosine (squaredcos_cap_v2) noise schedule. @@ -150,6 +154,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, ): @@ -197,6 +202,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): self.model_outputs = [None] * solver_order self.sample = None self.order_list = self.get_order_list(num_train_timesteps) + self.use_karras_sigmas = use_karras_sigmas def get_order_list(self, num_inference_steps: int) -> List[int]: """ @@ -252,6 +258,14 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): .copy() .astype(np.int64) ) + + if self.use_karras_sigmas: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) self.model_outputs = [None] * self.config.solver_order self.sample = None @@ -299,6 +313,44 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): return sample + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 18a706a1f5..66be3d5d00 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -215,12 +215,24 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): assert abs(result_mean.item() - 0.2791) < 1e-3 + def test_full_loop_with_karras(self): + sample = self.full_loop(use_karras_sigmas=True) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2248) < 1e-3 + def test_full_loop_with_v_prediction(self): sample = self.full_loop(prediction_type="v_prediction") result_mean = torch.mean(torch.abs(sample)) assert abs(result_mean.item() - 0.1453) < 1e-3 + def test_full_loop_with_karras_and_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.0649) < 1e-3 + def test_fp16_support(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)