From 22b45304bf85a3c5281753d6b3259ccaf96e5085 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 20 Dec 2023 21:01:33 +0530 Subject: [PATCH] [Refactor upsamplers and downsamplers] separate out upsamplers and downsamplers. (#6128) * separate out upsamplers and downsamplers. * import all the necessary blocks in resnet for backward comp. * move upsample2d and downsample2d to utils. * move downsample_2d to downsamplers.py * apply feedback * fix import * samplers -> sampling --- src/diffusers/models/downsampling.py | 318 ++++++++++++ src/diffusers/models/resnet.py | 714 +-------------------------- src/diffusers/models/upsampling.py | 426 ++++++++++++++++ 3 files changed, 759 insertions(+), 699 deletions(-) create mode 100644 src/diffusers/models/downsampling.py create mode 100644 src/diffusers/models/upsampling.py diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py new file mode 100644 index 0000000000..d39bae22e8 --- /dev/null +++ b/src/diffusers/models/downsampling.py @@ -0,0 +1,318 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import USE_PEFT_BACKEND +from .lora import LoRACompatibleConv +from .upsampling import upfirdn2d_native + + +class Downsample1D(nn.Module): + """A 1D downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 1D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + assert inputs.shape[1] == self.channels + return self.conv(inputs) + + +class Downsample2D(nn.Module): + """A 2D downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 2D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + + if use_conv: + conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + + if not USE_PEFT_BACKEND: + if isinstance(self.conv, LoRACompatibleConv): + hidden_states = self.conv(hidden_states, scale) + else: + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class FirDownsample2D(nn.Module): + """A 2D FIR downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + fir_kernel (`tuple`, default `(1, 3, 3, 1)`): + kernel for the FIR filter. + """ + + def __init__( + self, + channels: Optional[int] = None, + out_channels: Optional[int] = None, + use_conv: bool = False, + fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), + ): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.fir_kernel = fir_kernel + self.use_conv = use_conv + self.out_channels = out_channels + + def _downsample_2d( + self, + hidden_states: torch.FloatTensor, + weight: Optional[torch.FloatTensor] = None, + kernel: Optional[torch.FloatTensor] = None, + factor: int = 2, + gain: float = 1, + ) -> torch.FloatTensor: + """Fused `Conv2d()` followed by `downsample_2d()`. + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states (`torch.FloatTensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight (`torch.FloatTensor`, *optional*): + Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be + performed by `inChannels = x.shape[0] // numGroups`. + kernel (`torch.FloatTensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to average pooling. + factor (`int`, *optional*, default to `2`): + Integer downsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude. + + Returns: + output (`torch.FloatTensor`): + Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same + datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * gain + + if self.use_conv: + _, _, convH, convW = weight.shape + pad_value = (kernel.shape[0] - factor) + (convW - 1) + stride_value = [factor, factor] + upfirdn_input = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + pad=((pad_value + 1) // 2, pad_value // 2), + ) + output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + + return output + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + if self.use_conv: + downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) + hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) + + return hidden_states + + +# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead +class KDownsample2D(nn.Module): + r"""A 2D K-downsampling layer. + + Parameters: + pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. + """ + + def __init__(self, pad_mode: str = "reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) + weight = inputs.new_zeros( + [ + inputs.shape[1], + inputs.shape[1], + self.kernel.shape[0], + self.kernel.shape[1], + ] + ) + indices = torch.arange(inputs.shape[1], device=inputs.device) + kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) + weight[indices, indices] = kernel + return F.conv2d(inputs, weight, stride=2) + + +def downsample_2d( + hidden_states: torch.FloatTensor, + kernel: Optional[torch.FloatTensor] = None, + factor: int = 2, + gain: float = 1, +) -> torch.FloatTensor: + r"""Downsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + + Args: + hidden_states (`torch.FloatTensor`) + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel (`torch.FloatTensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to average pooling. + factor (`int`, *optional*, default to `2`): + Integer downsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude. + + Returns: + output (`torch.FloatTensor`): + Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * gain + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + kernel.to(device=hidden_states.device), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + return output diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 970d2be05b..bbfb71ca3f 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -23,562 +23,23 @@ import torch.nn.functional as F from ..utils import USE_PEFT_BACKEND from .activations import get_activation from .attention_processor import SpatialNorm +from .downsampling import ( # noqa + Downsample1D, + Downsample2D, + FirDownsample2D, + KDownsample2D, + downsample_2d, +) from .lora import LoRACompatibleConv, LoRACompatibleLinear from .normalization import AdaGroupNorm - - -class Upsample1D(nn.Module): - """A 1D upsampling layer with an optional convolution. - - Parameters: - channels (`int`): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - use_conv_transpose (`bool`, default `False`): - option to use a convolution transpose. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - name (`str`, default `conv`): - name of the upsampling 1D layer. - """ - - def __init__( - self, - channels: int, - use_conv: bool = False, - use_conv_transpose: bool = False, - out_channels: Optional[int] = None, - name: str = "conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - - self.conv = None - if use_conv_transpose: - self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - assert inputs.shape[1] == self.channels - if self.use_conv_transpose: - return self.conv(inputs) - - outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") - - if self.use_conv: - outputs = self.conv(outputs) - - return outputs - - -class Downsample1D(nn.Module): - """A 1D downsampling layer with an optional convolution. - - Parameters: - channels (`int`): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - padding (`int`, default `1`): - padding for the convolution. - name (`str`, default `conv`): - name of the downsampling 1D layer. - """ - - def __init__( - self, - channels: int, - use_conv: bool = False, - out_channels: Optional[int] = None, - padding: int = 1, - name: str = "conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - - if use_conv: - self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) - else: - assert self.channels == self.out_channels - self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - assert inputs.shape[1] == self.channels - return self.conv(inputs) - - -class Upsample2D(nn.Module): - """A 2D upsampling layer with an optional convolution. - - Parameters: - channels (`int`): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - use_conv_transpose (`bool`, default `False`): - option to use a convolution transpose. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - name (`str`, default `conv`): - name of the upsampling 2D layer. - """ - - def __init__( - self, - channels: int, - use_conv: bool = False, - use_conv_transpose: bool = False, - out_channels: Optional[int] = None, - name: str = "conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv - - conv = None - if use_conv_transpose: - conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - conv = conv_cls(self.channels, self.out_channels, 3, padding=1) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if name == "conv": - self.conv = conv - else: - self.Conv2d_0 = conv - - def forward( - self, - hidden_states: torch.FloatTensor, - output_size: Optional[int] = None, - scale: float = 1.0, - ) -> torch.FloatTensor: - assert hidden_states.shape[1] == self.channels - - if self.use_conv_transpose: - return self.conv(hidden_states) - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch - # https://github.com/pytorch/pytorch/issues/86679 - dtype = hidden_states.dtype - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(torch.float32) - - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - - # if `output_size` is passed we force the interpolation output - # size and do not make use of `scale_factor=2` - if output_size is None: - hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") - else: - hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(dtype) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if self.use_conv: - if self.name == "conv": - if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND: - hidden_states = self.conv(hidden_states, scale) - else: - hidden_states = self.conv(hidden_states) - else: - if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND: - hidden_states = self.Conv2d_0(hidden_states, scale) - else: - hidden_states = self.Conv2d_0(hidden_states) - - return hidden_states - - -class Downsample2D(nn.Module): - """A 2D downsampling layer with an optional convolution. - - Parameters: - channels (`int`): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - padding (`int`, default `1`): - padding for the convolution. - name (`str`, default `conv`): - name of the downsampling 2D layer. - """ - - def __init__( - self, - channels: int, - use_conv: bool = False, - out_channels: Optional[int] = None, - padding: int = 1, - name: str = "conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv - - if use_conv: - conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding) - else: - assert self.channels == self.out_channels - conv = nn.AvgPool2d(kernel_size=stride, stride=stride) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if name == "conv": - self.Conv2d_0 = conv - self.conv = conv - elif name == "Conv2d_0": - self.conv = conv - else: - self.conv = conv - - def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: - assert hidden_states.shape[1] == self.channels - - if self.use_conv and self.padding == 0: - pad = (0, 1, 0, 1) - hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) - - assert hidden_states.shape[1] == self.channels - - if not USE_PEFT_BACKEND: - if isinstance(self.conv, LoRACompatibleConv): - hidden_states = self.conv(hidden_states, scale) - else: - hidden_states = self.conv(hidden_states) - else: - hidden_states = self.conv(hidden_states) - - return hidden_states - - -class FirUpsample2D(nn.Module): - """A 2D FIR upsampling layer with an optional convolution. - - Parameters: - channels (`int`, optional): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - fir_kernel (`tuple`, default `(1, 3, 3, 1)`): - kernel for the FIR filter. - """ - - def __init__( - self, - channels: Optional[int] = None, - out_channels: Optional[int] = None, - use_conv: bool = False, - fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), - ): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) - self.use_conv = use_conv - self.fir_kernel = fir_kernel - self.out_channels = out_channels - - def _upsample_2d( - self, - hidden_states: torch.FloatTensor, - weight: Optional[torch.FloatTensor] = None, - kernel: Optional[torch.FloatTensor] = None, - factor: int = 2, - gain: float = 1, - ) -> torch.FloatTensor: - """Fused `upsample_2d()` followed by `Conv2d()`. - - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of - arbitrary order. - - Args: - hidden_states (`torch.FloatTensor`): - Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight (`torch.FloatTensor`, *optional*): - Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be - performed by `inChannels = x.shape[0] // numGroups`. - kernel (`torch.FloatTensor`, *optional*): - FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which - corresponds to nearest-neighbor upsampling. - factor (`int`, *optional*): Integer upsampling factor (default: 2). - gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0). - - Returns: - output (`torch.FloatTensor`): - Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same - datatype as `hidden_states`. - """ - - assert isinstance(factor, int) and factor >= 1 - - # Setup filter kernel. - if kernel is None: - kernel = [1] * factor - - # setup kernel - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - - kernel = kernel * (gain * (factor**2)) - - if self.use_conv: - convH = weight.shape[2] - convW = weight.shape[3] - inC = weight.shape[1] - - pad_value = (kernel.shape[0] - factor) - (convW - 1) - - stride = (factor, factor) - # Determine data dimensions. - output_shape = ( - (hidden_states.shape[2] - 1) * factor + convH, - (hidden_states.shape[3] - 1) * factor + convW, - ) - output_padding = ( - output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, - output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, - ) - assert output_padding[0] >= 0 and output_padding[1] >= 0 - num_groups = hidden_states.shape[1] // inC - - # Transpose weights. - weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) - weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) - weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) - - inverse_conv = F.conv_transpose2d( - hidden_states, - weight, - stride=stride, - output_padding=output_padding, - padding=0, - ) - - output = upfirdn2d_native( - inverse_conv, - torch.tensor(kernel, device=inverse_conv.device), - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), - ) - else: - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - torch.tensor(kernel, device=hidden_states.device), - up=factor, - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), - ) - - return output - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - if self.use_conv: - height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) - height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) - - return height - - -class FirDownsample2D(nn.Module): - """A 2D FIR downsampling layer with an optional convolution. - - Parameters: - channels (`int`): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - fir_kernel (`tuple`, default `(1, 3, 3, 1)`): - kernel for the FIR filter. - """ - - def __init__( - self, - channels: Optional[int] = None, - out_channels: Optional[int] = None, - use_conv: bool = False, - fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), - ): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) - self.fir_kernel = fir_kernel - self.use_conv = use_conv - self.out_channels = out_channels - - def _downsample_2d( - self, - hidden_states: torch.FloatTensor, - weight: Optional[torch.FloatTensor] = None, - kernel: Optional[torch.FloatTensor] = None, - factor: int = 2, - gain: float = 1, - ) -> torch.FloatTensor: - """Fused `Conv2d()` followed by `downsample_2d()`. - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of - arbitrary order. - - Args: - hidden_states (`torch.FloatTensor`): - Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight (`torch.FloatTensor`, *optional*): - Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be - performed by `inChannels = x.shape[0] // numGroups`. - kernel (`torch.FloatTensor`, *optional*): - FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which - corresponds to average pooling. - factor (`int`, *optional*, default to `2`): - Integer downsampling factor. - gain (`float`, *optional*, default to `1.0`): - Scaling factor for signal magnitude. - - Returns: - output (`torch.FloatTensor`): - Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same - datatype as `x`. - """ - - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - # setup kernel - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - - kernel = kernel * gain - - if self.use_conv: - _, _, convH, convW = weight.shape - pad_value = (kernel.shape[0] - factor) + (convW - 1) - stride_value = [factor, factor] - upfirdn_input = upfirdn2d_native( - hidden_states, - torch.tensor(kernel, device=hidden_states.device), - pad=((pad_value + 1) // 2, pad_value // 2), - ) - output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) - else: - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - torch.tensor(kernel, device=hidden_states.device), - down=factor, - pad=((pad_value + 1) // 2, pad_value // 2), - ) - - return output - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - if self.use_conv: - downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) - hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) - - return hidden_states - - -# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead -class KDownsample2D(nn.Module): - r"""A 2D K-downsampling layer. - - Parameters: - pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. - """ - - def __init__(self, pad_mode: str = "reflect"): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) - self.pad = kernel_1d.shape[1] // 2 - 1 - self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) - weight = inputs.new_zeros( - [ - inputs.shape[1], - inputs.shape[1], - self.kernel.shape[0], - self.kernel.shape[1], - ] - ) - indices = torch.arange(inputs.shape[1], device=inputs.device) - kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) - weight[indices, indices] = kernel - return F.conv2d(inputs, weight, stride=2) - - -class KUpsample2D(nn.Module): - r"""A 2D K-upsampling layer. - - Parameters: - pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. - """ - - def __init__(self, pad_mode: str = "reflect"): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 - self.pad = kernel_1d.shape[1] // 2 - 1 - self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) - weight = inputs.new_zeros( - [ - inputs.shape[1], - inputs.shape[1], - self.kernel.shape[0], - self.kernel.shape[1], - ] - ) - indices = torch.arange(inputs.shape[1], device=inputs.device) - kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) - weight[indices, indices] = kernel - return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) +from .upsampling import ( # noqa + FirUpsample2D, + KUpsample2D, + Upsample1D, + Upsample2D, + upfirdn2d_native, + upsample_2d, +) class ResnetBlock2D(nn.Module): @@ -894,151 +355,6 @@ class ResidualTemporalBlock1D(nn.Module): return out + self.residual_conv(inputs) -def upsample_2d( - hidden_states: torch.FloatTensor, - kernel: Optional[torch.FloatTensor] = None, - factor: int = 2, - gain: float = 1, -) -> torch.FloatTensor: - r"""Upsample2D a batch of 2D images with the given filter. - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given - filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified - `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is - a: multiple of the upsampling factor. - - Args: - hidden_states (`torch.FloatTensor`): - Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - kernel (`torch.FloatTensor`, *optional*): - FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which - corresponds to nearest-neighbor upsampling. - factor (`int`, *optional*, default to `2`): - Integer upsampling factor. - gain (`float`, *optional*, default to `1.0`): - Scaling factor for signal magnitude (default: 1.0). - - Returns: - output (`torch.FloatTensor`): - Tensor of the shape `[N, C, H * factor, W * factor]` - """ - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - - kernel = kernel * (gain * (factor**2)) - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - kernel.to(device=hidden_states.device), - up=factor, - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), - ) - return output - - -def downsample_2d( - hidden_states: torch.FloatTensor, - kernel: Optional[torch.FloatTensor] = None, - factor: int = 2, - gain: float = 1, -) -> torch.FloatTensor: - r"""Downsample2D a batch of 2D images with the given filter. - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the - given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the - specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its - shape is a multiple of the downsampling factor. - - Args: - hidden_states (`torch.FloatTensor`) - Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - kernel (`torch.FloatTensor`, *optional*): - FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which - corresponds to average pooling. - factor (`int`, *optional*, default to `2`): - Integer downsampling factor. - gain (`float`, *optional*, default to `1.0`): - Scaling factor for signal magnitude. - - Returns: - output (`torch.FloatTensor`): - Tensor of the shape `[N, C, H // factor, W // factor]` - """ - - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - - kernel = kernel * gain - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - kernel.to(device=hidden_states.device), - down=factor, - pad=((pad_value + 1) // 2, pad_value // 2), - ) - return output - - -def upfirdn2d_native( - tensor: torch.Tensor, - kernel: torch.Tensor, - up: int = 1, - down: int = 1, - pad: Tuple[int, int] = (0, 0), -) -> torch.Tensor: - up_x = up_y = up - down_x = down_y = down - pad_x0 = pad_y0 = pad[0] - pad_x1 = pad_y1 = pad[1] - - _, channel, in_h, in_w = tensor.shape - tensor = tensor.reshape(-1, in_h, in_w, 1) - - _, in_h, in_w, minor = tensor.shape - kernel_h, kernel_w = kernel.shape - - out = tensor.view(-1, in_h, 1, in_w, 1, minor) - out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) - out = out.view(-1, in_h * up_y, in_w * up_x, minor) - - out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) - out = out.to(tensor.device) # Move back to mps if necessary - out = out[ - :, - max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), - max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), - :, - ] - - out = out.permute(0, 3, 1, 2) - out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) - w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) - out = F.conv2d(out, w) - out = out.reshape( - -1, - minor, - in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, - in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, - ) - out = out.permute(0, 2, 3, 1) - out = out[:, ::down_y, ::down_x, :] - - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 - - return out.view(-1, channel, out_h, out_w) - - class TemporalConvLayer(nn.Module): """ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py new file mode 100644 index 0000000000..542a5d9d1e --- /dev/null +++ b/src/diffusers/models/upsampling.py @@ -0,0 +1,426 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import USE_PEFT_BACKEND +from .lora import LoRACompatibleConv + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 1D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + + outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") + + if self.use_conv: + outputs = self.conv(outputs) + + return outputs + + +class Upsample2D(nn.Module): + """A 2D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 2D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + + conv = None + if use_conv_transpose: + conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + conv = conv_cls(self.channels, self.out_channels, 3, padding=1) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward( + self, + hidden_states: torch.FloatTensor, + output_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + return self.conv(hidden_states) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND: + hidden_states = self.conv(hidden_states, scale) + else: + hidden_states = self.conv(hidden_states) + else: + if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND: + hidden_states = self.Conv2d_0(hidden_states, scale) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class FirUpsample2D(nn.Module): + """A 2D FIR upsampling layer with an optional convolution. + + Parameters: + channels (`int`, optional): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + fir_kernel (`tuple`, default `(1, 3, 3, 1)`): + kernel for the FIR filter. + """ + + def __init__( + self, + channels: Optional[int] = None, + out_channels: Optional[int] = None, + use_conv: bool = False, + fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), + ): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.use_conv = use_conv + self.fir_kernel = fir_kernel + self.out_channels = out_channels + + def _upsample_2d( + self, + hidden_states: torch.FloatTensor, + weight: Optional[torch.FloatTensor] = None, + kernel: Optional[torch.FloatTensor] = None, + factor: int = 2, + gain: float = 1, + ) -> torch.FloatTensor: + """Fused `upsample_2d()` followed by `Conv2d()`. + + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states (`torch.FloatTensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight (`torch.FloatTensor`, *optional*): + Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be + performed by `inChannels = x.shape[0] // numGroups`. + kernel (`torch.FloatTensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to nearest-neighbor upsampling. + factor (`int`, *optional*): Integer upsampling factor (default: 2). + gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0). + + Returns: + output (`torch.FloatTensor`): + Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same + datatype as `hidden_states`. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Setup filter kernel. + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + + if self.use_conv: + convH = weight.shape[2] + convW = weight.shape[3] + inC = weight.shape[1] + + pad_value = (kernel.shape[0] - factor) - (convW - 1) + + stride = (factor, factor) + # Determine data dimensions. + output_shape = ( + (hidden_states.shape[2] - 1) * factor + convH, + (hidden_states.shape[3] - 1) * factor + convW, + ) + output_padding = ( + output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + num_groups = hidden_states.shape[1] // inC + + # Transpose weights. + weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) + weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) + weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) + + inverse_conv = F.conv_transpose2d( + hidden_states, + weight, + stride=stride, + output_padding=output_padding, + padding=0, + ) + + output = upfirdn2d_native( + inverse_conv, + torch.tensor(kernel, device=inverse_conv.device), + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), + ) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + + return output + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + if self.use_conv: + height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) + height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) + + return height + + +class KUpsample2D(nn.Module): + r"""A 2D K-upsampling layer. + + Parameters: + pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. + """ + + def __init__(self, pad_mode: str = "reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) + weight = inputs.new_zeros( + [ + inputs.shape[1], + inputs.shape[1], + self.kernel.shape[0], + self.kernel.shape[1], + ] + ) + indices = torch.arange(inputs.shape[1], device=inputs.device) + kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) + weight[indices, indices] = kernel + return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) + + +def upfirdn2d_native( + tensor: torch.Tensor, + kernel: torch.Tensor, + up: int = 1, + down: int = 1, + pad: Tuple[int, int] = (0, 0), +) -> torch.Tensor: + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + + _, channel, in_h, in_w = tensor.shape + tensor = tensor.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = tensor.shape + kernel_h, kernel_w = kernel.shape + + out = tensor.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out.to(tensor.device) # Move back to mps if necessary + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) + + +def upsample_2d( + hidden_states: torch.FloatTensor, + kernel: Optional[torch.FloatTensor] = None, + factor: int = 2, + gain: float = 1, +) -> torch.FloatTensor: + r"""Upsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is + a: multiple of the upsampling factor. + + Args: + hidden_states (`torch.FloatTensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel (`torch.FloatTensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to nearest-neighbor upsampling. + factor (`int`, *optional*, default to `2`): + Integer upsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude (default: 1.0). + + Returns: + output (`torch.FloatTensor`): + Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + kernel.to(device=hidden_states.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + return output