From 1997614aa9525ef0f49858ac409540fdf2f02e9d Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 3 Jul 2023 18:49:49 -0700 Subject: [PATCH] avoid upcasting by assigning dtype to noise tensor (#3713) * avoid upcasting by assigning dtype to noise tensor * make style * Update train_unconditional.py * Update train_unconditional.py * make style * add unit test for pickle * revert change --------- Co-authored-by: root Co-authored-by: Patrick von Platen Co-authored-by: Prathik Rao --- .../unconditional_image_generation/train_unconditional.py | 4 +++- .../unconditional_image_generation/train_unconditional.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index a42187fade..12ff40bbd6 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -568,7 +568,9 @@ def main(args): clean_images = batch["input"] # Sample noise that we'll add to the images - noise = torch.randn(clean_images.shape).to(clean_images.device) + noise = torch.randn( + clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) + ).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index d6e4b17ba8..e10e6d3024 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -557,7 +557,9 @@ def main(args): clean_images = batch["input"] # Sample noise that we'll add to the images - noise = torch.randn(clean_images.shape).to(clean_images.device) + noise = torch.randn( + clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) + ).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint(