1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

ddpm scheduler variance fixes

This commit is contained in:
William Berman
2023-04-06 21:34:36 -07:00
committed by Will Berman
parent 2cbdc586de
commit 1dc856e508

View File

@@ -214,16 +214,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
variance = torch.clamp(variance, min=1e-20)
if variance_type is None:
variance_type = self.config.variance_type
# hacks - were probably added for training stability
if variance_type == "fixed_small":
variance = torch.clamp(variance, min=1e-20)
variance = variance
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif variance_type == "fixed_small_log":
variance = torch.log(torch.clamp(variance, min=1e-20))
variance = torch.log(variance, min=1e-20)
variance = torch.exp(0.5 * variance)
elif variance_type == "fixed_large":
variance = current_beta_t
@@ -234,7 +235,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return predicted_variance
elif variance_type == "learned_range":
min_log = torch.log(variance)
max_log = torch.log(self.betas[t])
max_log = torch.log(current_beta_t)
frac = (predicted_variance + 1) / 2
variance = frac * max_log + (1 - frac) * min_log