From ee010726ab20ef93a193cdef7a5cdb3478a2df2c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 16:27:24 +0200 Subject: [PATCH] cleanup --- src/diffusers/models/resnet.py | 82 -------------------------------- src/diffusers/models/unet_ldm.py | 9 ++-- 2 files changed, 5 insertions(+), 86 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 4e96221bfe..8d87786991 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -125,88 +125,6 @@ class Downsample(nn.Module): return self.down(x) -class UNetUpsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class GlideUpsample(nn.Module): - """ - An upsampling layer with an optional convolution. - - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - -class LDMUpsample(nn.Module): - """ - An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - -class GradTTSUpsample(torch.nn.Module): - def __init__(self, dim): - super(GradTTSUpsample, self).__init__() - self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) - - def forward(self, x): - return self.conv(x) - - # TODO (patil-suraj): needs test class Upsample1d(nn.Module): def __init__(self, dim): diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 9d17ea3c9b..26aab77570 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -82,7 +82,7 @@ def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) -#class LinearAttention(nn.Module): +# class LinearAttention(nn.Module): # def __init__(self, dim, heads=4, dim_head=32): # super().__init__() # self.heads = heads @@ -102,7 +102,7 @@ def Normalize(in_channels): # return self.to_out(out) # -#class SpatialSelfAttention(nn.Module): +# class SpatialSelfAttention(nn.Module): # def __init__(self, in_channels): # super().__init__() # self.in_channels = in_channels @@ -120,7 +120,7 @@ def Normalize(in_channels): # k = self.k(h_) # v = self.v(h_) # - # compute attention +# compute attention # b, c, h, w = q.shape # q = rearrange(q, "b c h w -> b (h w) c") # k = rearrange(k, "b c h w -> b c (h w)") @@ -129,7 +129,7 @@ def Normalize(in_channels): # w_ = w_ * (int(c) ** (-0.5)) # w_ = torch.nn.functional.softmax(w_, dim=2) # - # attend to values +# attend to values # v = rearrange(v, "b c h w -> b c (h w)") # w_ = rearrange(w_, "b i j -> b j i") # h_ = torch.einsum("bij,bjk->bik", v, w_) @@ -139,6 +139,7 @@ def Normalize(in_channels): # return x + h_ # + class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__()