1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Allow DDPM scheduler to use model's predicated variance (#132)

* Extented the ability of ddpm scheduler
to utilize model that also predict the variance.

* Update src/diffusers/schedulers/scheduling_ddpm.py

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
This commit is contained in:
Eyal Mazuz
2022-08-03 13:40:04 +03:00
committed by GitHub
parent b6cadcef98
commit b6447fa87e

View File

@@ -82,6 +82,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
self.variance_type = variance_type
def set_timesteps(self, num_inference_steps):
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
self.num_inference_steps = num_inference_steps
@@ -90,7 +92,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
)[::-1].copy()
self.set_format(tensor_format=self.tensor_format)
def _get_variance(self, t, variance_type=None):
def _get_variance(self, t, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
@@ -113,6 +115,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
elif variance_type == "fixed_large_log":
# Glide max_log
variance = self.log(self.betas[t])
elif variance_type == "learned":
return predicted_variance
elif variance_type == "learned_range":
min_log = variance
max_log = self.betas[t]
frac = (predicted_variance + 1) / 2
variance = frac * max_log + (1 - frac) * min_log
return variance
@@ -125,6 +134,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
generator=None,
):
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
predicted_variance = None
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
@@ -155,7 +170,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance = 0
if t > 0:
noise = self.randn_like(model_output, generator=generator)
variance = (self._get_variance(t) ** 0.5) * noise
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
pred_prev_sample = pred_prev_sample + variance