mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
cleanup
This commit is contained in:
@@ -125,88 +125,6 @@ class Downsample(nn.Module):
|
||||
return self.down(x)
|
||||
|
||||
|
||||
class UNetUpsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class GlideUpsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class LDMUpsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class GradTTSUpsample(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(GradTTSUpsample, self).__init__()
|
||||
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
# TODO (patil-suraj): needs test
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
|
||||
@@ -82,7 +82,7 @@ def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
#class LinearAttention(nn.Module):
|
||||
# class LinearAttention(nn.Module):
|
||||
# def __init__(self, dim, heads=4, dim_head=32):
|
||||
# super().__init__()
|
||||
# self.heads = heads
|
||||
@@ -102,7 +102,7 @@ def Normalize(in_channels):
|
||||
# return self.to_out(out)
|
||||
#
|
||||
|
||||
#class SpatialSelfAttention(nn.Module):
|
||||
# class SpatialSelfAttention(nn.Module):
|
||||
# def __init__(self, in_channels):
|
||||
# super().__init__()
|
||||
# self.in_channels = in_channels
|
||||
@@ -120,7 +120,7 @@ def Normalize(in_channels):
|
||||
# k = self.k(h_)
|
||||
# v = self.v(h_)
|
||||
#
|
||||
# compute attention
|
||||
# compute attention
|
||||
# b, c, h, w = q.shape
|
||||
# q = rearrange(q, "b c h w -> b (h w) c")
|
||||
# k = rearrange(k, "b c h w -> b c (h w)")
|
||||
@@ -129,7 +129,7 @@ def Normalize(in_channels):
|
||||
# w_ = w_ * (int(c) ** (-0.5))
|
||||
# w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
#
|
||||
# attend to values
|
||||
# attend to values
|
||||
# v = rearrange(v, "b c h w -> b c (h w)")
|
||||
# w_ = rearrange(w_, "b i j -> b j i")
|
||||
# h_ = torch.einsum("bij,bjk->bik", v, w_)
|
||||
@@ -139,6 +139,7 @@ def Normalize(in_channels):
|
||||
# return x + h_
|
||||
#
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user