From 321f9791d6a491ed140fd2cd26f56f45bbaa9f4a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 3 Jul 2022 22:26:33 +0200 Subject: [PATCH] Downsample / Upsample - clean to 1D and 2D (#68) * make unet rl work * uploaad files / code * upload files * make style correct * finish --- src/diffusers/models/resnet.py | 283 ++++++++++++++---- src/diffusers/models/unet.py | 6 +- src/diffusers/models/unet_glide.py | 8 +- src/diffusers/models/unet_grad_tts.py | 6 +- src/diffusers/models/unet_ldm.py | 6 +- src/diffusers/models/unet_rl.py | 6 +- .../models/unet_sde_score_estimation.py | 141 +-------- src/diffusers/models/vae.py | 6 +- src/diffusers/pipelines/bddm/pipeline_bddm.py | 2 +- tests/test_layers_utils.py | 22 +- 10 files changed, 254 insertions(+), 232 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 5aac8d9510..f851754859 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -6,46 +6,7 @@ import torch.nn as nn import torch.nn.functional as F -def avg_pool_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D average pooling module. - """ - if dims == 1: - return nn.AvgPool1d(*args, **kwargs) - elif dims == 2: - return nn.AvgPool2d(*args, **kwargs) - elif dims == 3: - return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def conv_transpose_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.ConvTranspose1d(*args, **kwargs) - elif dims == 2: - return nn.ConvTranspose2d(*args, **kwargs) - elif dims == 3: - return nn.ConvTranspose3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -class Upsample(nn.Module): +class Upsample2D(nn.Module): """ An upsampling layer with an optional convolution. @@ -54,21 +15,21 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv - self.dims = dims self.use_conv_transpose = use_conv_transpose self.name = name conv = None if use_conv_transpose: - conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) + conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) elif use_conv: - conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": self.conv = conv else: @@ -79,11 +40,9 @@ class Upsample(nn.Module): if self.use_conv_transpose: return self.conv(x) - if self.dims == 3: - x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") - else: - x = F.interpolate(x, scale_factor=2.0, mode="nearest") + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": x = self.conv(x) @@ -93,7 +52,7 @@ class Upsample(nn.Module): return x -class Downsample(nn.Module): +class Downsample2D(nn.Module): """ A downsampling layer with an optional convolution. @@ -102,22 +61,22 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv - self.dims = dims self.padding = padding - stride = 2 if dims != 3 else (1, 2, 2) + stride = 2 self.name = name if use_conv: - conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) + conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels - conv = avg_pool_nd(dims, kernel_size=stride, stride=stride) + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": self.conv = conv elif name == "Conv2d_0": @@ -127,10 +86,11 @@ class Downsample(nn.Module): def forward(self, x): assert x.shape[1] == self.channels - if self.use_conv and self.padding == 0 and self.dims == 2: + if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) x = F.pad(x, pad, mode="constant", value=0) + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.name == "conv": return self.conv(x) elif self.name == "Conv2d_0": @@ -139,8 +99,204 @@ class Downsample(nn.Module): return self.op(x) +class Upsample1D(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + if self.use_conv: + x = self.conv(x) + + return x + + +class Downsample1D(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.conv(x) + + +class FirUpsample2D(nn.Module): + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.use_conv = use_conv + self.fir_kernel = fir_kernel + self.out_channels = out_channels + + def forward(self, x): + if self.use_conv: + h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) + h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + h = upsample_2d(x, self.fir_kernel, factor=2) + + return h + + +class FirDownsample2D(nn.Module): + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.fir_kernel = fir_kernel + self.use_conv = use_conv + self.out_channels = out_channels + + def forward(self, x): + if self.use_conv: + x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) + x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + x = downsample_2d(x, self.fir_kernel, factor=2) + + return x + + +def _conv_downsample_2d(x, w, k=None, factor=2, gain=1): + """Fused `Conv2d()` followed by `downsample_2d()`. + + Args: + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary + order. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype + as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + _outC, _inC, convH, convW = w.shape + assert convW == convH + if k is None: + k = [1] * factor + k = _setup_kernel(k) * gain + p = (k.shape[0] - factor) + (convW - 1) + s = [factor, factor] + x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2)) + return F.conv2d(x, w, stride=s, padding=0) + + +def _upsample_conv_2d(x, w, k=None, factor=2, gain=1): + """Fused `upsample_2d()` followed by `Conv2d()`. + + Args: + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary + order. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as + `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Check weight shape. + assert len(w.shape) == 4 + convH = w.shape[2] + convW = w.shape[3] + inC = w.shape[1] + + assert convW == convH + + # Setup filter kernel. + if k is None: + k = [1] * factor + k = _setup_kernel(k) * (gain * (factor**2)) + p = (k.shape[0] - factor) - (convW - 1) + + stride = (factor, factor) + + # Determine data dimensions. + stride = [1, 1, factor, factor] + output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW) + output_padding = ( + output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (x.shape[3] - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + num_groups = x.shape[1] // inC + + # Transpose weights. + w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) + w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) + w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) + + x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) + + return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) + + # TODO (patil-suraj): needs test -# class Upsample1d(nn.Module): +# class Upsample2D1d(nn.Module): # def __init__(self, dim): # super().__init__() # self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) @@ -221,7 +377,7 @@ class ResnetBlock2D(nn.Module): elif kernel == "sde_vp": self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") else: - self.upsample = Upsample(in_channels, use_conv=False, dims=2) + self.upsample = Upsample2D(in_channels, use_conv=False) elif self.down: if kernel == "fir": fir_kernel = (1, 3, 3, 1) @@ -229,7 +385,7 @@ class ResnetBlock2D(nn.Module): elif kernel == "sde_vp": self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) else: - self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") + self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut @@ -257,7 +413,6 @@ class ResnetBlock2D(nn.Module): else: self.res_conv = torch.nn.Identity() elif self.overwrite_for_ldm: - dims = 2 channels = in_channels emb_channels = temb_channels use_scale_shift_norm = False @@ -266,7 +421,7 @@ class ResnetBlock2D(nn.Module): self.in_layers = nn.Sequential( normalization(channels, swish=1.0), nn.Identity(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), + nn.Conv2d(channels, self.out_channels, 3, padding=1), ) self.emb_layers = nn.Sequential( nn.SiLU(), @@ -279,12 +434,12 @@ class ResnetBlock2D(nn.Module): normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), nn.SiLU() if use_scale_shift_norm else nn.Identity(), nn.Dropout(p=dropout), - zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), + zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), ) if self.out_channels == in_channels: self.skip_connection = nn.Identity() else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) elif self.overwrite_for_score_vde: in_ch = in_channels out_ch = out_channels @@ -631,7 +786,7 @@ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, def upsample_2d(x, k=None, factor=2, gain=1): - r"""Upsample a batch of 2D images with the given filter. + r"""Upsample2D a batch of 2D images with the given filter. Args: Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given @@ -656,7 +811,7 @@ def upsample_2d(x, k=None, factor=2, gain=1): def downsample_2d(x, k=None, factor=2, gain=1): - r"""Downsample a batch of 2D images with the given filter. + r"""Downsample2D a batch of 2D images with the given filter. Args: Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index fcd4ad540c..aebf9b610b 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import get_timestep_embedding -from .resnet import Downsample, ResnetBlock2D, Upsample +from .resnet import Downsample2D, ResnetBlock2D, Upsample2D def nonlinearity(x): @@ -100,7 +100,7 @@ class UNetModel(ModelMixin, ConfigMixin): down.block = block down.attn = attn if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) + down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0) curr_res = curr_res // 2 self.down.append(down) @@ -139,7 +139,7 @@ class UNetModel(ModelMixin, ConfigMixin): up.block = block up.attn = attn if i_level != 0: - up.upsample = Upsample(block_in, use_conv=resamp_with_conv) + up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index ded61e5fc2..960d8416ec 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import get_timestep_embedding -from .resnet import Downsample, ResnetBlock2D, Upsample +from .resnet import Downsample2D, ResnetBlock2D, Upsample2D def convert_module_to_f16(l): @@ -218,9 +218,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): down=True, ) if resblock_updown - else Downsample( - ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op" - ) + else Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op") ) ) ch = out_ch @@ -299,7 +297,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): up=True, ) if resblock_updown - else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch) + else Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index e8440d20ca..32d36399d3 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -4,7 +4,7 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import LinearAttention from .embeddings import get_timestep_embedding -from .resnet import Downsample, ResnetBlock2D, Upsample +from .resnet import Downsample2D, ResnetBlock2D, Upsample2D class Mish(torch.nn.Module): @@ -105,7 +105,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): overwrite_for_grad_tts=True, ), Residual(Rezero(LinearAttention(dim_out))), - Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(), + Downsample2D(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(), ] ) ) @@ -158,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): overwrite_for_grad_tts=True, ), Residual(Rezero(LinearAttention(dim_in))), - Upsample(dim_in, use_conv_transpose=True), + Upsample2D(dim_in, use_conv_transpose=True), ] ) ) diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 5fb8e5a04c..1589b75b20 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import get_timestep_embedding -from .resnet import Downsample, ResnetBlock2D, Upsample +from .resnet import Downsample2D, ResnetBlock2D, Upsample2D # from .resnet import ResBlock @@ -350,7 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): out_ch = ch self.input_blocks.append( TimestepEmbedSequential( - Downsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op") + Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op") ) ) ch = out_ch @@ -437,7 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ) if level and i == num_res_blocks: out_ch = ch - layers.append(Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)) + layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch)) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index 872e2340f9..bf3a433890 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -6,7 +6,7 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding -from .resnet import Downsample, ResidualTemporalBlock, Upsample +from .resnet import Downsample1D, ResidualTemporalBlock, Upsample1D class SinusoidalPosEmb(nn.Module): @@ -96,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): [ ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon), ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon), - Downsample(dim_out, use_conv=True, dims=1) if not is_last else nn.Identity(), + Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(), ] ) ) @@ -116,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): [ ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon), ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon), - Upsample(dim_in, use_conv_transpose=True, dims=1) if not is_last else nn.Identity(), + Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(), ] ) ) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index d4a78101a7..facc2f9d6f 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -21,13 +21,12 @@ import math import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import GaussianFourierProjection, get_timestep_embedding -from .resnet import Downsample, ResnetBlock2D, Upsample, downsample_2d, upfirdn2d, upsample_2d +from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D def _setup_kernel(k): @@ -40,96 +39,6 @@ def _setup_kernel(k): return k -def _upsample_conv_2d(x, w, k=None, factor=2, gain=1): - """Fused `upsample_2d()` followed by `Conv2d()`. - - Args: - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary - order. - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, - C]`. - w: Weight tensor of the shape `[filterH, filterW, inChannels, - outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. - k: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as - `x`. - """ - - assert isinstance(factor, int) and factor >= 1 - - # Check weight shape. - assert len(w.shape) == 4 - convH = w.shape[2] - convW = w.shape[3] - inC = w.shape[1] - - assert convW == convH - - # Setup filter kernel. - if k is None: - k = [1] * factor - k = _setup_kernel(k) * (gain * (factor**2)) - p = (k.shape[0] - factor) - (convW - 1) - - stride = (factor, factor) - - # Determine data dimensions. - stride = [1, 1, factor, factor] - output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW) - output_padding = ( - output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, - output_shape[1] - (x.shape[3] - 1) * stride[1] - convW, - ) - assert output_padding[0] >= 0 and output_padding[1] >= 0 - num_groups = x.shape[1] // inC - - # Transpose weights. - w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) - w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) - w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) - - x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) - - return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) - - -def _conv_downsample_2d(x, w, k=None, factor=2, gain=1): - """Fused `Conv2d()` followed by `downsample_2d()`. - - Args: - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary - order. - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, - C]`. - w: Weight tensor of the shape `[filterH, filterW, inChannels, - outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. - k: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to average pooling. - factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype - as `x`. - """ - - assert isinstance(factor, int) and factor >= 1 - _outC, _inC, convH, convW = w.shape - assert convW == convH - if k is None: - k = [1] * factor - k = _setup_kernel(k) * gain - p = (k.shape[0] - factor) + (convW - 1) - s = [factor, factor] - x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2)) - return F.conv2d(x, w, stride=s, padding=0) - - def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): """Ported from JAX.""" scale = 1e-10 if scale == 0 else scale @@ -183,46 +92,6 @@ class Combine(nn.Module): raise ValueError(f"Method {self.method} not recognized.") -class FirUpsample(nn.Module): - def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) - self.use_conv = use_conv - self.fir_kernel = fir_kernel - self.out_channels = out_channels - - def forward(self, x): - if self.use_conv: - h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) - h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - h = upsample_2d(x, self.fir_kernel, factor=2) - - return h - - -class FirDownsample(nn.Module): - def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) - self.fir_kernel = fir_kernel - self.use_conv = use_conv - self.out_channels = out_channels - - def forward(self, x): - if self.use_conv: - x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) - x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - x = downsample_2d(x, self.fir_kernel, factor=2) - - return x - - class NCSNpp(ModelMixin, ConfigMixin): """NCSN++ model""" @@ -313,9 +182,9 @@ class NCSNpp(ModelMixin, ConfigMixin): AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) if self.fir: - Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv) + Up_sample = functools.partial(FirUpsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv) else: - Up_sample = functools.partial(Upsample, name="Conv2d_0") + Up_sample = functools.partial(Upsample2D, name="Conv2d_0") if progressive == "output_skip": self.pyramid_upsample = Up_sample(channels=None, use_conv=False) @@ -323,9 +192,9 @@ class NCSNpp(ModelMixin, ConfigMixin): pyramid_upsample = functools.partial(Up_sample, use_conv=True) if self.fir: - Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv) + Down_sample = functools.partial(FirDownsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv) else: - Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0") + Down_sample = functools.partial(Downsample2D, padding=0, name="Conv2d_0") if progressive_input == "input_skip": self.pyramid_downsample = Down_sample(channels=None, use_conv=False) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index ff8addf15f..a2ec239d43 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -5,7 +5,7 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock -from .resnet import Downsample, ResnetBlock2D, Upsample +from .resnet import Downsample2D, ResnetBlock2D, Upsample2D def nonlinearity(x): @@ -65,7 +65,7 @@ class Encoder(nn.Module): down.block = block down.attn = attn if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) + down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0) curr_res = curr_res // 2 self.down.append(down) @@ -179,7 +179,7 @@ class Decoder(nn.Module): up.block = block up.attn = attn if i_level != 0: - up.upsample = Upsample(block_in, use_conv=resamp_with_conv) + up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order diff --git a/src/diffusers/pipelines/bddm/pipeline_bddm.py b/src/diffusers/pipelines/bddm/pipeline_bddm.py index 45b26c3127..a27b9e122f 100644 --- a/src/diffusers/pipelines/bddm/pipeline_bddm.py +++ b/src/diffusers/pipelines/bddm/pipeline_bddm.py @@ -137,7 +137,7 @@ class ResidualBlock(nn.Module): # Dilated conv layer h = self.dilated_conv_layer(h) - # Upsample spectrogram to size of audio + # Upsample2D spectrogram to size of audio mel_spec = torch.unsqueeze(mel_spec, dim=1) mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False) mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False) diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index 86a7b88310..57ba310263 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -22,7 +22,7 @@ import numpy as np import torch from diffusers.models.embeddings import get_timestep_embedding -from diffusers.models.resnet import Downsample, Upsample +from diffusers.models.resnet import Downsample2D, Upsample2D from diffusers.testing_utils import floats_tensor, slow, torch_device @@ -116,11 +116,11 @@ class EmbeddingsTests(unittest.TestCase): ) -class UpsampleBlockTests(unittest.TestCase): +class Upsample2DBlockTests(unittest.TestCase): def test_upsample_default(self): torch.manual_seed(0) sample = torch.randn(1, 32, 32, 32) - upsample = Upsample(channels=32, use_conv=False) + upsample = Upsample2D(channels=32, use_conv=False) with torch.no_grad(): upsampled = upsample(sample) @@ -132,7 +132,7 @@ class UpsampleBlockTests(unittest.TestCase): def test_upsample_with_conv(self): torch.manual_seed(0) sample = torch.randn(1, 32, 32, 32) - upsample = Upsample(channels=32, use_conv=True) + upsample = Upsample2D(channels=32, use_conv=True) with torch.no_grad(): upsampled = upsample(sample) @@ -144,7 +144,7 @@ class UpsampleBlockTests(unittest.TestCase): def test_upsample_with_conv_out_dim(self): torch.manual_seed(0) sample = torch.randn(1, 32, 32, 32) - upsample = Upsample(channels=32, use_conv=True, out_channels=64) + upsample = Upsample2D(channels=32, use_conv=True, out_channels=64) with torch.no_grad(): upsampled = upsample(sample) @@ -156,7 +156,7 @@ class UpsampleBlockTests(unittest.TestCase): def test_upsample_with_transpose(self): torch.manual_seed(0) sample = torch.randn(1, 32, 32, 32) - upsample = Upsample(channels=32, use_conv=False, use_conv_transpose=True) + upsample = Upsample2D(channels=32, use_conv=False, use_conv_transpose=True) with torch.no_grad(): upsampled = upsample(sample) @@ -166,11 +166,11 @@ class UpsampleBlockTests(unittest.TestCase): assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) -class DownsampleBlockTests(unittest.TestCase): +class Downsample2DBlockTests(unittest.TestCase): def test_downsample_default(self): torch.manual_seed(0) sample = torch.randn(1, 32, 64, 64) - downsample = Downsample(channels=32, use_conv=False) + downsample = Downsample2D(channels=32, use_conv=False) with torch.no_grad(): downsampled = downsample(sample) @@ -184,7 +184,7 @@ class DownsampleBlockTests(unittest.TestCase): def test_downsample_with_conv(self): torch.manual_seed(0) sample = torch.randn(1, 32, 64, 64) - downsample = Downsample(channels=32, use_conv=True) + downsample = Downsample2D(channels=32, use_conv=True) with torch.no_grad(): downsampled = downsample(sample) @@ -199,7 +199,7 @@ class DownsampleBlockTests(unittest.TestCase): def test_downsample_with_conv_pad1(self): torch.manual_seed(0) sample = torch.randn(1, 32, 64, 64) - downsample = Downsample(channels=32, use_conv=True, padding=1) + downsample = Downsample2D(channels=32, use_conv=True, padding=1) with torch.no_grad(): downsampled = downsample(sample) @@ -211,7 +211,7 @@ class DownsampleBlockTests(unittest.TestCase): def test_downsample_with_conv_out_dim(self): torch.manual_seed(0) sample = torch.randn(1, 32, 64, 64) - downsample = Downsample(channels=32, use_conv=True, out_channels=16) + downsample = Downsample2D(channels=32, use_conv=True, out_channels=16) with torch.no_grad(): downsampled = downsample(sample)