1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

avoid upcasting by assigning dtype to noise tensor (#3713)

* avoid upcasting by assigning dtype to noise tensor

* make style

* Update train_unconditional.py

* Update train_unconditional.py

* make style

* add unit test for pickle

* revert change

---------

Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
Prathik Rao
2023-07-03 18:49:49 -07:00
committed by GitHub
parent 4e898560ce
commit 1997614aa9
2 changed files with 6 additions and 2 deletions

View File

@@ -568,7 +568,9 @@ def main(args):
clean_images = batch["input"]
# Sample noise that we'll add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
noise = torch.randn(
clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
).to(clean_images.device)
bsz = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(

View File

@@ -557,7 +557,9 @@ def main(args):
clean_images = batch["input"]
# Sample noise that we'll add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
noise = torch.randn(
clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
).to(clean_images.device)
bsz = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(