mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
SDNQ fix Loras
This commit is contained in:
@@ -46,7 +46,7 @@ def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n
|
||||
else:
|
||||
self.network_weights_backup = weight.clone().to(devices.cpu)
|
||||
if hasattr(self, "sdnq_dequantizer"):
|
||||
self.sdnq_dequantizer_backup = self.sdnq_dequantizer.to(devices.cpu)
|
||||
self.sdnq_dequantizer_backup = self.sdnq_dequantizer
|
||||
self.sdnq_scale_backup = self.scale.clone().to(devices.cpu)
|
||||
if self.zero_point is not None:
|
||||
self.sdnq_zero_point_backup = self.zero_point.clone().to(devices.cpu)
|
||||
@@ -91,7 +91,7 @@ def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.
|
||||
try:
|
||||
t0 = time.time()
|
||||
if hasattr(self, "sdnq_dequantizer_backup"):
|
||||
weight = self.sdnq_dequantizer_backup.to(devices.device)(
|
||||
weight = self.sdnq_dequantizer_backup(
|
||||
self.weight.to(devices.device),
|
||||
self.sdnq_scale_backup.to(devices.device),
|
||||
self.sdnq_zero_point_backup.to(devices.device) if self.sdnq_zero_point_backup is not None else None,
|
||||
@@ -100,7 +100,7 @@ def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.
|
||||
skip_quantized_matmul=self.sdnq_dequantizer_backup.use_quantized_matmul
|
||||
)
|
||||
elif hasattr(self, "sdnq_dequantizer"):
|
||||
weight = self.sdnq_dequantizer.to(devices.device)(
|
||||
weight = self.sdnq_dequantizer(
|
||||
self.weight.to(devices.device),
|
||||
self.scale.to(devices.device),
|
||||
self.zero_point.to(devices.device) if self.zero_point is not None else None,
|
||||
@@ -170,7 +170,7 @@ def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.G
|
||||
use_svd = bool(self.sdnq_svd_up_backup is not None)
|
||||
dequantize_fp32 = bool(self.sdnq_scale_backup.dtype == torch.float32)
|
||||
sdnq_dequantizer = self.sdnq_dequantizer_backup
|
||||
dequant_weight = self.sdnq_dequantizer_backup.to(devices.device)(
|
||||
dequant_weight = self.sdnq_dequantizer_backup(
|
||||
model_weights.to(devices.device),
|
||||
self.sdnq_scale_backup.to(devices.device),
|
||||
self.sdnq_zero_point_backup.to(devices.device) if self.sdnq_zero_point_backup is not None else None,
|
||||
@@ -182,7 +182,7 @@ def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.G
|
||||
use_svd = bool(self.svd_up is not None)
|
||||
dequantize_fp32 = bool(self.scale.dtype == torch.float32)
|
||||
sdnq_dequantizer = self.sdnq_dequantizer
|
||||
dequant_weight = self.sdnq_dequantizer.to(devices.device)(
|
||||
dequant_weight = self.sdnq_dequantizer(
|
||||
model_weights.to(devices.device),
|
||||
self.scale.to(devices.device),
|
||||
self.zero_point.to(devices.device) if self.zero_point is not None else None,
|
||||
@@ -278,7 +278,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(weights_backup.to(device), requires_grad=False)
|
||||
if hasattr(self, "sdnq_dequantizer_backup"):
|
||||
self.sdnq_dequantizer = self.sdnq_dequantizer_backup.to(device)
|
||||
self.sdnq_dequantizer = self.sdnq_dequantizer_backup
|
||||
self.scale = torch.nn.Parameter(self.sdnq_scale_backup.to(device), requires_grad=False)
|
||||
if self.sdnq_zero_point_backup is not None:
|
||||
self.zero_point = torch.nn.Parameter(self.sdnq_zero_point_backup.to(device), requires_grad=False)
|
||||
|
||||
Reference in New Issue
Block a user