diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d8f75f4bdd..3e20d70663 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -82,6 +82,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) + self.variance_type = variance_type + def set_timesteps(self, num_inference_steps): num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) self.num_inference_steps = num_inference_steps @@ -90,7 +92,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): )[::-1].copy() self.set_format(tensor_format=self.tensor_format) - def _get_variance(self, t, variance_type=None): + def _get_variance(self, t, predicted_variance=None, variance_type=None): alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one @@ -113,6 +115,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): elif variance_type == "fixed_large_log": # Glide max_log variance = self.log(self.betas[t]) + elif variance_type == "learned": + return predicted_variance + elif variance_type == "learned_range": + min_log = variance + max_log = self.betas[t] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log return variance @@ -125,6 +134,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): generator=None, ): t = timestep + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one @@ -155,7 +170,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): variance = 0 if t > 0: noise = self.randn_like(model_output, generator=generator) - variance = (self._get_variance(t) ** 0.5) * noise + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise pred_prev_sample = pred_prev_sample + variance