1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add missing restore() EMA call in train SDXL script (#7599)

* Restore unet params back to normal from EMA when validation call is finished

* empty commit

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Christopher Beckham
2024-04-09 08:37:55 -04:00
committed by GitHub
parent 7e808e768a
commit 8e46d97cd8

View File

@@ -1252,6 +1252,10 @@ def main(args):
del pipeline
torch.cuda.empty_cache()
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unwrap_model(unet)