mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[scheduler] fix some scheduler dtype error (#2992)
Co-authored-by: wangguan <dizhipeng.dzp@alibaba-inc.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -201,7 +201,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
|
||||
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
|
||||
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
|
||||
|
||||
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
|
||||
|
||||
@@ -190,7 +190,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
# interpolate timesteps
|
||||
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
|
||||
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
|
||||
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
|
||||
|
||||
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
|
||||
|
||||
Reference in New Issue
Block a user