From 8e46d97cd894998fb944b5021cdd4f7d6bfcd39c Mon Sep 17 00:00:00 2001 From: Christopher Beckham Date: Tue, 9 Apr 2024 08:37:55 -0400 Subject: [PATCH] Add missing restore() EMA call in train SDXL script (#7599) * Restore unet params back to normal from EMA when validation call is finished * empty commit --------- Co-authored-by: Sayak Paul --- examples/text_to_image/train_text_to_image_sdxl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index a341db3aa3..88adbb9955 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1252,6 +1252,10 @@ def main(args): del pipeline torch.cuda.empty_cache() + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unwrap_model(unet)