diff --git a/examples/train_unconditional.py b/examples/train_unconditional.py index 92f5d8f462..476ae95293 100644 --- a/examples/train_unconditional.py +++ b/examples/train_unconditional.py @@ -21,7 +21,7 @@ from torchvision.transforms import ( ToTensor, ) from tqdm.auto import tqdm -from transformers import get_linear_schedule_with_warmup +from diffusers.optimization import get_scheduler logger = logging.get_logger(__name__) @@ -60,7 +60,8 @@ def main(args): dataset.set_transform(transforms) train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True) - lr_scheduler = get_linear_schedule_with_warmup( + lr_scheduler = get_scheduler( + "linear", optimizer=optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, @@ -107,11 +108,13 @@ def main(args): output = model(noisy_images, timesteps) # predict the noise residual loss = F.mse_loss(output, noise_samples) + loss = loss / args.gradient_accumulation_steps accelerator.backward(loss) else: output = model(noisy_images, timesteps) # predict the noise residual loss = F.mse_loss(output, noise_samples) + loss = loss / args.gradient_accumulation_steps accelerator.backward(loss) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() diff --git a/src/diffusers/hub_utils.py b/src/diffusers/hub_utils.py index 4315d6307f..c2d1e34f3e 100644 --- a/src/diffusers/hub_utils.py +++ b/src/diffusers/hub_utils.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import os import shutil from pathlib import Path diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py new file mode 100644 index 0000000000..70101aec81 --- /dev/null +++ b/src/diffusers/optimization.py @@ -0,0 +1,276 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch optimization for diffusion models.""" + +import math +from enum import Enum +from typing import Optional, Union + +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +from .utils import logging + + +logger = logging.get_logger(__name__) + + +class SchedulerType(Enum): + LINEAR = "linear" + COSINE = "cosine" + COSINE_WITH_RESTARTS = "cosine_with_restarts" + POLYNOMIAL = "polynomial" + CONSTANT = "constant" + CONSTANT_WITH_WARMUP = "constant_with_warmup" + + +def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) + + +def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate + increases linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 + + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases + linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`int`, *optional*, defaults to 1): + The number of hard restarts to use. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + if progress >= 1.0: + return 0.0 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_polynomial_decay_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 +): + """ + Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the + optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + lr_end (`float`, *optional*, defaults to 1e-7): + The end LR. + power (`float`, *optional*, defaults to 1.0): + Power factor. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT + implementation at + https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + + """ + + lr_init = optimizer.defaults["lr"] + if not (lr_init > lr_end): + raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining**power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +TYPE_TO_SCHEDULER_FUNCTION = { + SchedulerType.LINEAR: get_linear_schedule_with_warmup, + SchedulerType.COSINE: get_cosine_schedule_with_warmup, + SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, + SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, + SchedulerType.CONSTANT: get_constant_schedule, + SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, +} + + +def get_scheduler( + name: Union[str, SchedulerType], + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, +): + """ + Unified API to get any scheduler from its name. + + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) \ No newline at end of file