1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/lora/lora_apply.py
2025-12-27 23:07:53 +03:00

307 lines
17 KiB
Python

from typing import Union
import re
import time
import torch
import diffusers.models.lora
from modules.lora import lora_common as l
from modules import shared, devices, errors, model_quant
bnb = None
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, wanted_names: tuple):
global bnb # pylint: disable=W0603
backup_size = 0
if len(l.loaded_networks) > 0 and network_layer_name is not None and any([net.modules.get(network_layer_name, None) for net in l.loaded_networks]): # noqa: C419 # pylint: disable=R1729
t0 = time.time()
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is not None or bias_backup is not None:
if (shared.opts.lora_fuse_native and not isinstance(weights_backup, bool)) or (not shared.opts.lora_fuse_native and isinstance(weights_backup, bool)): # invalidate so we can change direct/backup on-the-fly
weights_backup = None
bias_backup = None
self.network_weights_backup = weights_backup
self.network_bias_backup = bias_backup
if weights_backup is None and wanted_names != (): # pylint: disable=C1803
weight = getattr(self, 'weight', None)
self.network_weights_backup = None
if getattr(weight, "quant_type", None) in ['nf4', 'fp4']:
if bnb is None:
bnb = model_quant.load_bnb('Network load: type=LoRA', silent=True)
if bnb is not None:
if shared.opts.lora_fuse_native:
self.network_weights_backup = True
else:
self.network_weights_backup = bnb.functional.dequantize_4bit(weight, quant_state=weight.quant_state, quant_type=weight.quant_type, blocksize=weight.blocksize,)
self.quant_state, self.quant_type, self.blocksize = weight.quant_state, weight.quant_type, weight.blocksize
else:
self.network_weights_backup = weight.clone().to(devices.cpu) if not shared.opts.lora_fuse_native else True
else:
if shared.opts.lora_fuse_native:
self.network_weights_backup = True
else:
self.network_weights_backup = weight.clone().to(devices.cpu)
if hasattr(self, "sdnq_dequantizer"):
self.sdnq_dequantizer_backup = self.sdnq_dequantizer
self.sdnq_scale_backup = self.scale.clone().to(devices.cpu)
if self.zero_point is not None:
self.sdnq_zero_point_backup = self.zero_point.clone().to(devices.cpu)
else:
self.sdnq_zero_point_backup = None
if self.svd_up is not None:
self.sdnq_svd_up_backup = self.svd_up.clone().to(devices.cpu)
self.sdnq_svd_down_backup = self.svd_down.clone().to(devices.cpu)
else:
self.sdnq_svd_up_backup = None
self.sdnq_svd_down_backup = None
if bias_backup is None:
if getattr(self, 'bias', None) is not None:
if shared.opts.lora_fuse_native:
self.network_bias_backup = True
else:
bias_backup = self.bias.clone()
bias_backup = bias_backup.to(devices.cpu)
if getattr(self, 'network_weights_backup', None) is not None:
backup_size += self.network_weights_backup.numel() * self.network_weights_backup.element_size() if isinstance(self.network_weights_backup, torch.Tensor) else 0
if getattr(self, 'network_bias_backup', None) is not None:
backup_size += self.network_bias_backup.numel() * self.network_bias_backup.element_size() if isinstance(self.network_bias_backup, torch.Tensor) else 0
l.timer.backup += time.time() - t0
return backup_size
def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, use_previous: bool = False):
if shared.opts.diffusers_offload_mode == "none":
try:
self.to(devices.device)
except Exception:
pass
batch_updown = None
batch_ex_bias = None
loaded = l.loaded_networks if not use_previous else l.previously_loaded_networks
for net in loaded:
module = net.modules.get(network_layer_name, None)
if module is None:
continue
try:
t0 = time.time()
if hasattr(self, "sdnq_dequantizer_backup"):
weight = self.sdnq_dequantizer_backup(
self.weight.to(devices.device),
self.sdnq_scale_backup.to(devices.device),
self.sdnq_zero_point_backup.to(devices.device) if self.sdnq_zero_point_backup is not None else None,
self.sdnq_svd_up_backup.to(devices.device) if self.sdnq_svd_up_backup is not None else None,
self.sdnq_svd_down_backup.to(devices.device) if self.sdnq_svd_down_backup is not None else None,
skip_quantized_matmul=self.sdnq_dequantizer_backup.use_quantized_matmul
)
elif hasattr(self, "sdnq_dequantizer"):
weight = self.sdnq_dequantizer(
self.weight.to(devices.device),
self.scale.to(devices.device),
self.zero_point.to(devices.device) if self.zero_point is not None else None,
self.svd_up.to(devices.device) if self.svd_up is not None else None,
self.svd_down.to(devices.device) if self.svd_down is not None else None,
skip_quantized_matmul=self.sdnq_dequantizer.use_quantized_matmul
)
else:
weight = self.weight.to(devices.device) # must perform calc on gpu due to performance
updown, ex_bias = module.calc_updown(weight)
weight = None
del weight
if updown is not None:
if batch_updown is not None:
batch_updown += updown.to(batch_updown.device)
else:
batch_updown = updown.to(devices.device)
if ex_bias is not None:
if batch_ex_bias:
batch_ex_bias += ex_bias.to(batch_ex_bias.device)
else:
batch_ex_bias = ex_bias.to(devices.device)
l.timer.calc += time.time() - t0
if shared.opts.diffusers_offload_mode == "sequential":
t0 = time.time()
if batch_updown is not None:
batch_updown = batch_updown.to(devices.cpu)
if batch_ex_bias is not None:
batch_ex_bias = batch_ex_bias.to(devices.cpu)
t1 = time.time()
l.timer.move += t1 - t0
except RuntimeError as e:
l.extra_network_lora.errors[net.name] = l.extra_network_lora.errors.get(net.name, 0) + 1
module_name = net.modules.get(network_layer_name, None)
shared.log.error(f'Network: type=LoRA name="{net.name}" module="{module_name}" layer="{network_layer_name}" apply weight: {e}')
if l.debug:
errors.display(e, 'LoRA')
raise RuntimeError('LoRA apply weight') from e
continue
return batch_updown, batch_ex_bias
def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], model_weights: Union[None, torch.Tensor] = None, lora_weights: torch.Tensor = None, deactivate: bool = False, device: torch.device = None, bias: bool = False):
if lora_weights is None:
return
if deactivate:
lora_weights *= -1
if model_weights is None: # weights are used if provided-from-backup else use self.weight
model_weights = self.weight
weight, new_weight = None, None
# TODO lora: add other quantization types
if self.__class__.__name__ == 'Linear4bit' and bnb is not None:
try:
dequant_weight = bnb.functional.dequantize_4bit(model_weights.to(devices.device), quant_state=self.quant_state, quant_type=self.quant_type, blocksize=self.blocksize)
new_weight = dequant_weight.to(devices.device) + lora_weights.to(devices.device)
weight = bnb.nn.Params4bit(new_weight.to(device), quant_state=self.quant_state, quant_type=self.quant_type, blocksize=self.blocksize, requires_grad=False)
# TODO lora: maybe force imediate quantization
# weight._quantize(devices.device) / weight.to(device=device)
except Exception as e:
shared.log.error(f'Network load: type=LoRA quant=bnb cls={self.__class__.__name__} type={self.quant_type} blocksize={self.blocksize} state={vars(self.quant_state)} weight={self.weight} bias={lora_weights} {e}')
elif not bias and hasattr(self, "sdnq_dequantizer"):
try:
from modules.sdnq import sdnq_quantize_layer
if hasattr(self, "sdnq_dequantizer_backup"):
use_svd = bool(self.sdnq_svd_up_backup is not None)
dequantize_fp32 = bool(self.sdnq_scale_backup.dtype == torch.float32)
sdnq_dequantizer = self.sdnq_dequantizer_backup
dequant_weight = self.sdnq_dequantizer_backup(
model_weights.to(devices.device),
self.sdnq_scale_backup.to(devices.device),
self.sdnq_zero_point_backup.to(devices.device) if self.sdnq_zero_point_backup is not None else None,
self.sdnq_svd_up_backup.to(devices.device) if use_svd else None,
self.sdnq_svd_down_backup.to(devices.device) if use_svd else None,
skip_quantized_matmul=self.sdnq_dequantizer_backup.use_quantized_matmul,
dtype=torch.float32,
)
else:
use_svd = bool(self.svd_up is not None)
dequantize_fp32 = bool(self.scale.dtype == torch.float32)
sdnq_dequantizer = self.sdnq_dequantizer
dequant_weight = self.sdnq_dequantizer(
model_weights.to(devices.device),
self.scale.to(devices.device),
self.zero_point.to(devices.device) if self.zero_point is not None else None,
self.svd_up.to(devices.device) if use_svd else None,
self.svd_down.to(devices.device) if use_svd else None,
skip_quantized_matmul=self.sdnq_dequantizer.use_quantized_matmul,
dtype=torch.float32,
)
new_weight = dequant_weight.to(devices.device, dtype=torch.float32) + lora_weights.to(devices.device, dtype=torch.float32)
self.weight = torch.nn.Parameter(new_weight, requires_grad=False)
del self.sdnq_dequantizer, self.scale, self.zero_point, self.svd_up, self.svd_down
self = sdnq_quantize_layer(
self,
weights_dtype=sdnq_dequantizer.weights_dtype,
quantized_matmul_dtype=sdnq_dequantizer.quantized_matmul_dtype,
torch_dtype=sdnq_dequantizer.result_dtype,
group_size=sdnq_dequantizer.group_size,
svd_rank=sdnq_dequantizer.svd_rank,
use_quantized_matmul=sdnq_dequantizer.use_quantized_matmul,
use_quantized_matmul_conv=sdnq_dequantizer.use_quantized_matmul,
use_svd=use_svd,
dequantize_fp32=dequantize_fp32,
svd_steps=shared.opts.sdnq_svd_steps,
quant_conv=True, # quant_conv is True if conv layers ends up here
non_blocking=False,
quantization_device=devices.device,
return_device=device,
param_name=getattr(self, 'network_layer_name', None),
)[0].to(device)
weight = None
del dequant_weight
except Exception as e:
shared.log.error(f'Network load: type=LoRA quant=sdnq cls={self.__class__.__name__} weight={self.weight} lora_weights={lora_weights} {e}')
else:
try:
new_weight = model_weights.to(devices.device) + lora_weights.to(devices.device)
except Exception as e:
shared.log.warning(f'Network load: {e}')
if 'The size of tensor' in str(e):
shared.log.error(f'Network load: type=LoRA model={shared.sd_model.__class__.__name__} incompatible lora shape')
new_weight = model_weights
else:
new_weight = model_weights + lora_weights # try without device cast
weight = torch.nn.Parameter(new_weight.to(device), requires_grad=False)
if weight is not None:
if not bias:
self.weight = weight
else:
self.bias = weight
del model_weights, lora_weights, new_weight, weight # required to avoid memory leak
def network_apply_direct(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, deactivate: bool = False, device: torch.device = devices.device):
weights_backup = getattr(self, "network_weights_backup", False)
bias_backup = getattr(self, "network_bias_backup", False)
if not isinstance(weights_backup, bool): # remove previous backup if we switched settings
weights_backup = True
if not isinstance(bias_backup, bool):
bias_backup = True
if not weights_backup and not bias_backup:
return
t0 = time.time()
if weights_backup:
if updown is not None and len(self.weight.shape) == 4 and self.weight.shape[1] == 9: # inpainting model so zero pad updown to make channel 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable
if updown is not None:
network_add_weights(self, lora_weights=updown, deactivate=deactivate, device=device, bias=False)
if bias_backup:
if ex_bias is not None:
network_add_weights(self, lora_weights=ex_bias, deactivate=deactivate, device=device, bias=True)
if hasattr(self, "qweight") and hasattr(self, "freeze"):
self.freeze()
l.timer.apply += time.time() - t0
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, device: torch.device, deactivate: bool = False):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is None and bias_backup is None:
return
t0 = time.time()
if weights_backup is not None:
self.weight = None
if updown is not None and len(weights_backup.shape) == 4 and weights_backup.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable
if updown is not None:
network_add_weights(self, model_weights=weights_backup, lora_weights=updown, deactivate=deactivate, device=device, bias=False)
else:
self.weight = torch.nn.Parameter(weights_backup.to(device), requires_grad=False)
if hasattr(self, "sdnq_dequantizer_backup"):
self.sdnq_dequantizer = self.sdnq_dequantizer_backup
self.scale = torch.nn.Parameter(self.sdnq_scale_backup.to(device), requires_grad=False)
if self.sdnq_zero_point_backup is not None:
self.zero_point = torch.nn.Parameter(self.sdnq_zero_point_backup.to(device), requires_grad=False)
else:
self.zero_point = None
if self.sdnq_svd_up_backup is not None:
self.svd_up = torch.nn.Parameter(self.sdnq_svd_up_backup.to(device), requires_grad=False)
self.svd_down = torch.nn.Parameter(self.sdnq_svd_down_backup.to(device), requires_grad=False)
else:
self.svd_up, self.svd_down = None, None
del self.sdnq_dequantizer_backup, self.sdnq_scale_backup, self.sdnq_zero_point_backup, self.sdnq_svd_up_backup, self.sdnq_svd_down_backup
if bias_backup is not None:
self.bias = None
if ex_bias is not None:
network_add_weights(self, model_weights=bias_backup, lora_weights=ex_bias, deactivate=deactivate, device=device, bias=True)
else:
self.bias = torch.nn.Parameter(bias_backup.to(device), requires_grad=False)
if hasattr(self, "qweight") and hasattr(self, "freeze"):
self.freeze()
l.timer.apply += time.time() - t0