From dc7c49e4e419ef0888647873b0fb2e233fea6dc2 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 15:50:54 +0200 Subject: [PATCH 1/5] add tests for upsample blocks --- src/diffusers/models/resnet.py | 14 ++++++---- tests/test_layers_utils.py | 51 ++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 04e3735d60..2abb5ce6e1 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -1,4 +1,3 @@ - import torch import torch.nn as nn import torch.nn.functional as F @@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs): return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") + def conv_transpose_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. @@ -73,7 +73,7 @@ class Upsample(nn.Module): self.use_conv_transpose = use_conv_transpose if use_conv_transpose: - self.conv = conv_transpose_nd(dims, channels, out_channels, 4, 2, 1) + self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) elif use_conv: self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) @@ -81,15 +81,15 @@ class Upsample(nn.Module): assert x.shape[1] == self.channels if self.use_conv_transpose: return self.conv(x) - + if self.dims == 3: x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") else: x = F.interpolate(x, scale_factor=2.0, mode="nearest") - + if self.use_conv: x = self.conv(x) - + return x @@ -138,6 +138,7 @@ class UNetUpsample(nn.Module): x = self.conv(x) return x + class GlideUpsample(nn.Module): """ An upsampling layer with an optional convolution. @@ -199,13 +200,14 @@ class LDMUpsample(nn.Module): class GradTTSUpsample(torch.nn.Module): def __init__(self, dim): - super(Upsample, self).__init__() + super(GradTTSUpsample, self).__init__() self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) def forward(self, x): return self.conv(x) +# TODO (patil-suraj): needs test class Upsample1d(nn.Module): def __init__(self, dim): super().__init__() diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index 42a4261081..cde7fc6be0 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -22,6 +22,7 @@ import numpy as np import torch from diffusers.models.embeddings import get_timestep_embedding +from diffusers.models.resnet import Upsample from diffusers.testing_utils import floats_tensor, slow, torch_device @@ -113,3 +114,53 @@ class EmbeddingsTests(unittest.TestCase): torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]), 1e-3, ) + + +class UpsampleBlockTests(unittest.TestCase): + def test_upsample_default(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 32, 32) + upsample = Upsample(channels=32, use_conv=False) + with torch.no_grad(): + upsampled = upsample(sample) + + assert upsampled.shape == (1, 32, 64, 64) + output_slice = upsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_upsample_with_conv(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 32, 32) + upsample = Upsample(channels=32, use_conv=True) + with torch.no_grad(): + upsampled = upsample(sample) + + assert upsampled.shape == (1, 32, 64, 64) + output_slice = upsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([0.7145, 1.3773, 0.3492, 0.8448, 1.0839, -0.3341, 0.5956, 0.1250, -0.4841]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_upsample_with_conv_out_dim(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 32, 32) + upsample = Upsample(channels=32, use_conv=True, out_channels=64) + with torch.no_grad(): + upsampled = upsample(sample) + + assert upsampled.shape == (1, 64, 64, 64) + output_slice = upsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([0.2703, 0.1656, -0.2538, -0.0553, -0.2984, 0.1044, 0.1155, 0.2579, 0.7755]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_upsample_with_transpose(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 32, 32) + upsample = Upsample(channels=32, use_conv=False, use_conv_transpose=True) + with torch.no_grad(): + upsampled = upsample(sample) + + assert upsampled.shape == (1, 32, 64, 64) + output_slice = upsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) From 183056f24311f32f69351ad9d5f748dd5627650a Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 16:25:47 +0200 Subject: [PATCH 2/5] consolidate Upsample --- src/diffusers/models/resnet.py | 2 +- src/diffusers/models/unet.py | 17 ++---------- src/diffusers/models/unet_glide.py | 37 +++------------------------ src/diffusers/models/unet_grad_tts.py | 12 ++------- src/diffusers/models/unet_ldm.py | 36 +++----------------------- tests/test_modeling_utils.py | 3 +-- 6 files changed, 14 insertions(+), 93 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 2abb5ce6e1..4e96221bfe 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -64,7 +64,7 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, use_conv_transpose=False, dims=2, out_channels=None): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None): super().__init__() self.channels = channels self.out_channels = out_channels or channels diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index 1749def9b1..fe8802cc7a 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -31,6 +31,7 @@ from tqdm import tqdm from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding +from .resnet import Upsample def nonlinearity(x): @@ -42,20 +43,6 @@ def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - class Downsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() @@ -259,7 +246,7 @@ class UNetModel(ModelMixin, ConfigMixin): up.block = block up.attn = attn if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) + up.upsample = Upsample(block_in, use_conv=resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index c154db9210..9a50b9cb52 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding +from .resnet import Upsample def convert_module_to_f16(l): @@ -125,36 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): return x -class Upsample(nn.Module): - """ - An upsampling layer with an optional convolution. - - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - class Downsample(nn.Module): """ A downsampling layer with an optional convolution. @@ -231,8 +202,8 @@ class ResBlock(TimestepBlock): self.updown = up or down if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) + self.h_upd = Upsample(channels, use_conv=False, dims=dims) + self.x_upd = Upsample(channels, use_conv=False, dims=dims) elif down: self.h_upd = Downsample(channels, False, dims) self.x_upd = Downsample(channels, False, dims) @@ -567,7 +538,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): up=True, ) if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 36bcce53e9..e9666f7456 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -10,6 +10,7 @@ except: from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding +from .resnet import Upsample class Mish(torch.nn.Module): @@ -17,15 +18,6 @@ class Mish(torch.nn.Module): return x * torch.tanh(torch.nn.functional.softplus(x)) -class Upsample(torch.nn.Module): - def __init__(self, dim): - super(Upsample, self).__init__() - self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) - - def forward(self, x): - return self.conv(x) - - class Downsample(torch.nn.Module): def __init__(self, dim): super(Downsample, self).__init__() @@ -166,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), ResnetBlock(dim_in, dim_in, time_emb_dim=dim), Residual(Rezero(LinearAttention(dim_in))), - Upsample(dim_in), + Upsample(dim_in, use_conv_transpose=True), ] ) ) diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index da84391a36..7812e8e4fe 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -17,6 +17,7 @@ except: from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding +from .resnet import Upsample def exists(val): @@ -377,35 +378,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): return x -class Upsample(nn.Module): - """ - An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - class Downsample(nn.Module): """ A downsampling layer with an optional convolution. @@ -480,8 +452,8 @@ class ResBlock(TimestepBlock): self.updown = up or down if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) + self.h_upd = Upsample(channels, use_conv=False, dims=dims) + self.x_upd = Upsample(channels, use_conv=False, dims=dims) elif down: self.h_upd = Downsample(channels, False, dims) self.x_upd = Downsample(channels, False, dims) @@ -948,7 +920,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): up=True, ) if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 697a377f8c..8af1196a0b 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -21,7 +21,7 @@ import unittest import numpy as np import torch -from diffusers import ( +from diffusers import ( # GradTTSPipeline, BDDMPipeline, DDIMPipeline, DDIMScheduler, @@ -30,7 +30,6 @@ from diffusers import ( GlidePipeline, GlideSuperResUNetModel, GlideTextToImageUNetModel, - GradTTSPipeline, GradTTSScheduler, LatentDiffusionPipeline, NCSNpp, From ee010726ab20ef93a193cdef7a5cdb3478a2df2c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 16:27:24 +0200 Subject: [PATCH 3/5] cleanup --- src/diffusers/models/resnet.py | 82 -------------------------------- src/diffusers/models/unet_ldm.py | 9 ++-- 2 files changed, 5 insertions(+), 86 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 4e96221bfe..8d87786991 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -125,88 +125,6 @@ class Downsample(nn.Module): return self.down(x) -class UNetUpsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class GlideUpsample(nn.Module): - """ - An upsampling layer with an optional convolution. - - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - -class LDMUpsample(nn.Module): - """ - An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - -class GradTTSUpsample(torch.nn.Module): - def __init__(self, dim): - super(GradTTSUpsample, self).__init__() - self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) - - def forward(self, x): - return self.conv(x) - - # TODO (patil-suraj): needs test class Upsample1d(nn.Module): def __init__(self, dim): diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 9d17ea3c9b..26aab77570 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -82,7 +82,7 @@ def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) -#class LinearAttention(nn.Module): +# class LinearAttention(nn.Module): # def __init__(self, dim, heads=4, dim_head=32): # super().__init__() # self.heads = heads @@ -102,7 +102,7 @@ def Normalize(in_channels): # return self.to_out(out) # -#class SpatialSelfAttention(nn.Module): +# class SpatialSelfAttention(nn.Module): # def __init__(self, in_channels): # super().__init__() # self.in_channels = in_channels @@ -120,7 +120,7 @@ def Normalize(in_channels): # k = self.k(h_) # v = self.v(h_) # - # compute attention +# compute attention # b, c, h, w = q.shape # q = rearrange(q, "b c h w -> b (h w) c") # k = rearrange(k, "b c h w -> b c (h w)") @@ -129,7 +129,7 @@ def Normalize(in_channels): # w_ = w_ * (int(c) ** (-0.5)) # w_ = torch.nn.functional.softmax(w_, dim=2) # - # attend to values +# attend to values # v = rearrange(v, "b c h w -> b c (h w)") # w_ = rearrange(w_, "b i j -> b j i") # h_ = torch.einsum("bij,bjk->bik", v, w_) @@ -139,6 +139,7 @@ def Normalize(in_channels): # return x + h_ # + class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() From 1cf7933ea234b9aa0ba5b13fbe60740fa855e838 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 27 Jun 2022 17:11:01 +0200 Subject: [PATCH 4/5] Framework-agnostic timestep broadcasting --- examples/train_unconditional.py | 9 ++++--- src/diffusers/schedulers/scheduling_ddpm.py | 12 +++------ src/diffusers/schedulers/scheduling_utils.py | 28 ++++++++++++++++++++ 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/examples/train_unconditional.py b/examples/train_unconditional.py index 846dd3eda4..fe45f2a5fa 100644 --- a/examples/train_unconditional.py +++ b/examples/train_unconditional.py @@ -7,7 +7,7 @@ import torch.nn.functional as F import PIL.Image from accelerate import Accelerator from datasets import load_dataset -from diffusers import DDPM, DDPMScheduler, UNetModel +from diffusers import DDPMPipeline, DDPMScheduler, UNetModel from diffusers.hub_utils import init_git_repo, push_to_hub from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel @@ -71,7 +71,7 @@ def main(args): model, optimizer, train_dataloader, lr_scheduler ) - ema_model = EMAModel(model, inv_gamma=1.0, power=3 / 4) + ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) if args.push_to_hub: repo = init_git_repo(args, at_init=True) @@ -133,7 +133,7 @@ def main(args): # Generate a sample image for visual inspection if accelerator.is_main_process: with torch.no_grad(): - pipeline = DDPM( + pipeline = DDPMPipeline( unet=accelerator.unwrap_model(ema_model.averaged_model), noise_scheduler=noise_scheduler ) @@ -172,6 +172,9 @@ if __name__ == "__main__": parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--warmup_steps", type=int, default=500) + parser.add_argument("--ema_inv_gamma", type=float, default=1.0) + parser.add_argument("--ema_power", type=float, default=3/4) + parser.add_argument("--ema_max_decay", type=float, default=0.999) parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--hub_token", type=str, default=None) parser.add_argument("--hub_model_id", type=str, default=None) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 5dea0b22b3..d908850dfe 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -144,16 +144,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): return pred_prev_sample def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor): - if timesteps.dim() != 1: - raise ValueError("`timesteps` must be a 1D tensor") - - device = original_samples.device - batch_size = original_samples.shape[0] - timesteps = timesteps.reshape(batch_size, 1, 1, 1) - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - noisy_samples = sqrt_alpha_prod.to(device) * original_samples + sqrt_one_minus_alpha_prod.to(device) * noise + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index a6f317852d..4cfbc5e59d 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -14,6 +14,8 @@ import numpy as np import torch +from typing import Union + SCHEDULER_CONFIG_NAME = "scheduler_config.json" @@ -50,3 +52,29 @@ class SchedulerMixin: return torch.log(tensor) raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def match_shape( + self, + values: Union[np.ndarray, torch.Tensor], + broadcast_array: Union[np.ndarray, torch.Tensor] + ): + """ + Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. + + Args: + timesteps: an array or tensor of values to extract. + broadcast_array: an array with a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + Returns: + a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + + tensor_format = getattr(self, "tensor_format", "pt") + values = values.flatten() + + while len(values.shape) < len(broadcast_array.shape): + values = values[..., None] + if tensor_format == "pt": + values = values.to(broadcast_array.device) + + return values From 07ff0abff4484aad441ceb64c11e60887aac4522 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 27 Jun 2022 17:25:59 +0200 Subject: [PATCH 5/5] Glide and LDM training experiments --- .../experimental/train_glide_text_to_image.py | 201 ++++++++++++++++++ examples/train_latent_text_to_image.py | 76 ++++--- 2 files changed, 246 insertions(+), 31 deletions(-) create mode 100644 examples/experimental/train_glide_text_to_image.py diff --git a/examples/experimental/train_glide_text_to_image.py b/examples/experimental/train_glide_text_to_image.py new file mode 100644 index 0000000000..9b1f28d680 --- /dev/null +++ b/examples/experimental/train_glide_text_to_image.py @@ -0,0 +1,201 @@ +import argparse +import os + +import torch +import torch.nn.functional as F + +import bitsandbytes as bnb +import PIL.Image +from accelerate import Accelerator +from datasets import load_dataset +from diffusers import DDPMScheduler, Glide, GlideUNetModel +from diffusers.hub_utils import init_git_repo, push_to_hub +from diffusers.optimization import get_scheduler +from diffusers.utils import logging +from torchvision.transforms import ( + CenterCrop, + Compose, + InterpolationMode, + Normalize, + RandomHorizontalFlip, + Resize, + ToTensor, +) +from tqdm.auto import tqdm + + +logger = logging.get_logger(__name__) + + +def main(args): + accelerator = Accelerator(mixed_precision=args.mixed_precision) + + pipeline = Glide.from_pretrained("fusing/glide-base") + model = pipeline.text_unet + noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt") + optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr) + + augmentations = Compose( + [ + Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), + CenterCrop(args.resolution), + RandomHorizontalFlip(), + ToTensor(), + Normalize([0.5], [0.5]), + ] + ) + dataset = load_dataset(args.dataset, split="train") + + text_encoder = pipeline.text_encoder.eval() + + def transforms(examples): + images = [augmentations(image.convert("RGB")) for image in examples["image"]] + text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt") + text_inputs = text_inputs.input_ids.to(accelerator.device) + with torch.no_grad(): + text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs).last_hidden_state + return {"images": images, "text_embeddings": text_embeddings} + + dataset.set_transform(transforms) + train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True) + + lr_scheduler = get_scheduler( + "linear", + optimizer=optimizer, + num_warmup_steps=args.warmup_steps, + num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, + ) + + model, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + + if args.push_to_hub: + repo = init_git_repo(args, at_init=True) + + # Train! + is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + world_size = torch.distributed.get_world_size() if is_distributed else 1 + total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size + max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataloader.dataset)}") + logger.info(f" Num Epochs = {args.num_epochs}") + logger.info(f" Instantaneous batch size per device = {args.batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps}") + + for epoch in range(args.num_epochs): + model.train() + with tqdm(total=len(train_dataloader), unit="ba") as pbar: + pbar.set_description(f"Epoch {epoch}") + for step, batch in enumerate(train_dataloader): + clean_images = batch["images"] + batch_size, n_channels, height, width = clean_images.shape + noise_samples = torch.randn(clean_images.shape).to(clean_images.device) + timesteps = torch.randint( + 0, noise_scheduler.timesteps, (batch_size,), device=clean_images.device + ).long() + + # add noise onto the clean images according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps) + + if step % args.gradient_accumulation_steps != 0: + with accelerator.no_sync(model): + model_output = model(noisy_images, timesteps, batch["text_embeddings"]) + model_output, model_var_values = torch.split(model_output, n_channels, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1) + + # predict the noise residual + loss = F.mse_loss(model_output, noise_samples) + + loss = loss / args.gradient_accumulation_steps + + accelerator.backward(loss) + optimizer.step() + else: + model_output = model(noisy_images, timesteps, batch["text_embeddings"]) + model_output, model_var_values = torch.split(model_output, n_channels, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1) + + # predict the noise residual + loss = F.mse_loss(model_output, noise_samples) + loss = loss / args.gradient_accumulation_steps + accelerator.backward(loss) + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + pbar.update(1) + pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) + + accelerator.wait_for_everyone() + + # Generate a sample image for visual inspection + if accelerator.is_main_process: + model.eval() + with torch.no_grad(): + pipeline.unet = accelerator.unwrap_model(model) + + generator = torch.manual_seed(0) + # run pipeline in inference (sample random noise and denoise) + image = pipeline("a clip art of a corgi", generator=generator, num_upscale_inference_steps=50) + + # process image to PIL + image_processed = image.squeeze(0) + image_processed = ((image_processed + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + image_pil = PIL.Image.fromarray(image_processed) + + # save image + test_dir = os.path.join(args.output_dir, "test_samples") + os.makedirs(test_dir, exist_ok=True) + image_pil.save(f"{test_dir}/{epoch:04d}.png") + + # save the model + if args.push_to_hub: + push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) + else: + pipeline.save_pretrained(args.output_dir) + accelerator.wait_for_everyone() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument("--local_rank", type=int, default=-1) + parser.add_argument("--dataset", type=str, default="fusing/dog_captions") + parser.add_argument("--output_dir", type=str, default="glide-text2image") + parser.add_argument("--overwrite_output_dir", action="store_true") + parser.add_argument("--resolution", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--num_epochs", type=int, default=100) + parser.add_argument("--gradient_accumulation_steps", type=int, default=4) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--warmup_steps", type=int, default=500) + parser.add_argument("--push_to_hub", action="store_true") + parser.add_argument("--hub_token", type=str, default=None) + parser.add_argument("--hub_model_id", type=str, default=None) + parser.add_argument("--hub_private_repo", action="store_true") + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + main(args) diff --git a/examples/train_latent_text_to_image.py b/examples/train_latent_text_to_image.py index fd823fdad9..7cbfa2c49d 100644 --- a/examples/train_latent_text_to_image.py +++ b/examples/train_latent_text_to_image.py @@ -4,19 +4,19 @@ import os import torch import torch.nn.functional as F +import bitsandbytes as bnb import PIL.Image from accelerate import Accelerator from datasets import load_dataset -from diffusers import DDPM, DDPMScheduler, UNetLDMModel +from diffusers import DDPMScheduler, LatentDiffusion, UNetLDMModel from diffusers.hub_utils import init_git_repo, push_to_hub -from diffusers.modeling_utils import unwrap_model from diffusers.optimization import get_scheduler from diffusers.utils import logging from torchvision.transforms import ( CenterCrop, Compose, InterpolationMode, - Lambda, + Normalize, RandomHorizontalFlip, Resize, ToTensor, @@ -30,6 +30,8 @@ logger = logging.get_logger(__name__) def main(args): accelerator = Accelerator(mixed_precision=args.mixed_precision) + pipeline = LatentDiffusion.from_pretrained("fusing/latent-diffusion-text2im-large") + pipeline.unet = None # this model will be trained from scratch now model = UNetLDMModel( attention_resolutions=[4, 2, 1], channel_mult=[1, 2, 4, 4], @@ -37,7 +39,7 @@ def main(args): conv_resample=True, dims=2, dropout=0, - image_size=32, + image_size=8, in_channels=4, model_channels=320, num_heads=8, @@ -51,7 +53,7 @@ def main(args): legacy=False, ) noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt") - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr) augmentations = Compose( [ @@ -59,14 +61,22 @@ def main(args): CenterCrop(args.resolution), RandomHorizontalFlip(), ToTensor(), - Lambda(lambda x: x * 2 - 1), + Normalize([0.5], [0.5]), ] ) dataset = load_dataset(args.dataset, split="train") + text_encoder = pipeline.bert.eval() + vqvae = pipeline.vqvae.eval() + def transforms(examples): images = [augmentations(image.convert("RGB")) for image in examples["image"]] - return {"input": images} + text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt") + with torch.no_grad(): + text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs.input_ids.cpu()).last_hidden_state + images = 1 / 0.18215 * torch.stack(images, dim=0) + latents = accelerator.unwrap_model(vqvae).encode(images.cpu()).mode() + return {"images": images, "text_embeddings": text_embeddings, "latents": latents} dataset.set_transform(transforms) train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True) @@ -78,9 +88,11 @@ def main(args): num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, ) - model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - model, optimizer, train_dataloader, lr_scheduler + model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler ) + text_encoder = text_encoder.cpu() + vqvae = vqvae.cpu() if args.push_to_hub: repo = init_git_repo(args, at_init=True) @@ -98,29 +110,31 @@ def main(args): logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps}") + global_step = 0 for epoch in range(args.num_epochs): model.train() with tqdm(total=len(train_dataloader), unit="ba") as pbar: pbar.set_description(f"Epoch {epoch}") for step, batch in enumerate(train_dataloader): - clean_images = batch["input"] - noise_samples = torch.randn(clean_images.shape).to(clean_images.device) - bsz = clean_images.shape[0] - timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long() + clean_latents = batch["latents"] + noise_samples = torch.randn(clean_latents.shape).to(clean_latents.device) + bsz = clean_latents.shape[0] + timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_latents.device).long() - # add noise onto the clean images according to the noise magnitude at each timestep + # add noise onto the clean latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps) + noisy_latents = noise_scheduler.training_step(clean_latents, noise_samples, timesteps) if step % args.gradient_accumulation_steps != 0: with accelerator.no_sync(model): - output = model(noisy_images, timesteps) + output = model(noisy_latents, timesteps, context=batch["text_embeddings"]) # predict the noise residual loss = F.mse_loss(output, noise_samples) loss = loss / args.gradient_accumulation_steps accelerator.backward(loss) + optimizer.step() else: - output = model(noisy_images, timesteps) + output = model(noisy_latents, timesteps, context=batch["text_embeddings"]) # predict the noise residual loss = F.mse_loss(output, noise_samples) loss = loss / args.gradient_accumulation_steps @@ -131,24 +145,25 @@ def main(args): optimizer.zero_grad() pbar.update(1) pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) + global_step += 1 - optimizer.step() - if is_distributed: - torch.distributed.barrier() + accelerator.wait_for_everyone() # Generate a sample image for visual inspection - if args.local_rank in [-1, 0]: + if accelerator.is_main_process: model.eval() with torch.no_grad(): - pipeline = DDPM(unet=unwrap_model(model), noise_scheduler=noise_scheduler) + pipeline.unet = accelerator.unwrap_model(model) generator = torch.manual_seed(0) # run pipeline in inference (sample random noise and denoise) - image = pipeline(generator=generator) + image = pipeline( + ["a clip art of a corgi"], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50 + ) # process image to PIL image_processed = image.cpu().permute(0, 2, 3, 1) - image_processed = (image_processed + 1.0) * 127.5 + image_processed = image_processed * 255.0 image_processed = image_processed.type(torch.uint8).numpy() image_pil = PIL.Image.fromarray(image_processed[0]) @@ -162,20 +177,19 @@ def main(args): push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) else: pipeline.save_pretrained(args.output_dir) - if is_distributed: - torch.distributed.barrier() + accelerator.wait_for_everyone() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument("--local_rank", type=int, default=-1) - parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories") - parser.add_argument("--output_dir", type=str, default="ddpm-model") + parser.add_argument("--dataset", type=str, default="fusing/dog_captions") + parser.add_argument("--output_dir", type=str, default="ldm-text2image") parser.add_argument("--overwrite_output_dir", action="store_true") - parser.add_argument("--resolution", type=int, default=64) - parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--resolution", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--num_epochs", type=int, default=100) - parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--gradient_accumulation_steps", type=int, default=16) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--warmup_steps", type=int, default=500) parser.add_argument("--push_to_hub", action="store_true")