diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 79c448f05e..5edb245722 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -1,4 +1,5 @@ import argparse +import math import os import torch @@ -29,6 +30,7 @@ logger = get_logger(__name__) def main(args): logging_dir = os.path.join(args.output_dir, args.logging_dir) accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with="tensorboard", logging_dir=logging_dir, @@ -105,6 +107,8 @@ def main(args): model, optimizer, train_dataloader, lr_scheduler ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) if args.push_to_hub: @@ -117,7 +121,7 @@ def main(args): global_step = 0 for epoch in range(args.num_epochs): model.train() - progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) + progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process) progress_bar.set_description(f"Epoch {epoch}") for step, batch in enumerate(train_dataloader): clean_images = batch["input"] @@ -146,13 +150,16 @@ def main(args): ema_model.step(model) optimizer.zero_grad() - progress_bar.update(1) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} if args.use_ema: logs["ema_decay"] = ema_model.decay progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) - global_step += 1 progress_bar.close() accelerator.wait_for_everyone()