1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00

SDNQ fix Loras

This commit is contained in:
Disty0
2025-11-18 01:47:35 +03:00
parent 1745ed53f8
commit 524e92eee2

View File

@@ -46,7 +46,7 @@ def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n
else:
self.network_weights_backup = weight.clone().to(devices.cpu)
if hasattr(self, "sdnq_dequantizer"):
self.sdnq_dequantizer_backup = self.sdnq_dequantizer.to(devices.cpu)
self.sdnq_dequantizer_backup = self.sdnq_dequantizer
self.sdnq_scale_backup = self.scale.clone().to(devices.cpu)
if self.zero_point is not None:
self.sdnq_zero_point_backup = self.zero_point.clone().to(devices.cpu)
@@ -91,7 +91,7 @@ def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.
try:
t0 = time.time()
if hasattr(self, "sdnq_dequantizer_backup"):
weight = self.sdnq_dequantizer_backup.to(devices.device)(
weight = self.sdnq_dequantizer_backup(
self.weight.to(devices.device),
self.sdnq_scale_backup.to(devices.device),
self.sdnq_zero_point_backup.to(devices.device) if self.sdnq_zero_point_backup is not None else None,
@@ -100,7 +100,7 @@ def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.
skip_quantized_matmul=self.sdnq_dequantizer_backup.use_quantized_matmul
)
elif hasattr(self, "sdnq_dequantizer"):
weight = self.sdnq_dequantizer.to(devices.device)(
weight = self.sdnq_dequantizer(
self.weight.to(devices.device),
self.scale.to(devices.device),
self.zero_point.to(devices.device) if self.zero_point is not None else None,
@@ -170,7 +170,7 @@ def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.G
use_svd = bool(self.sdnq_svd_up_backup is not None)
dequantize_fp32 = bool(self.sdnq_scale_backup.dtype == torch.float32)
sdnq_dequantizer = self.sdnq_dequantizer_backup
dequant_weight = self.sdnq_dequantizer_backup.to(devices.device)(
dequant_weight = self.sdnq_dequantizer_backup(
model_weights.to(devices.device),
self.sdnq_scale_backup.to(devices.device),
self.sdnq_zero_point_backup.to(devices.device) if self.sdnq_zero_point_backup is not None else None,
@@ -182,7 +182,7 @@ def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.G
use_svd = bool(self.svd_up is not None)
dequantize_fp32 = bool(self.scale.dtype == torch.float32)
sdnq_dequantizer = self.sdnq_dequantizer
dequant_weight = self.sdnq_dequantizer.to(devices.device)(
dequant_weight = self.sdnq_dequantizer(
model_weights.to(devices.device),
self.scale.to(devices.device),
self.zero_point.to(devices.device) if self.zero_point is not None else None,
@@ -278,7 +278,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
else:
self.weight = torch.nn.Parameter(weights_backup.to(device), requires_grad=False)
if hasattr(self, "sdnq_dequantizer_backup"):
self.sdnq_dequantizer = self.sdnq_dequantizer_backup.to(device)
self.sdnq_dequantizer = self.sdnq_dequantizer_backup
self.scale = torch.nn.Parameter(self.sdnq_scale_backup.to(device), requires_grad=False)
if self.sdnq_zero_point_backup is not None:
self.zero_point = torch.nn.Parameter(self.sdnq_zero_point_backup.to(device), requires_grad=False)