From 457abdf2cf31956a15df7233187b0b358307c7d1 Mon Sep 17 00:00:00 2001 From: Beinsezii <39478211+Beinsezii@users.noreply.github.com> Date: Tue, 19 Dec 2023 23:39:25 -0800 Subject: [PATCH] EulerAncestral add `rescale_betas_zero_snr` (#6187) * EulerAncestral add `rescale_betas_zero_snr` Uses same infinite sigma fix from EulerDiscrete. Interestingly the ancestral version had the opposite problem: too much contrast instead of too little. * UT for EulerAncestral `rescale_betas_zero_snr` * EulerAncestral upcast samples during step() It helps this scheduler too, particularly when the model is using bf16. While the noise dtype is still the model's it's automatically upcasted for the add so all it affects is determinism. --------- Co-authored-by: Sayak Paul --- .../scheduling_euler_ancestral_discrete.py | 56 +++++++++++++++++++ .../test_scheduler_euler_ancestral.py | 4 ++ 2 files changed, 60 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index e476c32945..ca188378a3 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -92,6 +92,43 @@ def betas_for_alpha_bar( return torch.tensor(betas, dtype=torch.float32) +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Ancestral sampling with Euler method steps. @@ -122,6 +159,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): An offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable Diffusion. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -138,6 +179,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -152,9 +194,17 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) @@ -327,6 +377,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): sigma = self.sigmas[self.step_index] + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output @@ -357,6 +410,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prev_sample = prev_sample + noise * sigma_up + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + # upon completion increase step index by one self._step_index += 1 diff --git a/tests/schedulers/test_scheduler_euler_ancestral.py b/tests/schedulers/test_scheduler_euler_ancestral.py index a0818042fa..9f22ab38dd 100644 --- a/tests/schedulers/test_scheduler_euler_ancestral.py +++ b/tests/schedulers/test_scheduler_euler_ancestral.py @@ -37,6 +37,10 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): for prediction_type in ["epsilon", "v_prediction"]: self.check_over_configs(prediction_type=prediction_type) + def test_rescale_betas_zero_snr(self): + for rescale_betas_zero_snr in [True, False]: + self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config()