From e40526431ad9d1fbc36bf52aadb172dcd620dd4e Mon Sep 17 00:00:00 2001 From: FurryPotato <1169028312@qq.com> Date: Thu, 6 Apr 2023 21:55:33 +0800 Subject: [PATCH] [scheduler] fix some scheduler dtype error (#2992) Co-authored-by: wangguan Co-authored-by: Patrick von Platen --- .../schedulers/scheduling_k_dpm_2_ancestral_discrete.py | 2 +- src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index c8b1f2c3be..b8205455d6 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -201,7 +201,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): else: timesteps = torch.from_numpy(timesteps).to(device) - timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) + timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 809da798f8..b49cc2e544 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -190,7 +190,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): timesteps = torch.from_numpy(timesteps).to(device) # interpolate timesteps - timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) + timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])