mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
ddpm scheduler variance fixes
This commit is contained in:
committed by
Will Berman
parent
2cbdc586de
commit
1dc856e508
@@ -214,16 +214,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# and sample from it to get previous sample
|
||||
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
|
||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
|
||||
variance = torch.clamp(variance, min=1e-20)
|
||||
|
||||
if variance_type is None:
|
||||
variance_type = self.config.variance_type
|
||||
|
||||
# hacks - were probably added for training stability
|
||||
if variance_type == "fixed_small":
|
||||
variance = torch.clamp(variance, min=1e-20)
|
||||
variance = variance
|
||||
# for rl-diffuser https://arxiv.org/abs/2205.09991
|
||||
elif variance_type == "fixed_small_log":
|
||||
variance = torch.log(torch.clamp(variance, min=1e-20))
|
||||
variance = torch.log(variance, min=1e-20)
|
||||
variance = torch.exp(0.5 * variance)
|
||||
elif variance_type == "fixed_large":
|
||||
variance = current_beta_t
|
||||
@@ -234,7 +235,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
return predicted_variance
|
||||
elif variance_type == "learned_range":
|
||||
min_log = torch.log(variance)
|
||||
max_log = torch.log(self.betas[t])
|
||||
max_log = torch.log(current_beta_t)
|
||||
frac = (predicted_variance + 1) / 2
|
||||
variance = frac * max_log + (1 - frac) * min_log
|
||||
|
||||
|
||||
Reference in New Issue
Block a user