diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index af3d0ddc22..25d4b87983 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -338,8 +338,8 @@ def main(args): lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps, - num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=(len(train_dataloader) * args.num_epochs), ) model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(