1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
patil-suraj
2022-06-27 16:27:24 +02:00
parent abcb25978a
commit ee010726ab
2 changed files with 5 additions and 86 deletions

View File

@@ -125,88 +125,6 @@ class Downsample(nn.Module):
return self.down(x)
class UNetUpsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class GlideUpsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class LDMUpsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class GradTTSUpsample(torch.nn.Module):
def __init__(self, dim):
super(GradTTSUpsample, self).__init__()
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
# TODO (patil-suraj): needs test
class Upsample1d(nn.Module):
def __init__(self, dim):

View File

@@ -82,7 +82,7 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
#class LinearAttention(nn.Module):
# class LinearAttention(nn.Module):
# def __init__(self, dim, heads=4, dim_head=32):
# super().__init__()
# self.heads = heads
@@ -102,7 +102,7 @@ def Normalize(in_channels):
# return self.to_out(out)
#
#class SpatialSelfAttention(nn.Module):
# class SpatialSelfAttention(nn.Module):
# def __init__(self, in_channels):
# super().__init__()
# self.in_channels = in_channels
@@ -120,7 +120,7 @@ def Normalize(in_channels):
# k = self.k(h_)
# v = self.v(h_)
#
# compute attention
# compute attention
# b, c, h, w = q.shape
# q = rearrange(q, "b c h w -> b (h w) c")
# k = rearrange(k, "b c h w -> b c (h w)")
@@ -129,7 +129,7 @@ def Normalize(in_channels):
# w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values
# attend to values
# v = rearrange(v, "b c h w -> b c (h w)")
# w_ = rearrange(w_, "b i j -> b j i")
# h_ = torch.einsum("bij,bjk->bik", v, w_)
@@ -139,6 +139,7 @@ def Normalize(in_channels):
# return x + h_
#
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()