From aa190259896877bd13390f444404c68bfbaba4d6 Mon Sep 17 00:00:00 2001 From: Beinsezii <39478211+Beinsezii@users.noreply.github.com> Date: Tue, 2 Apr 2024 20:23:55 -0700 Subject: [PATCH] UniPC Multistep add `rescale_betas_zero_snr` (#7531) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * UniPC Multistep add `rescale_betas_zero_snr` Same patch as DPM and Euler with the patched final alpha cumprod BF16 doesn't seem to break down, I think cause UniPC upcasts during some phases already? We could still force an upcast since it only loses ≈ 0.005 it/s for me but the difference in output is very small. A better endeavor might upcasting in step() and removing all the other upcasts elsewhere? * UniPC ZSNR UT * Re-add `rescale_betas_zsnr` doc oops --- .../schedulers/scheduling_unipc_multistep.py | 51 +++++++++++++++++++ tests/schedulers/test_scheduler_unipc.py | 4 ++ 2 files changed, 55 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 0bf03f057f..74e97a33f1 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -71,6 +71,43 @@ def betas_for_alpha_bar( return torch.tensor(betas, dtype=torch.float32) +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. @@ -130,6 +167,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): final_sigmas_type (`str`, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -157,6 +198,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): timestep_spacing: str = "linspace", steps_offset: int = 0, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -171,8 +213,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index a591a60105..5eb4d5ceef 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -180,6 +180,10 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): for prediction_type in ["epsilon", "v_prediction"]: self.check_over_configs(prediction_type=prediction_type) + def test_rescale_betas_zero_snr(self): + for rescale_betas_zero_snr in [True, False]: + self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_solver_order_and_type(self): for solver_type in ["bh1", "bh2"]: for order in [1, 2, 3]: