From 5321f3e2035675ee3f749fae2298a2bc6a6f012a Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 22 Aug 2022 12:08:07 +0530 Subject: [PATCH] add add_noise method in LMSDiscreteScheduler, PNDMScheduler (#227) add add_noise method in more schedulers --- src/diffusers/schedulers/scheduling_lms_discrete.py | 9 +++++++++ src/diffusers/schedulers/scheduling_pndm.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 55dd3dbec8..0e1ed2049b 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -130,5 +130,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): noisy_samples = (alpha_prod**0.5) * original_samples + ((1 - alpha_prod) ** 0.5) * noise return noisy_samples + def add_noise(self, original_samples, noise, timesteps): + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + def __len__(self): return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 1d4f70d4d2..cd1d2bb2a7 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -250,5 +250,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): return prev_sample + def add_noise(self, original_samples, noise, timesteps): + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + def __len__(self): return self.config.num_train_timesteps