diff --git a/examples/train_unconditional.py b/examples/train_unconditional.py index 79144e9c02..846dd3eda4 100644 --- a/examples/train_unconditional.py +++ b/examples/train_unconditional.py @@ -9,14 +9,14 @@ from accelerate import Accelerator from datasets import load_dataset from diffusers import DDPM, DDPMScheduler, UNetModel from diffusers.hub_utils import init_git_repo, push_to_hub -from diffusers.modeling_utils import unwrap_model from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel from diffusers.utils import logging from torchvision.transforms import ( CenterCrop, Compose, InterpolationMode, - Lambda, + Normalize, RandomHorizontalFlip, Resize, ToTensor, @@ -48,7 +48,7 @@ def main(args): CenterCrop(args.resolution), RandomHorizontalFlip(), ToTensor(), - Lambda(lambda x: x * 2 - 1), + Normalize([0.5], [0.5]), ] ) dataset = load_dataset(args.dataset, split="train") @@ -71,6 +71,8 @@ def main(args): model, optimizer, train_dataloader, lr_scheduler ) + ema_model = EMAModel(model, inv_gamma=1.0, power=3 / 4) + if args.push_to_hub: repo = init_git_repo(args, at_init=True) @@ -87,6 +89,7 @@ def main(args): logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps}") + global_step = 0 for epoch in range(args.num_epochs): model.train() with tqdm(total=len(train_dataloader), unit="ba") as pbar: @@ -117,19 +120,22 @@ def main(args): torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() + ema_model.step(model, global_step) optimizer.zero_grad() pbar.update(1) - pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) + pbar.set_postfix( + loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay + ) + global_step += 1 - optimizer.step() - if is_distributed: - torch.distributed.barrier() + accelerator.wait_for_everyone() # Generate a sample image for visual inspection - if args.local_rank in [-1, 0]: - model.eval() + if accelerator.is_main_process: with torch.no_grad(): - pipeline = DDPM(unet=unwrap_model(model), noise_scheduler=noise_scheduler) + pipeline = DDPM( + unet=accelerator.unwrap_model(ema_model.averaged_model), noise_scheduler=noise_scheduler + ) generator = torch.manual_seed(0) # run pipeline in inference (sample random noise and denoise) @@ -151,8 +157,7 @@ def main(args): push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) else: pipeline.save_pretrained(args.output_dir) - if is_distributed: - torch.distributed.barrier() + accelerator.wait_for_everyone() if __name__ == "__main__": diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index e69de29bb2..04e3735d60 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -0,0 +1,278 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + +def conv_transpose_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.ConvTranspose1d(*args, **kwargs) + elif dims == 2: + return nn.ConvTranspose2d(*args, **kwargs) + elif dims == 3: + return nn.ConvTranspose3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +def nonlinearity(x, swish=1.0): + # swish + if swish == 1.0: + return F.silu(x) + else: + return x * F.sigmoid(x * float(swish)) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, use_conv_transpose=False, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.use_conv_transpose = use_conv_transpose + + if use_conv_transpose: + self.conv = conv_transpose_nd(dims, channels, out_channels, 4, 2, 1) + elif use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + else: + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + if self.use_conv: + x = self.conv(x) + + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.padding = padding + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.down = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.down = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv and self.padding == 0 and self.dims == 2: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + return self.down(x) + + +class UNetUpsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + +class GlideUpsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class LDMUpsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class GradTTSUpsample(torch.nn.Module): + def __init__(self, dim): + super(Upsample, self).__init__() + self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Upsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +# class ResnetBlock(nn.Module): +# def __init__( +# self, +# *, +# in_channels, +# out_channels=None, +# conv_shortcut=False, +# dropout, +# temb_channels=512, +# use_scale_shift_norm=False, +# ): +# super().__init__() +# self.in_channels = in_channels +# out_channels = in_channels if out_channels is None else out_channels +# self.out_channels = out_channels +# self.use_conv_shortcut = conv_shortcut +# self.use_scale_shift_norm = use_scale_shift_norm + +# self.norm1 = Normalize(in_channels) +# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + +# temp_out_channles = 2 * out_channels if use_scale_shift_norm else out_channels +# self.temb_proj = torch.nn.Linear(temb_channels, temp_out_channles) + +# self.norm2 = Normalize(out_channels) +# self.dropout = torch.nn.Dropout(dropout) +# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) +# if self.in_channels != self.out_channels: +# if self.use_conv_shortcut: +# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) +# else: +# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + +# def forward(self, x, temb): +# h = x +# h = self.norm1(h) +# h = nonlinearity(h) +# h = self.conv1(h) + +# # TODO: check if this broadcasting works correctly for 1D and 3D +# temb = self.temb_proj(nonlinearity(temb))[:, :, None, None] + +# if self.use_scale_shift_norm: +# out_norm, out_rest = self.out_layers[0], self.out_layers[1:] +# scale, shift = torch.chunk(temb, 2, dim=1) +# h = self.norm2(h) * (1 + scale) + shift +# h = out_rest(h) +# else: +# h = h + temb +# h = self.norm2(h) +# h = nonlinearity(h) +# h = self.dropout(h) +# h = self.conv2(h) + +# if self.in_channels != self.out_channels: +# if self.use_conv_shortcut: +# x = self.conv_shortcut(x) +# else: +# x = self.nin_shortcut(x) + +# return x + h diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py new file mode 100644 index 0000000000..99fecaa07f --- /dev/null +++ b/src/diffusers/training_utils.py @@ -0,0 +1,88 @@ +import copy + +import torch + + +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__( + self, + model, + update_after_step=0, + inv_gamma=1.0, + power=2 / 3, + min_value=0.0, + max_value=0.9999, + device=None, + ): + """ + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are + good values for models you plan to train for a million or more steps (reaches decay + factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models + you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at + 215.4k steps). + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 2/3. + min_value (float): The minimum EMA decay rate. Default: 0. + """ + + self.averaged_model = copy.deepcopy(model) + self.averaged_model.requires_grad_(False) + + self.update_after_step = update_after_step + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + self.max_value = max_value + + if device is not None: + self.averaged_model = self.averaged_model.to(device=device) + + self.decay = 0.0 + + def get_decay(self, optimization_step): + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + value = 1 - (1 + step / self.inv_gamma) ** -self.power + + if step <= 0: + return 0.0 + + return max(self.min_value, min(value, self.max_value)) + + @torch.no_grad() + def step(self, new_model, optimization_step): + ema_state_dict = {} + ema_params = self.averaged_model.state_dict() + + self.decay = self.get_decay(optimization_step) + + for key, param in new_model.named_parameters(): + if isinstance(param, dict): + continue + try: + ema_param = ema_params[key] + except KeyError: + ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) + ema_params[key] = ema_param + + if not param.requires_grad: + ema_params[key].copy_(param.to(dtype=ema_param.dtype).data) + ema_param = ema_params[key] + else: + ema_param.mul_(self.decay) + ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) + + ema_state_dict[key] = ema_param + + for key, param in new_model.named_buffers(): + ema_state_dict[key] = param + + self.averaged_model.load_state_dict(ema_state_dict, strict=False)