1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix(DDIM scheduler): use correct dtype for noise (#742)

Otherwise, it crashes when eta > 0 with float16.
This commit is contained in:
Kevin Turner
2022-10-07 07:02:32 -07:00
committed by GitHub
parent e0fece2b26
commit cb0bf0bd0b

View File

@@ -283,8 +283,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
device = model_output.device if torch.is_tensor(model_output) else "cpu"
noise = torch.randn(model_output.shape, generator=generator).to(device)
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
prev_sample = prev_sample + variance