1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/sdnq/layers/__init__.py
vladmandic a4671045b6 lint and crlf
Signed-off-by: vladmandic <mandic00@live.com>
2026-01-24 10:28:46 +01:00

80 lines
3.4 KiB
Python

import torch
class SDNQLayer(torch.nn.Module):
def __init__(self, original_layer, forward_func):
torch.nn.Module.__init__(self)
for key, value in original_layer.__dict__.items():
if key not in {"forward", "forward_func", "original_class", "state_dict", "load_state_dict"}:
setattr(self, key, value)
self.original_class = original_layer.__class__
self.forward_func = forward_func
def dequantize(self: torch.nn.Module):
if self.weight.__class__.__name__ == "SDNQTensor": # pylint: disable=access-member-before-definition
self.weight = torch.nn.Parameter(self.weight.dequantize(), requires_grad=True) # pylint: disable=attribute-defined-outside-init
elif hasattr(self, "sdnq_dequantizer"):
self.weight = torch.nn.Parameter(self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down, skip_quantized_matmul=self.sdnq_dequantizer.use_quantized_matmul), requires_grad=True) # pylint: disable=attribute-defined-outside-init
del self.sdnq_dequantizer, self.scale, self.zero_point, self.svd_up, self.svd_down
self.__class__ = self.original_class
del self.original_class, self.forward_func
return self
def forward(self, *args, **kwargs) -> torch.Tensor:
return self.forward_func(self, *args, **kwargs)
def __repr__(self):
return f"{self.__class__.__name__}(original_class={self.original_class} forward_func={self.forward_func} sdnq_dequantizer={repr(getattr(self, 'sdnq_dequantizer', None))})"
class SDNQLinear(SDNQLayer, torch.nn.Linear):
original_class: torch.nn.Linear
class SDNQConv1d(SDNQLayer, torch.nn.Conv1d):
original_class: torch.nn.Conv1d
class SDNQConv2d(SDNQLayer, torch.nn.Conv2d):
original_class: torch.nn.Conv2d
class SDNQConv3d(SDNQLayer, torch.nn.Conv3d):
original_class: torch.nn.Conv3d
class SDNQConvTranspose1d(SDNQLayer, torch.nn.ConvTranspose1d):
original_class: torch.nn.ConvTranspose1d
class SDNQConvTranspose2d(SDNQLayer, torch.nn.ConvTranspose2d):
original_class: torch.nn.ConvTranspose2d
class SDNQConvTranspose3d(SDNQLayer, torch.nn.ConvTranspose3d):
original_class: torch.nn.ConvTranspose3d
torch.serialization.add_safe_globals([SDNQLayer])
torch.serialization.add_safe_globals([SDNQLinear])
torch.serialization.add_safe_globals([SDNQConv1d])
torch.serialization.add_safe_globals([SDNQConv2d])
torch.serialization.add_safe_globals([SDNQConv3d])
torch.serialization.add_safe_globals([SDNQConvTranspose1d])
torch.serialization.add_safe_globals([SDNQConvTranspose2d])
torch.serialization.add_safe_globals([SDNQConvTranspose3d])
def get_sdnq_wrapper_class(original_layer, forward_func):
match original_layer.__class__.__name__:
case "Linear":
return SDNQLinear(original_layer, forward_func)
case "Conv1d":
return SDNQConv1d(original_layer, forward_func)
case "Conv2d":
return SDNQConv2d(original_layer, forward_func)
case "Conv3d":
return SDNQConv3d(original_layer, forward_func)
case "ConvTranspose1d":
return SDNQConvTranspose1d(original_layer, forward_func)
case "ConvTranspose2d":
return SDNQConvTranspose2d(original_layer, forward_func)
case "ConvTranspose3d":
return SDNQConvTranspose3d(original_layer, forward_func)
case _:
return SDNQLayer(original_layer, forward_func)