From f861cde14f168ee3da391ed423d411968b456b9d Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 17 Jan 2023 10:11:46 +0100 Subject: [PATCH] [train_unconditional] fix LR scheduler init (#2010) fix lr scheduler --- .../unconditional_image_generation/train_unconditional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index af3d0ddc22..25d4b87983 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -338,8 +338,8 @@ def main(args): lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps, - num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=(len(train_dataloader) * args.num_epochs), ) model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(