1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00
Files
sdnext/modules/sdnq/layers/linear/forward.py
2025-10-05 22:50:30 +03:00

21 lines
704 B
Python

# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
from typing import Tuple
import torch
from ...common import use_contiguous_mm # noqa: TID252
def check_mats(input: torch.Tensor, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
input = input.contiguous()
if use_contiguous_mm:
weight = weight.contiguous()
elif weight.is_contiguous():
weight = weight.t().contiguous().t()
return input, weight
def quantized_linear_forward(self, input: torch.FloatTensor) -> torch.FloatTensor:
return torch.nn.functional.linear(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias)