mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add constant learning rate with custom rule (#3133)
* add constant lr with rules * add constant with rules in TYPE_TO_SCHEDULER_FUNCTION * add constant lr rate with rule * hotfix code quality * fix doc style * change name constant_with_rules to piecewise constant
This commit is contained in:
@@ -34,6 +34,7 @@ class SchedulerType(Enum):
|
||||
POLYNOMIAL = "polynomial"
|
||||
CONSTANT = "constant"
|
||||
CONSTANT_WITH_WARMUP = "constant_with_warmup"
|
||||
PIECEWISE_CONSTANT = "piecewise_constant"
|
||||
|
||||
|
||||
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
|
||||
@@ -77,6 +78,48 @@ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: in
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
|
||||
def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
|
||||
"""
|
||||
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
||||
|
||||
Args:
|
||||
optimizer ([`~torch.optim.Optimizer`]):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
step_rules (`string`):
|
||||
The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
|
||||
if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
|
||||
steps and multiple 0.005 for the other steps.
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
|
||||
Return:
|
||||
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||
"""
|
||||
|
||||
rules_dict = {}
|
||||
rule_list = step_rules.split(",")
|
||||
for rule_str in rule_list[:-1]:
|
||||
value_str, steps_str = rule_str.split(":")
|
||||
steps = int(steps_str)
|
||||
value = float(value_str)
|
||||
rules_dict[steps] = value
|
||||
last_lr_multiple = float(rule_list[-1])
|
||||
|
||||
def create_rules_function(rules_dict, last_lr_multiple):
|
||||
def rule_func(steps: int) -> float:
|
||||
sorted_steps = sorted(rules_dict.keys())
|
||||
for i, sorted_step in enumerate(sorted_steps):
|
||||
if steps < sorted_step:
|
||||
return rules_dict[sorted_steps[i]]
|
||||
return last_lr_multiple
|
||||
|
||||
return rule_func
|
||||
|
||||
rules_func = create_rules_function(rules_dict, last_lr_multiple)
|
||||
|
||||
return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
|
||||
|
||||
|
||||
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
||||
@@ -232,12 +275,14 @@ TYPE_TO_SCHEDULER_FUNCTION = {
|
||||
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
|
||||
SchedulerType.CONSTANT: get_constant_schedule,
|
||||
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
|
||||
SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
|
||||
}
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
name: Union[str, SchedulerType],
|
||||
optimizer: Optimizer,
|
||||
step_rules: Optional[str] = None,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
num_cycles: int = 1,
|
||||
@@ -252,6 +297,8 @@ def get_scheduler(
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
step_rules (`str`, *optional*):
|
||||
A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
@@ -270,6 +317,9 @@ def get_scheduler(
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer, last_epoch=last_epoch)
|
||||
|
||||
if name == SchedulerType.PIECEWISE_CONSTANT:
|
||||
return schedule_func(optimizer, rules=step_rules, last_epoch=last_epoch)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
|
||||
Reference in New Issue
Block a user