mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix(DDIM scheduler): use correct dtype for noise (#742)
Otherwise, it crashes when eta > 0 with float16.
This commit is contained in:
@@ -283,8 +283,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
|
||||
if eta > 0:
|
||||
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
|
||||
device = model_output.device if torch.is_tensor(model_output) else "cpu"
|
||||
noise = torch.randn(model_output.shape, generator=generator).to(device)
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
|
||||
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
|
||||
|
||||
prev_sample = prev_sample + variance
|
||||
|
||||
Reference in New Issue
Block a user