1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[scheduler] fix some scheduler dtype error (#2992)

Co-authored-by: wangguan <dizhipeng.dzp@alibaba-inc.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
FurryPotato
2023-04-06 21:55:33 +08:00
committed by GitHub
parent 24947317a6
commit e40526431a
2 changed files with 2 additions and 2 deletions

View File

@@ -201,7 +201,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
else:
timesteps = torch.from_numpy(timesteps).to(device)
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])

View File

@@ -190,7 +190,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps = torch.from_numpy(timesteps).to(device)
# interpolate timesteps
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])