mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-29 05:02:09 +03:00
21 lines
704 B
Python
21 lines
704 B
Python
# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
|
|
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
|
|
from ...common import use_contiguous_mm # noqa: TID252
|
|
|
|
|
|
def check_mats(input: torch.Tensor, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
input = input.contiguous()
|
|
if use_contiguous_mm:
|
|
weight = weight.contiguous()
|
|
elif weight.is_contiguous():
|
|
weight = weight.t().contiguous().t()
|
|
return input, weight
|
|
|
|
|
|
def quantized_linear_forward(self, input: torch.FloatTensor) -> torch.FloatTensor:
|
|
return torch.nn.functional.linear(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias)
|