diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 9a7eaa2ecd..49c1564253 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -150,81 +150,6 @@ class Downsample(nn.Module): return self.op(x) -# class UNetUpsample(nn.Module): -# def __init__(self, in_channels, with_conv): -# super().__init__() -# self.with_conv = with_conv -# if self.with_conv: -# self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) -# -# def forward(self, x): -# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") -# if self.with_conv: -# x = self.conv(x) -# return x -# -# -# class GlideUpsample(nn.Module): -# """ -# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param # -use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If -# 3D, then # upsampling occurs in the inner-two dimensions. #""" -# -# def __init__(self, channels, use_conv, dims=2, out_channels=None): -# super().__init__() -# self.channels = channels -# self.out_channels = out_channels or channels -# self.use_conv = use_conv -# self.dims = dims -# if use_conv: -# self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) -# -# def forward(self, x): -# assert x.shape[1] == self.channels -# if self.dims == 3: -# x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") -# else: -# x = F.interpolate(x, scale_factor=2, mode="nearest") -# if self.use_conv: -# x = self.conv(x) -# return x -# -# -# class LDMUpsample(nn.Module): -# """ -# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param # # -use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If -# 3D, then # upsampling occurs in the inner-two dimensions. #""" -# -# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): -# super().__init__() -# self.channels = channels -# self.out_channels = out_channels or channels -# self.use_conv = use_conv -# self.dims = dims -# if use_conv: -# self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) -# -# def forward(self, x): -# assert x.shape[1] == self.channels -# if self.dims == 3: -# x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") -# else: -# x = F.interpolate(x, scale_factor=2, mode="nearest") -# if self.use_conv: -# x = self.conv(x) -# return x -# -# -# class GradTTSUpsample(torch.nn.Module): -# def __init__(self, dim): -# super(Upsample, self).__init__() -# self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) -# -# def forward(self, x): -# return self.conv(x) -# -# # TODO (patil-suraj): needs test # class Upsample1d(nn.Module): # def __init__(self, dim):