From 2bdde4dd834f1055bcfe37c62e877f1337ffa8b1 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 7 Oct 2022 16:44:19 +0200 Subject: [PATCH] [schedulers] hanlde dtype in add_noise (#767) * handle dtype in vae and image2image pipeline * handle dtype in add noise * don't modify vae and pipeline * remove the if --- src/diffusers/schedulers/scheduling_ddim.py | 8 +++----- src/diffusers/schedulers/scheduling_ddpm.py | 8 +++----- src/diffusers/schedulers/scheduling_lms_discrete.py | 10 +++++++--- src/diffusers/schedulers/scheduling_pndm.py | 8 +++----- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 2dc85a93ad..2d24ecac1d 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -300,11 +300,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: - if self.alphas_cumprod.device != original_samples.device: - self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) - - if timesteps.device != original_samples.device: - timesteps = timesteps.to(original_samples.device) + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index e1db9079d1..77ed981377 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -294,11 +294,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: - if self.alphas_cumprod.device != original_samples.device: - self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) - - if timesteps.device != original_samples.device: - timesteps = timesteps.to(original_samples.device) + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 802da468cd..1f6187c727 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -257,9 +257,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - sigmas = self.sigmas.to(original_samples.device) - schedule_timesteps = self.timesteps.to(original_samples.device) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + self.timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) + + schedule_timesteps = self.timesteps + if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor): deprecate( "timesteps as indices", @@ -273,7 +277,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): else: step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sigma = sigmas[step_indices].flatten() + sigma = self.sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index f6a6d6153b..b29712e1e7 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -400,11 +400,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.Tensor: - if self.alphas_cumprod.device != original_samples.device: - self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) - - if timesteps.device != original_samples.device: - timesteps = timesteps.to(original_samples.device) + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten()