From cb0bf0bd0b8d4ab41855dc687392a7a80ccd8af7 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 7 Oct 2022 07:02:32 -0700 Subject: [PATCH] fix(DDIM scheduler): use correct dtype for noise (#742) Otherwise, it crashes when eta > 0 with float16. --- src/diffusers/schedulers/scheduling_ddim.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 1aebb27a35..8e2f5d9298 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -283,8 +283,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if eta > 0: + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 device = model_output.device if torch.is_tensor(model_output) else "cpu" - noise = torch.randn(model_output.shape, generator=generator).to(device) + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise prev_sample = prev_sample + variance