diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index a341db3aa3..88adbb9955 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1252,6 +1252,10 @@ def main(args): del pipeline torch.cuda.empty_cache() + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unwrap_model(unet)