From d87cc15977b87160c30abaace3894e802ad9e1e6 Mon Sep 17 00:00:00 2001 From: Emil Bogomolov Date: Mon, 19 Dec 2022 16:41:37 -0800 Subject: [PATCH] expose polynomial:power and cosine_with_restarts:num_cycles params (#1737) * expose polynomial:power and cosine_with_restarts:num_cycles using get_scheduler func, add it to train_dreambooth.py * fix formatting * fix style * Update src/diffusers/optimization.py Co-authored-by: Pedro Cuenca Co-authored-by: Pedro Cuenca --- examples/dreambooth/train_dreambooth.py | 9 +++++++++ src/diffusers/optimization.py | 24 +++++++++++++++++++++--- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index de735141d5..122d346ff5 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -204,6 +204,13 @@ def parse_args(input_args=None): parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) @@ -588,6 +595,8 @@ def main(args): optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, ) if args.train_text_encoder: diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py index e7b836b4a6..a5eedc6803 100644 --- a/src/diffusers/optimization.py +++ b/src/diffusers/optimization.py @@ -121,9 +121,9 @@ def get_cosine_schedule_with_warmup( The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. - num_cycles (`float`, *optional*, defaults to 0.5): - The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 - following a half-cosine). + num_periods (`float`, *optional*, defaults to 0.5): + The number of periods of the cosine function in a schedule (the default is to just decrease from the max + value to 0 following a half-cosine). last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. @@ -240,6 +240,8 @@ def get_scheduler( optimizer: Optimizer, num_warmup_steps: Optional[int] = None, num_training_steps: Optional[int] = None, + num_cycles: int = 1, + power: float = 1.0, ): """ Unified API to get any scheduler from its name. @@ -255,6 +257,12 @@ def get_scheduler( num_training_steps (`int``, *optional*): The number of training steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. + num_cycles (`int`, *optional*): + The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. + power (`float`, *optional*, defaults to 1.0): + Power factor. See `POLYNOMIAL` scheduler + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. """ name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] @@ -272,4 +280,14 @@ def get_scheduler( if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles + ) + + if name == SchedulerType.POLYNOMIAL: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power + ) + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)