From ad310af0d65d5a008401ebde806bed413156cf82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Tolga=20Cang=C3=B6z?= <46008593+standardAI@users.noreply.github.com> Date: Mon, 26 Feb 2024 23:39:57 +0300 Subject: [PATCH] Fix EMA in train_text_to_image_sdxl.py (#7048) * Fix typos --- examples/text_to_image/train_text_to_image_sdxl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 2d77e9c8bf..78021b5afe 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -951,6 +951,9 @@ def main(args): unet, optimizer, train_dataloader, lr_scheduler ) + if args.use_ema: + ema_unet.to(accelerator.device) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: @@ -1126,6 +1129,8 @@ def main(args): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + if args.use_ema: + ema_unet.step(unet.parameters()) progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step)