From fd47d7c0ef46e5a949e53cb4283faf4ca9c33117 Mon Sep 17 00:00:00 2001 From: Jason Kuan Date: Fri, 28 Apr 2023 18:59:56 +0800 Subject: [PATCH] add constant learning rate with custom rule (#3133) * add constant lr with rules * add constant with rules in TYPE_TO_SCHEDULER_FUNCTION * add constant lr rate with rule * hotfix code quality * fix doc style * change name constant_with_rules to piecewise constant --- src/diffusers/optimization.py | 50 +++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py index 657e085062..78d68b7978 100644 --- a/src/diffusers/optimization.py +++ b/src/diffusers/optimization.py @@ -34,6 +34,7 @@ class SchedulerType(Enum): POLYNOMIAL = "polynomial" CONSTANT = "constant" CONSTANT_WITH_WARMUP = "constant_with_warmup" + PIECEWISE_CONSTANT = "piecewise_constant" def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): @@ -77,6 +78,48 @@ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: in return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) +def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + step_rules (`string`): + The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate + if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 + steps and multiple 0.005 for the other steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + rules_dict = {} + rule_list = step_rules.split(",") + for rule_str in rule_list[:-1]: + value_str, steps_str = rule_str.split(":") + steps = int(steps_str) + value = float(value_str) + rules_dict[steps] = value + last_lr_multiple = float(rule_list[-1]) + + def create_rules_function(rules_dict, last_lr_multiple): + def rule_func(steps: int) -> float: + sorted_steps = sorted(rules_dict.keys()) + for i, sorted_step in enumerate(sorted_steps): + if steps < sorted_step: + return rules_dict[sorted_steps[i]] + return last_lr_multiple + + return rule_func + + rules_func = create_rules_function(rules_dict, last_lr_multiple) + + return LambdaLR(optimizer, rules_func, last_epoch=last_epoch) + + def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after @@ -232,12 +275,14 @@ TYPE_TO_SCHEDULER_FUNCTION = { SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, SchedulerType.CONSTANT: get_constant_schedule, SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, + SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule, } def get_scheduler( name: Union[str, SchedulerType], optimizer: Optimizer, + step_rules: Optional[str] = None, num_warmup_steps: Optional[int] = None, num_training_steps: Optional[int] = None, num_cycles: int = 1, @@ -252,6 +297,8 @@ def get_scheduler( The name of the scheduler to use. optimizer (`torch.optim.Optimizer`): The optimizer that will be used during training. + step_rules (`str`, *optional*): + A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler. num_warmup_steps (`int`, *optional*): The number of warmup steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. @@ -270,6 +317,9 @@ def get_scheduler( if name == SchedulerType.CONSTANT: return schedule_func(optimizer, last_epoch=last_epoch) + if name == SchedulerType.PIECEWISE_CONSTANT: + return schedule_func(optimizer, rules=step_rules, last_epoch=last_epoch) + # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")