mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add missing restore() EMA call in train SDXL script (#7599)
* Restore unet params back to normal from EMA when validation call is finished * empty commit --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
committed by
GitHub
parent
7e808e768a
commit
8e46d97cd8
@@ -1252,6 +1252,10 @@ def main(args):
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if args.use_ema:
|
||||
# Switch back to the original UNet parameters.
|
||||
ema_unet.restore(unet.parameters())
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = unwrap_model(unet)
|
||||
|
||||
Reference in New Issue
Block a user