1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00
Files
sdnext/modules/sdnq/layers/conv/forward.py
2025-10-05 22:50:30 +03:00

94 lines
4.7 KiB
Python

# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
from typing import Optional
import torch
def get_conv_args(input_ndim: int, stride, padding, dilation):
if input_ndim == 3:
conv_type = 1
elif input_ndim == 4:
conv_type = 2
else:
conv_type = 3
if isinstance(stride, int):
stride = (stride,) * conv_type
if isinstance(padding, int):
padding = (padding,) * conv_type
if isinstance(dilation, int):
dilation = (dilation,) * conv_type
if conv_type == 1:
stride = (1, stride[0])
padding = (0, padding[0])
dilation = (1, dilation[0])
return conv_type, stride, padding, dilation
def process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation):
if conv_type == 1:
batch_size, _, L_in = input.shape
C_out, _, K_l = result_shape
L_out = (L_in + 2 * padding[1] - dilation[1] * (K_l - 1) - 1) // stride[1] + 1
mm_output_shape = (batch_size, L_out, C_out)
kernel_size = (1, K_l)
if conv_type == 2:
batch_size, _, H_in, W_in = input.shape
C_out, _, K_h, K_w = result_shape
H_out = (H_in + 2 * padding[0] - dilation[0] * (K_h - 1) - 1) // stride[0] + 1
W_out = (W_in + 2 * padding[1] - dilation[1] * (K_w - 1) - 1) // stride[1] + 1
mm_output_shape = (batch_size, H_out, W_out, C_out)
kernel_size = (K_h, K_w)
else:
batch_size, _, D_in, H_in, W_in = input.shape
C_out, _, K_d, K_h, K_w = result_shape
D_out = (D_in + 2 * padding[0] - dilation[0] * (K_d - 1) - 1) // stride[0] + 1
H_out = (H_in + 2 * padding[1] - dilation[1] * (K_h - 1) - 1) // stride[1] + 1
W_out = (W_in + 2 * padding[2] - dilation[2] * (K_w - 1) - 1) // stride[2] + 1
mm_output_shape = (batch_size, D_out, H_out, W_out, C_out)
kernel_size = (K_d, K_h, K_w)
if padding_mode != "zeros":
input = torch.nn.functional.pad(input, reversed_padding_repeated_twice, mode=padding_mode)
padding = (0,) * (conv_type if conv_type != 1 else 2)
elif conv_type == 3:
input = torch.nn.functional.pad(input, reversed_padding_repeated_twice)
if conv_type == 1:
input = input.unsqueeze(2)
if conv_type == 3:
K_D_eff = K_d + (K_d - 1) * (dilation[0] - 1)
K_H_eff = K_h + (K_h - 1) * (dilation[0] - 1)
K_W_eff = K_w + (K_w - 1) * (dilation[0] - 1)
input = input.unfold(2, K_D_eff, stride[0]).unfold(3, K_H_eff, stride[1]).unfold(4, K_W_eff, stride[2])
if dilation[0] > 1:
input = input[..., ::dilation[0], :, :]
if dilation[1] > 1:
input = input[..., ::dilation[1], :]
if dilation[2] > 1:
input = input[..., ::dilation[2]]
input = input.permute(0, 2, 3, 4, 1, 5, 6, 7).reshape(batch_size, D_out * H_out * W_out, -1)
else:
input = torch.nn.functional.unfold(input, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation).transpose(1,2)
return input, mm_output_shape
def quantized_conv_forward(self, input) -> torch.FloatTensor:
return self._conv_forward(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias)
def quantized_conv_transpose_1d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor:
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 1, self.dilation)
return torch.nn.functional.conv_transpose1d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
def quantized_conv_transpose_2d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor:
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 2, self.dilation)
return torch.nn.functional.conv_transpose2d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
def quantized_conv_transpose_3d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor:
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 3, self.dilation)
return torch.nn.functional.conv_transpose3d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)