mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -951,6 +951,9 @@ def main(args):
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
if args.use_ema:
|
||||
ema_unet.to(accelerator.device)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
@@ -1126,6 +1129,8 @@ def main(args):
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
if args.use_ema:
|
||||
ema_unet.step(unet.parameters())
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||
|
||||
Reference in New Issue
Block a user