From 1dc856e508ea3722a47915ee2a472f5091d49a40 Mon Sep 17 00:00:00 2001 From: William Berman Date: Thu, 6 Apr 2023 21:34:36 -0700 Subject: [PATCH] ddpm scheduler variance fixes --- src/diffusers/schedulers/scheduling_ddpm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index e047a553a2..481010fcb7 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -214,16 +214,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t + variance = torch.clamp(variance, min=1e-20) if variance_type is None: variance_type = self.config.variance_type # hacks - were probably added for training stability if variance_type == "fixed_small": - variance = torch.clamp(variance, min=1e-20) + variance = variance # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": - variance = torch.log(torch.clamp(variance, min=1e-20)) + variance = torch.log(variance, min=1e-20) variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": variance = current_beta_t @@ -234,7 +235,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): return predicted_variance elif variance_type == "learned_range": min_log = torch.log(variance) - max_log = torch.log(self.betas[t]) + max_log = torch.log(current_beta_t) frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log