1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix ddpm scheduler

This commit is contained in:
Patrick von Platen
2022-07-19 23:47:04 +00:00
parent 1f49a343b5
commit 7e11392dfd
2 changed files with 12 additions and 9 deletions

View File

@@ -51,13 +51,7 @@ class DDPMPipeline(DiffusionPipeline):
# 2. predict previous mean of image x_t-1
pred_prev_image = self.scheduler.step(model_output, t, image)["prev_sample"]
# 3. optionally sample variance
variance = 0
if t > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.scheduler.get_variance(t).sqrt() * noise
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
# 3. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image
return {"sample": image}

View File

@@ -101,7 +101,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
)[::-1].copy()
self.set_format(tensor_format=self.tensor_format)
def get_variance(self, t, variance_type=None):
def _get_variance(self, t, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
@@ -133,6 +133,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
predict_epsilon=True,
generator=None,
):
t = timestep
# 1. compute alphas, betas
@@ -161,6 +162,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
# 6. Add noise
variance = 0
if t > 0:
noise = torch.randn(model_output.shape, generator=generator).to(model_output.device)
variance = self._get_variance(t).sqrt() * noise
pred_prev_sample = pred_prev_sample + variance
return {"prev_sample": pred_prev_sample}
def add_noise(self, original_samples, noise, timesteps):