1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Allow DDPMPipeline half precision (#9222)

Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Seongbin Lim
2024-09-24 08:28:14 +09:00
committed by GitHub
parent 65f9439b56
commit 3e69e241f7

View File

@@ -101,10 +101,10 @@ class DDPMPipeline(DiffusionPipeline):
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = randn_tensor(image_shape, generator=generator)
image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype)
image = image.to(self.device)
else:
image = randn_tensor(image_shape, generator=generator, device=self.device)
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
# set step values
self.scheduler.set_timesteps(num_inference_steps)