mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-29 05:02:09 +03:00
94 lines
4.7 KiB
Python
94 lines
4.7 KiB
Python
# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
|
|
def get_conv_args(input_ndim: int, stride, padding, dilation):
|
|
if input_ndim == 3:
|
|
conv_type = 1
|
|
elif input_ndim == 4:
|
|
conv_type = 2
|
|
else:
|
|
conv_type = 3
|
|
if isinstance(stride, int):
|
|
stride = (stride,) * conv_type
|
|
if isinstance(padding, int):
|
|
padding = (padding,) * conv_type
|
|
if isinstance(dilation, int):
|
|
dilation = (dilation,) * conv_type
|
|
if conv_type == 1:
|
|
stride = (1, stride[0])
|
|
padding = (0, padding[0])
|
|
dilation = (1, dilation[0])
|
|
return conv_type, stride, padding, dilation
|
|
|
|
|
|
def process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation):
|
|
if conv_type == 1:
|
|
batch_size, _, L_in = input.shape
|
|
C_out, _, K_l = result_shape
|
|
L_out = (L_in + 2 * padding[1] - dilation[1] * (K_l - 1) - 1) // stride[1] + 1
|
|
mm_output_shape = (batch_size, L_out, C_out)
|
|
kernel_size = (1, K_l)
|
|
if conv_type == 2:
|
|
batch_size, _, H_in, W_in = input.shape
|
|
C_out, _, K_h, K_w = result_shape
|
|
H_out = (H_in + 2 * padding[0] - dilation[0] * (K_h - 1) - 1) // stride[0] + 1
|
|
W_out = (W_in + 2 * padding[1] - dilation[1] * (K_w - 1) - 1) // stride[1] + 1
|
|
mm_output_shape = (batch_size, H_out, W_out, C_out)
|
|
kernel_size = (K_h, K_w)
|
|
else:
|
|
batch_size, _, D_in, H_in, W_in = input.shape
|
|
C_out, _, K_d, K_h, K_w = result_shape
|
|
D_out = (D_in + 2 * padding[0] - dilation[0] * (K_d - 1) - 1) // stride[0] + 1
|
|
H_out = (H_in + 2 * padding[1] - dilation[1] * (K_h - 1) - 1) // stride[1] + 1
|
|
W_out = (W_in + 2 * padding[2] - dilation[2] * (K_w - 1) - 1) // stride[2] + 1
|
|
mm_output_shape = (batch_size, D_out, H_out, W_out, C_out)
|
|
kernel_size = (K_d, K_h, K_w)
|
|
|
|
if padding_mode != "zeros":
|
|
input = torch.nn.functional.pad(input, reversed_padding_repeated_twice, mode=padding_mode)
|
|
padding = (0,) * (conv_type if conv_type != 1 else 2)
|
|
elif conv_type == 3:
|
|
input = torch.nn.functional.pad(input, reversed_padding_repeated_twice)
|
|
|
|
if conv_type == 1:
|
|
input = input.unsqueeze(2)
|
|
|
|
if conv_type == 3:
|
|
K_D_eff = K_d + (K_d - 1) * (dilation[0] - 1)
|
|
K_H_eff = K_h + (K_h - 1) * (dilation[0] - 1)
|
|
K_W_eff = K_w + (K_w - 1) * (dilation[0] - 1)
|
|
input = input.unfold(2, K_D_eff, stride[0]).unfold(3, K_H_eff, stride[1]).unfold(4, K_W_eff, stride[2])
|
|
if dilation[0] > 1:
|
|
input = input[..., ::dilation[0], :, :]
|
|
if dilation[1] > 1:
|
|
input = input[..., ::dilation[1], :]
|
|
if dilation[2] > 1:
|
|
input = input[..., ::dilation[2]]
|
|
input = input.permute(0, 2, 3, 4, 1, 5, 6, 7).reshape(batch_size, D_out * H_out * W_out, -1)
|
|
else:
|
|
input = torch.nn.functional.unfold(input, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation).transpose(1,2)
|
|
return input, mm_output_shape
|
|
|
|
|
|
def quantized_conv_forward(self, input) -> torch.FloatTensor:
|
|
return self._conv_forward(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias)
|
|
|
|
|
|
def quantized_conv_transpose_1d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor:
|
|
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 1, self.dilation)
|
|
return torch.nn.functional.conv_transpose1d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
|
|
|
|
|
def quantized_conv_transpose_2d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor:
|
|
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 2, self.dilation)
|
|
return torch.nn.functional.conv_transpose2d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
|
|
|
|
|
def quantized_conv_transpose_3d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor:
|
|
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 3, self.dilation)
|
|
return torch.nn.functional.conv_transpose3d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|