1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00
Files
sdnext/modules/model_quant_nncf.py
2025-05-11 07:08:21 +03:00

573 lines
21 KiB
Python

from typing import Any, Dict, List, Optional, Union
from dataclasses import dataclass
from enum import Enum
import os
import torch
from diffusers.quantizers.base import DiffusersQuantizer
from diffusers.quantizers.quantization_config import QuantizationConfigMixin
from diffusers.utils import get_module_from_name
from accelerate import init_empty_weights
from accelerate.utils import CustomDtype
from modules import devices, shared
debug = os.environ.get('SD_QUANT_DEBUG', None) is not None
torch_dtype_dict = {
"int8": torch.int8,
"uint8": torch.uint8,
"int4": CustomDtype.INT4,
"uint4": CustomDtype.INT4,
}
weights_dtype_dict = {
"int8_asym": "uint8",
"int8_sym": "int8",
"int4_asym": "uint4",
"int4_sym": "int4",
"int8": "uint8",
"int4": "uint4",
}
linear_types = ["NNCFLinear", "Linear"]
conv_types = ["NNCFConv1d", "NNCFConv2d", "NNCFConv3d", "Conv1d", "Conv2d", "Conv3d"]
conv_transpose_types = ["NNCFConvTranspose1d", "NNCFConvTranspose2d", "NNCFConvTranspose3d", "ConvTranspose1d", "ConvTranspose2d", "ConvTranspose3d"]
allowed_types = []
allowed_types.extend(linear_types)
allowed_types.extend(conv_types)
allowed_types.extend(conv_transpose_types)
class QuantizationMethod(str, Enum):
NNCF = "nncf"
# de-abstracted and modified from the actual quant functions of nncf 2.16.0:
def nncf_compress_layer(layer, num_bits, is_asym_mode, torch_dtype=None, quant_conv=False, param_name=None):
if layer.__class__.__name__ in allowed_types:
if torch_dtype is None:
torch_dtype = devices.dtype
result_shape = None
if layer.__class__.__name__ in conv_types:
if is_asym_mode or not quant_conv: # don't quant convs with asym mode
return layer
reduction_axes = [i for i in range(layer.weight.ndim) if i != 0]
if layer.__class__.__name__ in conv_transpose_types:
if is_asym_mode or not quant_conv: # don't quant convs with asym mode
return layer
reduction_axes = [i for i in range(layer.weight.ndim) if i != 1]
else:
reduction_axes = -1
if shared.opts.nncf_compress_weights_num_of_groups > 1:
num_of_groups = shared.opts.nncf_compress_weights_num_of_groups
channel_size = layer.weight.shape[-1]
group_size = channel_size / num_of_groups
while channel_size % group_size != 0: # find something divisible
num_of_groups -= 1
group_size = channel_size / num_of_groups
if num_of_groups > 1:
result_shape = layer.weight.shape
new_shape = list(result_shape)
last_dim_index = layer.weight.ndim
new_shape[last_dim_index - 1 : last_dim_index] = (int(num_of_groups), int(group_size))
layer.weight.data = layer.weight.reshape(new_shape)
if shared.opts.diffusers_offload_mode != "none":
return_device = layer.weight.data.device
else:
return_device = devices.device
layer.weight.data = layer.weight.data.to(devices.device, dtype=torch.float32)
if is_asym_mode:
level_low = 0
level_high = 2**num_bits - 1
min_values = torch.amin(layer.weight, dim=reduction_axes, keepdims=True) # [a1, r, a2] -> [a1, 1, a2]
max_values = torch.amax(layer.weight, dim=reduction_axes, keepdims=True) # [a1, r, a2] -> [a1, 1, a2]
levels = level_high - level_low + 1
scale = ((max_values - min_values) / (levels - 1)).to(dtype=torch.float32)
eps = torch.finfo(scale.dtype).eps
scale = torch.where(torch.abs(scale) < eps, eps, scale)
zero_point = level_low - torch.round(min_values / scale)
zero_point = torch.clip(zero_point.to(dtype=torch.int32), level_low, level_high).to(dtype=torch.float32)
else:
factor = 2 ** (num_bits - 1)
w_abs_min = torch.abs(torch.amin(layer.weight, dim=reduction_axes, keepdims=True))
w_max = torch.amax(layer.weight, dim=reduction_axes, keepdims=True)
scale = torch.where(w_abs_min >= w_max, w_abs_min, -w_max)
scale /= factor
eps = torch.finfo(scale.dtype).eps
scale = torch.where(torch.abs(scale) < eps, eps, scale)
zero_point = None
dtype = torch.uint8 if is_asym_mode else torch.int8
level_low = 0 if is_asym_mode else -(2 ** (num_bits - 1))
level_high = 2**num_bits - 1 if is_asym_mode else 2 ** (num_bits - 1) - 1
compressed_weight = layer.weight.data / scale
if not shared.opts.nncf_decompress_fp32:
scale = scale.to(torch_dtype)
if zero_point is not None:
compressed_weight += zero_point
zero_point = zero_point.to(scale.dtype)
compressed_weight = torch.round(compressed_weight)
compressed_weight = torch.clip(compressed_weight, level_low, level_high).to(dtype)
if num_bits == 4:
if is_asym_mode:
decompressor = INT4AsymmetricWeightsDecompressor(
scale=scale.data,
zero_point=zero_point.data,
compressed_weight_shape=compressed_weight.shape,
result_dtype=torch_dtype,
result_shape=result_shape,
)
else:
decompressor = INT4SymmetricWeightsDecompressor(
scale=scale.data,
compressed_weight_shape=compressed_weight.shape,
result_dtype=torch_dtype,
result_shape=result_shape,
)
else:
if is_asym_mode:
decompressor = INT8AsymmetricWeightsDecompressor(
scale=scale.data,
zero_point=zero_point.data,
result_dtype=torch_dtype,
result_shape=result_shape,
)
else:
decompressor = INT8SymmetricWeightsDecompressor(
scale=scale.data,
result_dtype=torch_dtype,
result_shape=result_shape,
)
compressed_weight = decompressor.pack_weight(compressed_weight)
compressed_weight = compressed_weight.to(return_device)
decompressor = decompressor.to(return_device)
layer.register_pre_forward_operation(decompressor)
layer.weight.requires_grad = False
layer.weight.data = compressed_weight
return layer
def apply_nncf_to_module(model, num_bits, is_asym_mode, quant_conv=False):
has_children = list(model.children())
if not has_children:
return model
for param_name, module in model.named_children():
if module.__class__.__name__.startswith("NNCF") and hasattr(module, "weight") and module.weight is not None:
module = nncf_compress_layer(module, num_bits, is_asym_mode, torch_dtype=devices.dtype, quant_conv=quant_conv, param_name=param_name)
module = apply_nncf_to_module(module, num_bits, is_asym_mode, quant_conv=quant_conv)
return model
def nncf_send_to_device(model, device):
for child in model.children():
if "WeightsDecompressor" in child.__class__.__name__:
child.scale = child.scale.to(device)
if hasattr(child, "zero_point"):
child.zero_point = child.zero_point.to(device)
nncf_send_to_device(child, device)
class NNCFQuantizer(DiffusersQuantizer):
r"""
Diffusers Quantizer for NNCF
"""
requires_parameters_quantization = True
use_keep_in_fp32_modules = True
requires_calibration = False
required_packages = ["nncf"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
def check_if_quantized_param(
self,
model,
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
module, tensor_name = get_module_from_name(model, param_name)
return module.__class__.__name__.startswith("NNCF") and param_name.endswith(".weight")
def check_quantized_param(self, *args, **kwargs) -> bool:
"""
needed for transformers compatibilty, returns self.check_if_quantized_param
"""
return self.check_if_quantized_param(*args, **kwargs)
def create_quantized_param(
self,
model,
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: List[str],
**kwargs,
):
# load the model params to target_device first
layer, tensor_name = get_module_from_name(model, param_name)
layer._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
split_param_name = param_name.split(".")
if param_name not in self.modules_to_not_convert and not any(param in split_param_name for param in self.modules_to_not_convert):
layer = nncf_compress_layer(
layer,
self.quantization_config.num_bits,
self.quantization_config.is_asym_mode,
torch_dtype=self.torch_dtype,
param_name=param_name
)
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.70 for key, val in max_memory.items()}
return max_memory
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
return torch_dtype_dict[self.quantization_config.weights_dtype]
def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
if torch_dtype is None:
torch_dtype = devices.dtype
self.torch_dtype = torch_dtype
return torch_dtype
def _process_model_before_weight_loading(
self,
model,
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]
if keep_in_fp32_modules is not None:
self.modules_to_not_convert.extend(keep_in_fp32_modules)
model.config.quantization_config = self.quantization_config
with init_empty_weights():
model, _ = replace_modules_by_nncf_modules(model)
def _process_model_after_weight_loading(self, model, **kwargs):
nncf_send_to_device(model, devices.device)
return model
def update_tp_plan(self, config):
"""
needed for transformers compatibilty, no-op function
"""
return config
def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str) -> List[str]:
"""
needed for transformers compatibilty, no-op function
"""
return unexpected_keys
def update_missing_keys_after_loading(self, model, missing_keys: List[str], prefix: str) -> List[str]:
"""
needed for transformers compatibilty, no-op function
"""
return missing_keys
def update_expected_keys(self, model, expected_keys: List[str], loaded_keys: List[str]) -> List[str]:
"""
needed for transformers compatibilty, no-op function
"""
return expected_keys
@property
def is_trainable(self):
return False
@property
def is_serializable(self):
return False
@dataclass
class NNCFConfig(QuantizationConfigMixin):
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `nncf`.
Args:
weights_dtype (`str`, *optional*, defaults to `"int8"`):
The target dtype for the weights after quantization. Supported values are ("int8", "int8_sym", "int4", "int4_sym")
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
"""
def __init__(
self,
weights_dtype: str = "int8_sym",
modules_to_not_convert: Optional[List[str]] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.NNCF
self.weights_dtype = weights_dtype_dict[weights_dtype.lower()]
self.modules_to_not_convert = modules_to_not_convert
self.post_init()
self.num_bits = 8 if self.weights_dtype in {"int8", "uint8"} else 4
self.is_asym_mode = self.weights_dtype in {"uint8", "uint4"}
self.is_integer = True
self.group_size = -1
def post_init(self):
r"""
Safety checker that arguments are correct
"""
accepted_weights = ["int8", "uint8", "int4", "uint4"]
if self.weights_dtype not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
class NNCF_T5DenseGatedActDense(torch.nn.Module): # forward can't find what self is without creating a class
def __init__(self, T5DenseGatedActDense, dtype):
super().__init__()
self.wi_0 = T5DenseGatedActDense.wi_0
self.wi_1 = T5DenseGatedActDense.wi_1
self.wo = T5DenseGatedActDense.wo
self.dropout = T5DenseGatedActDense.dropout
self.act = T5DenseGatedActDense.act
self.torch_dtype = dtype
def forward(self, hidden_states):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states.to(self.torch_dtype) # this line needs to be forced
hidden_states = self.wo(hidden_states)
return hidden_states
# WeightsDecompressor classes and functions are modified from NNCF 2.16.0
def unpack_uint4(packed_tensor: torch.Tensor, shape: torch.Size) -> torch.Tensor:
return torch.stack((torch.bitwise_and(packed_tensor, 15), torch.bitwise_right_shift(packed_tensor, 4)), dim=-1).reshape(shape)
def unpack_int4(packed_tensor: torch.Tensor, shape: torch.Size, dtype: Optional[torch.dtype] = torch.int8) -> torch.Tensor:
return unpack_uint4(packed_tensor, shape).to(dtype=dtype) - 8
def pack_uint4(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype != torch.uint8:
raise RuntimeError(f"Invalid tensor dtype {tensor.type}. torch.uint8 type is supported.")
packed_tensor = tensor.contiguous().reshape(-1, 2)
packed_tensor = torch.bitwise_and(packed_tensor[..., ::2], 15) | packed_tensor[..., 1::2] << 4
return packed_tensor
def pack_int4(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype != torch.int8:
raise RuntimeError(f"Invalid tensor dtype {tensor.type}. torch.int8 type is supported.")
tensor = tensor + 8
return pack_uint4(tensor.to(dtype=torch.uint8))
def decompress_asymmetric(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, dtype: torch.dtype, result_shape: torch.Size) -> torch.Tensor:
result = torch.mul(torch.sub(input.to(dtype=scale.dtype), zero_point), scale).to(dtype=dtype)
if result_shape is not None:
result = result.reshape(result_shape)
return result
def decompress_symmetric(input: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype, result_shape: torch.Size) -> torch.Tensor:
result = torch.mul(input.to(dtype=scale.dtype), scale).to(dtype=dtype)
if result_shape is not None:
result = result.reshape(result_shape)
return result
def decompress_int4_asymmetric(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, shape: torch.Size, dtype: torch.dtype, result_shape: torch.Size) -> torch.Tensor:
return decompress_asymmetric(unpack_uint4(input, shape), scale, zero_point, dtype, result_shape)
def decompress_int4_symmetric(input: torch.Tensor, scale: torch.Tensor, shape: torch.Size, dtype: torch.dtype, result_shape: torch.Size) -> torch.Tensor:
return decompress_symmetric(unpack_int4(input, shape, dtype=scale.dtype), scale, dtype, result_shape)
if shared.opts.nncf_decompress_compile:
try:
torch._dynamo.config.cache_size_limit = max(8192, torch._dynamo.config.cache_size_limit) # pylint: disable=protected-access
decompress_asymmetric = torch.compile(decompress_asymmetric, fullgraph=True)
decompress_symmetric = torch.compile(decompress_symmetric, fullgraph=True)
decompress_int4_asymmetric = torch.compile(decompress_int4_asymmetric, fullgraph=True)
decompress_int4_symmetric = torch.compile(decompress_int4_symmetric, fullgraph=True)
except Exception as e:
shared.log.warning(f"Quantization: type=nncf Decompress using torch.compile is not available: {e}")
class INT8AsymmetricWeightsDecompressor(torch.nn.Module):
def __init__(
self,
scale: torch.Tensor,
zero_point: torch.Tensor,
result_dtype: torch.dtype,
result_shape: torch.Size,
):
super().__init__()
self.scale = scale
self.zero_point = zero_point
self.result_dtype = result_dtype
self.result_shape = result_shape
@property
def num_bits(self):
return 8
@property
def quantization_mode(self):
return "asymmetric"
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if debug:
if torch.any((weight < 0) | (weight > 255)):
raise ValueError("Weight values are not in [0, 255].")
return weight.to(dtype=torch.uint8)
def forward(self, x, *args, return_decompressed_only=False):
result = decompress_asymmetric(x.weight, self.scale, self.zero_point, self.result_dtype, self.result_shape)
if return_decompressed_only:
return result
else:
x.weight = result
class INT8SymmetricWeightsDecompressor(torch.nn.Module):
def __init__(
self,
scale: torch.Tensor,
result_dtype: torch.dtype,
result_shape: torch.Size,
):
super().__init__()
self.scale = scale
self.result_dtype = result_dtype
self.result_shape = result_shape
@property
def num_bits(self):
return 8
@property
def quantization_mode(self):
return "symmetric"
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if debug:
if torch.any((weight < -128) | (weight > 127)):
raise ValueError("Weight values are not in [-128, 127].")
return weight.to(dtype=torch.int8)
def forward(self, x, *args, return_decompressed_only=False):
result = decompress_symmetric(x.weight, self.scale, self.result_dtype, self.result_shape)
if return_decompressed_only:
return result
else:
x.weight = result
class INT4AsymmetricWeightsDecompressor(torch.nn.Module):
def __init__(
self,
scale: torch.Tensor,
zero_point: torch.Tensor,
compressed_weight_shape: torch.Size,
result_dtype: torch.dtype,
result_shape: torch.Size,
):
super().__init__()
self.scale = scale
self.zero_point = zero_point
self.compressed_weight_shape = compressed_weight_shape
self.result_dtype = result_dtype
self.result_shape = result_shape
@property
def num_bits(self):
return 4
@property
def quantization_mode(self):
return "asymmetric"
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if debug:
if torch.any((weight < 0) | (weight > 15)):
raise ValueError("Weight values are not in [0, 15].")
return pack_uint4(weight.to(dtype=torch.uint8))
def forward(self, x, *args, return_decompressed_only=False):
result = decompress_int4_asymmetric(x.weight, self.scale, self.zero_point, self.compressed_weight_shape, self.result_dtype, self.result_shape)
if return_decompressed_only:
return result
else:
x.weight = result
class INT4SymmetricWeightsDecompressor(torch.nn.Module):
def __init__(
self,
scale: torch.Tensor,
compressed_weight_shape: torch.Size,
result_dtype: torch.dtype,
result_shape: torch.Size,
):
super().__init__()
self.scale = scale
self.compressed_weight_shape = compressed_weight_shape
self.result_dtype = result_dtype
self.result_shape = result_shape
@property
def num_bits(self):
return 4
@property
def quantization_mode(self):
return "symmetric"
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if debug:
if torch.any((weight < -8) | (weight > 7)):
raise ValueError("Tensor values are not in [-8, 7].")
return pack_int4(weight.to(dtype=torch.int8))
def forward(self, x, *arg, return_decompressed_only=False):
result = decompress_int4_symmetric(x.weight, self.scale, self.compressed_weight_shape, self.result_dtype, self.result_shape)
if return_decompressed_only:
return result
else:
x.weight = result