From 93df5bb67016a176cab4b58405e4daf5bd1828d9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 16 Oct 2023 14:11:35 +0530 Subject: [PATCH] [Examples] fix unconditioning generation training example for mixed-precision training (#5407) * fix: unconditional generation example * fix: float in loss. * apply styling. --- .../train_unconditional.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index a3baa3b85b..12b63439fa 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -413,6 +413,14 @@ def main(args): model_config=model.config, ) + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers @@ -559,11 +567,9 @@ def main(args): progress_bar.update(1) continue - clean_images = batch["input"] + clean_images = batch["input"].to(weight_dtype) # Sample noise that we'll add to the images - noise = torch.randn( - clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) - ).to(clean_images.device) + noise = torch.randn(clean_images.shape, dtype=weight_dtype, device=clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( @@ -579,15 +585,14 @@ def main(args): model_output = model(noisy_images, timesteps).sample if args.prediction_type == "epsilon": - loss = F.mse_loss(model_output, noise) # this could have different weights! + loss = F.mse_loss(model_output.float(), noise.float()) # this could have different weights! elif args.prediction_type == "sample": alpha_t = _extract_into_tensor( noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1) ) snr_weights = alpha_t / (1 - alpha_t) - loss = snr_weights * F.mse_loss( - model_output, clean_images, reduction="none" - ) # use SNR weighting from distillation paper + # use SNR weighting from distillation paper + loss = snr_weights * F.mse_loss(model_output.float(), clean_images.float(), reduction="none") loss = loss.mean() else: raise ValueError(f"Unsupported prediction type: {args.prediction_type}")