mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Allow setting num_cycles for cosine_with_restarts lr scheduler (#3606)
Expose num_cycles kwarg of get_schedule() through args.lr_num_cycles.
This commit is contained in:
@@ -285,6 +285,12 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_num_cycles",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
@@ -739,6 +745,7 @@ def main():
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_num_cycles * args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
|
||||
Reference in New Issue
Block a user