mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Allow DDPM scheduler to use model's predicated variance (#132)
* Extented the ability of ddpm scheduler to utilize model that also predict the variance. * Update src/diffusers/schedulers/scheduling_ddpm.py Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
This commit is contained in:
@@ -82,6 +82,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
self.variance_type = variance_type
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
|
||||
self.num_inference_steps = num_inference_steps
|
||||
@@ -90,7 +92,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)[::-1].copy()
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def _get_variance(self, t, variance_type=None):
|
||||
def _get_variance(self, t, predicted_variance=None, variance_type=None):
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
|
||||
@@ -113,6 +115,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif variance_type == "fixed_large_log":
|
||||
# Glide max_log
|
||||
variance = self.log(self.betas[t])
|
||||
elif variance_type == "learned":
|
||||
return predicted_variance
|
||||
elif variance_type == "learned_range":
|
||||
min_log = variance
|
||||
max_log = self.betas[t]
|
||||
frac = (predicted_variance + 1) / 2
|
||||
variance = frac * max_log + (1 - frac) * min_log
|
||||
|
||||
return variance
|
||||
|
||||
@@ -125,6 +134,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
generator=None,
|
||||
):
|
||||
t = timestep
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||
else:
|
||||
predicted_variance = None
|
||||
|
||||
# 1. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
@@ -155,7 +170,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
variance = 0
|
||||
if t > 0:
|
||||
noise = self.randn_like(model_output, generator=generator)
|
||||
variance = (self._get_variance(t) ** 0.5) * noise
|
||||
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
|
||||
|
||||
pred_prev_sample = pred_prev_sample + variance
|
||||
|
||||
|
||||
Reference in New Issue
Block a user