From 7e11392dfdae9fbb0d9b55742ac07bb0b3075c22 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 19 Jul 2022 23:47:04 +0000 Subject: [PATCH] fix ddpm scheduler --- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 10 ++-------- src/diffusers/schedulers/scheduling_ddpm.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 058f1c53f6..e72b05cf89 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -51,13 +51,7 @@ class DDPMPipeline(DiffusionPipeline): # 2. predict previous mean of image x_t-1 pred_prev_image = self.scheduler.step(model_output, t, image)["prev_sample"] - # 3. optionally sample variance - variance = 0 - if t > 0: - noise = torch.randn(image.shape, generator=generator).to(image.device) - variance = self.scheduler.get_variance(t).sqrt() * noise - - # 4. set current image to prev_image: x_t -> x_t-1 - image = pred_prev_image + variance + # 3. set current image to prev_image: x_t -> x_t-1 + image = pred_prev_image return {"sample": image} diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d5e3264038..25eae068a9 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -101,7 +101,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): )[::-1].copy() self.set_format(tensor_format=self.tensor_format) - def get_variance(self, t, variance_type=None): + def _get_variance(self, t, variance_type=None): alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one @@ -133,6 +133,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): timestep: int, sample: Union[torch.FloatTensor, np.ndarray], predict_epsilon=True, + generator=None, ): t = timestep # 1. compute alphas, betas @@ -161,6 +162,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + # 6. Add noise + variance = 0 + if t > 0: + noise = torch.randn(model_output.shape, generator=generator).to(model_output.device) + variance = self._get_variance(t).sqrt() * noise + + pred_prev_sample = pred_prev_sample + variance + return {"prev_sample": pred_prev_sample} def add_noise(self, original_samples, noise, timesteps):