From ecadcdefe1c626bda12c4a593aa6681e6234796e Mon Sep 17 00:00:00 2001 From: Dudu Moshe <53430514+dudulightricks@users.noreply.github.com> Date: Fri, 3 Feb 2023 10:42:42 +0200 Subject: [PATCH] [Bug] scheduling_ddpm: fix variance in the case of learned_range type. (#2090) scheduling_ddpm: fix variance in the case of learned_range type. In the case of learned_range variance type, there are missing logs and exponent comparing to the theory (see "Improved Denoising Diffusion Probabilistic Models" section 3.1 equation 15: https://arxiv.org/pdf/2102.09672.pdf). --- src/diffusers/schedulers/scheduling_ddpm.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 63b31033c9..9d8aa6fa5b 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -218,8 +218,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": - min_log = variance - max_log = self.betas[t] + min_log = torch.log(variance) + max_log = torch.log(self.betas[t]) frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log @@ -304,6 +304,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ) if self.variance_type == "fixed_small_log": variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise + elif self.variance_type == "learned_range": + variance = self._get_variance(t, predicted_variance=predicted_variance) + variance = torch.exp(0.5 * variance) * variance_noise else: variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise