1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/sdnq/quantizer.py
2026-01-14 16:23:26 +03:00

1131 lines
54 KiB
Python

# pylint: disable=redefined-builtin,no-member,protected-access
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
from enum import Enum
import re
import torch
from transformers.quantizers import HfQuantizer
from diffusers.quantizers.base import DiffusersQuantizer
from diffusers.quantizers.quantization_config import QuantizationConfigMixin
from diffusers.utils import get_module_from_name
from accelerate import init_empty_weights
from modules import devices, shared
from .common import sdnq_version, dtype_dict, common_skip_keys, module_skip_keys_dict, accepted_weight_dtypes, accepted_matmul_dtypes, weights_dtype_order, weights_dtype_order_fp32, allowed_types, linear_types, conv_types, conv_transpose_types, compile_func, use_tensorwise_fp8_matmul, use_contiguous_mm, check_torch_compile
from .dequantizer import SDNQDequantizer, dequantize_sdnq_model
from .packed_int import pack_int_symetric, pack_int_asymetric
from .packed_float import pack_float
from .forward import get_forward_func
from .layers import get_sdnq_wrapper_class
class QuantizationMethod(str, Enum):
SDNQ = "sdnq"
SDNQ_TRAINING = "sdnq_training"
@devices.inference_context()
def get_scale_asymmetric(weight: torch.FloatTensor, reduction_axes: Union[int, List[int]], weights_dtype: str) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
zero_point = torch.amin(weight, dim=reduction_axes, keepdims=True)
scale = torch.amax(weight, dim=reduction_axes, keepdims=True).sub_(zero_point).div_(dtype_dict[weights_dtype]["max"] - dtype_dict[weights_dtype]["min"])
if dtype_dict[weights_dtype]["min"] != 0:
zero_point.sub_(torch.mul(scale, dtype_dict[weights_dtype]["min"]))
return scale, zero_point
@devices.inference_context()
def get_scale_symmetric(weight: torch.FloatTensor, reduction_axes: Union[int, List[int]], weights_dtype: str) -> torch.FloatTensor:
return torch.amax(weight.abs(), dim=reduction_axes, keepdims=True).div_(dtype_dict[weights_dtype]["max"])
@devices.inference_context()
def quantize_weight(weight: torch.FloatTensor, reduction_axes: Union[int, List[int]], weights_dtype: str, use_stochastic_rounding: bool = False) -> Tuple[torch.Tensor, torch.FloatTensor, torch.FloatTensor]:
weight = weight.to(dtype=torch.float32)
if dtype_dict[weights_dtype]["is_unsigned"]:
scale, zero_point = get_scale_asymmetric(weight, reduction_axes, weights_dtype)
quantized_weight = torch.sub(weight, zero_point).div_(scale)
else:
scale = get_scale_symmetric(weight, reduction_axes, weights_dtype)
quantized_weight = torch.div(weight, scale)
zero_point = None
if dtype_dict[weights_dtype]["is_integer"]:
if use_stochastic_rounding:
quantized_weight.add_(torch.rand_like(quantized_weight), alpha=0.1)
quantized_weight.round_()
else:
if use_stochastic_rounding:
mantissa_difference = 1 << (23 - dtype_dict[weights_dtype]["mantissa"])
quantized_weight = quantized_weight.view(dtype=torch.int32).add_(torch.randint_like(quantized_weight, low=0, high=mantissa_difference, dtype=torch.int32)).view(dtype=torch.float32)
quantized_weight.nan_to_num_()
quantized_weight = quantized_weight.clamp_(dtype_dict[weights_dtype]["min"], dtype_dict[weights_dtype]["max"]).to(dtype_dict[weights_dtype]["torch_dtype"])
return quantized_weight, scale, zero_point
@devices.inference_context()
def apply_svdquant(weight: torch.FloatTensor, rank: int = 32, niter: int = 8) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
reshape_weight = False
if weight.ndim > 2: # convs
reshape_weight = True
weight_shape = weight.shape
weight = weight.flatten(1,-1)
weight = weight.to(dtype=torch.float32)
U, S, svd_down = torch.svd_lowrank(weight, q=rank, niter=niter)
svd_up = torch.mul(U, S.unsqueeze(0))
svd_down = svd_down.t_()
weight = weight.sub(torch.mm(svd_up, svd_down))
if reshape_weight:
weight = weight.unflatten(-1, (*weight_shape[1:],)) # pylint: disable=possibly-used-before-assignment
return weight, svd_up, svd_down
@devices.inference_context()
def prepare_weight_for_matmul(weight: torch.Tensor) -> torch.Tensor:
if use_contiguous_mm:
weight = weight.contiguous()
elif weight.is_contiguous():
weight = weight.t_().contiguous().t_()
return weight
@devices.inference_context()
def prepare_svd_for_matmul(svd_up: torch.FloatTensor, svd_down: torch.FloatTensor, use_quantized_matmul: bool) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
if svd_up is not None:
if use_quantized_matmul:
svd_up = prepare_weight_for_matmul(svd_up)
else:
svd_up = svd_up.contiguous()
if svd_down is not None:
svd_down = prepare_weight_for_matmul(svd_down)
return svd_up, svd_down
def check_param_name_in(param_name: str, param_list: List[str]) -> bool:
split_param_name = param_name.split(".")
for param in param_list:
if param.startswith("."):
if param_name.startswith(param[1:]):
return True
else:
continue
if (
param_name == param
or param in split_param_name
or ("*" in param and re.match(param.replace(".*", "\\.*").replace("*", ".*"), param_name))
):
return True
return False
def get_quant_args_from_config(quantization_config: Union["SDNQConfig", dict]) -> dict:
if isinstance(quantization_config, SDNQConfig):
quantization_config_dict = quantization_config.to_dict()
else:
quantization_config_dict = quantization_config.copy()
quantization_config_dict.pop("is_integer", None)
quantization_config_dict.pop("quant_method", None)
quantization_config_dict.pop("quantization_device", None)
quantization_config_dict.pop("return_device", None)
quantization_config_dict.pop("non_blocking", None)
quantization_config_dict.pop("add_skip_keys", None)
quantization_config_dict.pop("use_dynamic_quantization", None)
quantization_config_dict.pop("use_static_quantization", None)
quantization_config_dict.pop("use_stochastic_rounding", None)
quantization_config_dict.pop("use_grad_ckpt", None)
quantization_config_dict.pop("is_training", None)
quantization_config_dict.pop("sdnq_version", None)
return quantization_config_dict
def get_minimum_dtype(weights_dtype: str, param_name: str, modules_dtype_dict: Dict[str, List[str]]):
if len(modules_dtype_dict.keys()) > 0:
for key, value in modules_dtype_dict.items():
if check_param_name_in(param_name, value):
key = key.lower()
if key in {"8bit", "8bits"}:
if dtype_dict[weights_dtype]["num_bits"] != 8:
return "int8"
elif key.startswith("minimum_"):
minimum_bits_str = key.removeprefix("minimum_").removesuffix("bits").removesuffix("bit")
if minimum_bits_str.startswith("uint"):
is_unsigned = True
minimum_bits_str = minimum_bits_str.removeprefix("uint")
else:
is_unsigned = False
minimum_bits_str = minimum_bits_str.removeprefix("int")
minimum_bits = int(minimum_bits_str)
if dtype_dict[weights_dtype]["num_bits"] < minimum_bits:
if is_unsigned or minimum_bits <= 4:
return "uint" + minimum_bits_str
else:
return "int" + minimum_bits_str
else:
return key
return weights_dtype
def add_module_skip_keys(model, modules_to_not_convert: List[str] = None, modules_dtype_dict: Dict[str, List[str]] = None):
if modules_to_not_convert is None:
modules_to_not_convert = []
if modules_dtype_dict is None:
modules_dtype_dict = {}
if getattr(model, "_keep_in_fp32_modules", None) is not None:
modules_to_not_convert.extend(model._keep_in_fp32_modules) # pylint: disable=protected-access
if getattr(model, "_tied_weights_keys", None) is not None:
if isinstance(model._tied_weights_keys, dict): # pylint: disable=protected-access
modules_to_not_convert.extend(model._tied_weights_keys.keys()) # pylint: disable=protected-access
modules_to_not_convert.extend(model._tied_weights_keys.values()) # pylint: disable=protected-access
else:
modules_to_not_convert.extend(model._tied_weights_keys) # pylint: disable=protected-access
skip_key_list = module_skip_keys_dict.get(model.__class__.__name__, None)
if skip_key_list is not None:
modules_to_not_convert.extend(skip_key_list[0])
for key, value in skip_key_list[1].items():
if key in modules_dtype_dict.keys():
modules_dtype_dict[key].extend(value)
else:
modules_dtype_dict[key] = value
else:
modules_to_not_convert.extend(common_skip_keys)
if getattr(model, "_skip_layerwise_casting_patterns", None) is not None:
modules_to_not_convert.extend(model._skip_layerwise_casting_patterns) # pylint: disable=protected-access
# dedupe
modules_to_not_convert = list(set(modules_to_not_convert))
for key, value in modules_dtype_dict.items():
modules_dtype_dict[key] = list(set(value))
return model, modules_to_not_convert, modules_dtype_dict
@devices.inference_context()
def sdnq_quantize_layer_weight(weight, layer_class_name=None, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, use_svd=False, use_quantized_matmul=False, use_stochastic_rounding=False, dequantize_fp32=False, using_pre_calculated_svd=False, param_name=None): # pylint: disable=unused-argument
num_of_groups = 1
is_conv_type = False
is_conv_transpose_type = False
is_linear_type = False
result_shape = None
original_shape = weight.shape
original_stride = weight.stride()
weight = weight.detach()
if torch_dtype is None:
torch_dtype = weight.dtype
if quantized_matmul_dtype is None:
if dtype_dict[weights_dtype]["is_integer"]:
quantized_matmul_dtype = "int8"
elif dtype_dict[weights_dtype]["num_bits"] == 8:
quantized_matmul_dtype = "float8_e4m3fn"
else:
quantized_matmul_dtype = "float16"
re_quantize_for_matmul = bool(
dtype_dict[weights_dtype]["is_unsigned"]
or dtype_dict[weights_dtype]["is_integer"] != dtype_dict[quantized_matmul_dtype]["is_integer"]
or dtype_dict[weights_dtype]["num_bits"] > dtype_dict[quantized_matmul_dtype]["num_bits"]
or (
dtype_dict[weights_dtype]["is_packed"]
and not dtype_dict[weights_dtype]["is_integer"]
and not dtype_dict[quantized_matmul_dtype]["is_integer"]
and (
dtype_dict[weights_dtype]["num_bits"] >= dtype_dict[quantized_matmul_dtype]["num_bits"]
or dtype_dict[weights_dtype]["max"] > dtype_dict[quantized_matmul_dtype]["max"]
)
)
)
if layer_class_name in conv_types:
is_conv_type = True
reduction_axes = 1
output_channel_size, channel_size = weight.shape[:2]
if use_quantized_matmul:
use_quantized_matmul = channel_size >= 32 and output_channel_size >= 32
use_quantized_matmul = use_quantized_matmul and output_channel_size % 16 == 0 and channel_size % 16 == 0
if use_quantized_matmul and not re_quantize_for_matmul and not dtype_dict[weights_dtype]["is_packed"]:
result_shape = weight.shape
weight = weight.flatten(1,-1)
reduction_axes = -1
elif layer_class_name in conv_transpose_types:
is_conv_transpose_type = True
reduction_axes = 0
channel_size, output_channel_size = weight.shape[:2]
use_quantized_matmul = False
elif layer_class_name in linear_types:
is_linear_type = True
reduction_axes = -1
try:
output_channel_size, channel_size = weight.shape
except Exception as e:
raise ValueError(f"SDNQ: param_name={param_name} layer_class_name={layer_class_name} weight_shape={weight.shape} weights_dtype={weights_dtype} quantized_matmul_dtype={quantized_matmul_dtype} unsupported") from e
if use_quantized_matmul:
use_quantized_matmul = channel_size >= 32 and output_channel_size >= 32
use_quantized_matmul = use_quantized_matmul and output_channel_size % 16 == 0 and channel_size % 16 == 0
else:
if weight.ndim > 1:
output_channel_size, channel_size = weight.shape[-2:]
else:
output_channel_size, channel_size = 1, weight.shape[-1]
reduction_axes = -1
use_quantized_matmul = False
if use_svd:
try:
weight, svd_up, svd_down = apply_svdquant(weight, rank=svd_rank, niter=svd_steps)
if use_quantized_matmul:
svd_up = svd_up.t_()
svd_down = svd_down.t_()
svd_up, svd_down = prepare_svd_for_matmul(svd_up, svd_down, use_quantized_matmul)
except Exception:
svd_up, svd_down = None, None
else:
svd_up, svd_down = None, None
if group_size == 0:
if use_quantized_matmul and not re_quantize_for_matmul and dtype_dict[weights_dtype]["num_bits"] >= 6:
group_size = -1
elif is_linear_type:
group_size = 2 ** ((3 if (svd_up is not None or using_pre_calculated_svd) else 2) + dtype_dict[weights_dtype]["num_bits"])
else:
group_size = 2 ** ((2 if (svd_up is not None or using_pre_calculated_svd) else 1) + dtype_dict[weights_dtype]["num_bits"])
if group_size > 0:
if group_size >= channel_size:
group_size = channel_size
num_of_groups = 1
else:
num_of_groups = channel_size // group_size
while num_of_groups * group_size != channel_size: # find something divisible
num_of_groups -= 1
if num_of_groups <= 1:
group_size = channel_size
num_of_groups = 1
break
group_size = channel_size // num_of_groups
group_size = int(group_size)
num_of_groups = int(num_of_groups)
if num_of_groups > 1:
if result_shape is None:
result_shape = weight.shape
new_shape = list(result_shape)
if is_conv_type:
# output_channel_size, channel_size, X, X
# output_channel_size, num_of_groups, group_size, X, X
new_shape[1] = group_size
new_shape.insert(1, num_of_groups)
reduction_axes = 2
elif is_conv_transpose_type:
#channel_size, output_channel_size, X, X
#num_of_groups, group_size, output_channel_size, X, X
new_shape[0] = group_size
new_shape.insert(0, num_of_groups)
reduction_axes = 1
else:
# output_channel_size, channel_size
# output_channel_size, num_of_groups, group_size
last_dim_index = weight.ndim
new_shape[last_dim_index - 1 : last_dim_index] = (num_of_groups, group_size)
weight = weight.reshape(new_shape)
else:
group_size = -1
weight, scale, zero_point = quantize_weight(weight, reduction_axes, weights_dtype, use_stochastic_rounding=use_stochastic_rounding)
if (
not dequantize_fp32
and dtype_dict[weights_dtype]["num_bits"] <= 8
and not (
use_quantized_matmul
and not dtype_dict[quantized_matmul_dtype]["is_integer"]
and (not use_tensorwise_fp8_matmul or dtype_dict[quantized_matmul_dtype]["num_bits"] == 16)
)
):
scale = scale.to(dtype=torch_dtype)
if zero_point is not None:
zero_point = zero_point.to(dtype=torch_dtype)
if svd_up is not None:
svd_up = svd_up.to(dtype=torch_dtype)
svd_down = svd_down.to(dtype=torch_dtype)
re_quantize_for_matmul = re_quantize_for_matmul or num_of_groups > 1
if use_quantized_matmul and not re_quantize_for_matmul and not dtype_dict[weights_dtype]["is_packed"]:
scale.t_()
weight.t_()
weight = prepare_weight_for_matmul(weight)
if not use_tensorwise_fp8_matmul and not dtype_dict[quantized_matmul_dtype]["is_integer"]:
scale = scale.to(dtype=torch.float32)
sdnq_dequantizer = SDNQDequantizer(
result_dtype=torch_dtype,
result_shape=result_shape,
original_shape=original_shape,
original_stride=original_stride,
quantized_weight_shape=weight.shape,
weights_dtype=weights_dtype,
quantized_matmul_dtype=quantized_matmul_dtype,
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
use_quantized_matmul=use_quantized_matmul,
re_quantize_for_matmul=re_quantize_for_matmul,
use_stochastic_rounding=use_stochastic_rounding,
layer_class_name=layer_class_name,
)
if dtype_dict[weights_dtype]["is_packed"]:
if dtype_dict[weights_dtype]["is_integer"]:
if dtype_dict[weights_dtype]["is_unsigned"]:
weight = pack_int_asymetric(weight, weights_dtype)
else:
weight = pack_int_symetric(weight, weights_dtype)
else:
weight = pack_float(weight, weights_dtype)
else:
weight = weight.to(dtype=dtype_dict[weights_dtype]["torch_dtype"])
return weight, scale, zero_point, svd_up, svd_down, sdnq_dequantizer
@devices.inference_context()
def sdnq_quantize_layer_weight_dynamic(weight, layer_class_name=None, weights_dtype="int2", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, dynamic_loss_threshold=1e-2, use_svd=False, use_quantized_matmul=False, use_dynamic_quantization=False, use_stochastic_rounding=False, dequantize_fp32=False, param_name=None): # pylint: disable=unused-argument
if torch_dtype is None:
torch_dtype = weight.dtype
weights_dtype_order_to_use = weights_dtype_order_fp32 if torch_dtype in {torch.float32, torch.float64} else weights_dtype_order
weight = weight.to(dtype=torch.float32)
weight_std = weight.std().square()
if use_svd:
try:
svd_weight, svd_up, svd_down = apply_svdquant(weight, rank=svd_rank, niter=svd_steps)
svd_up, svd_down = prepare_svd_for_matmul(svd_up, svd_down, use_quantized_matmul)
svd_up = svd_up.to(dtype=torch_dtype)
svd_down = svd_down.to(dtype=torch_dtype)
except Exception:
svd_up, svd_down = None, None
svd_weight = weight
else:
svd_up, svd_down = None, None
svd_weight = weight
quantization_loss = None
svd_is_transposed = False
for i in range(weights_dtype_order_to_use.index(weights_dtype), len(weights_dtype_order_to_use)):
quantized_weight, scale, zero_point, _, _, sdnq_dequantizer = sdnq_quantize_layer_weight(
svd_weight,
layer_class_name=layer_class_name,
weights_dtype=weights_dtype_order_to_use[i],
quantized_matmul_dtype=quantized_matmul_dtype,
torch_dtype=torch_dtype,
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
use_svd=False,
using_pre_calculated_svd=use_svd,
use_quantized_matmul=use_quantized_matmul,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
param_name=param_name,
)
if use_svd and not svd_is_transposed and sdnq_dequantizer.use_quantized_matmul:
svd_up = svd_up.t_()
svd_down = svd_down.t_()
svd_is_transposed = True
quantization_loss = torch.nn.functional.mse_loss(weight, sdnq_dequantizer(quantized_weight, scale, zero_point, svd_up, svd_down, skip_quantized_matmul=sdnq_dequantizer.use_quantized_matmul, dtype=torch.float32, skip_compile=True)).div_(weight_std)
if quantization_loss <= dynamic_loss_threshold:
return (quantized_weight, scale, zero_point, svd_up, svd_down, sdnq_dequantizer)
return None
@devices.inference_context()
def sdnq_quantize_layer(layer, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, dynamic_loss_threshold=1e-2, use_svd=False, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, use_dynamic_quantization=False, use_stochastic_rounding=False, dequantize_fp32=False, non_blocking=False, modules_to_not_convert=None, modules_dtype_dict=None, quantization_device=None, return_device=None, param_name=None): # pylint: disable=unused-argument
layer_class_name = layer.__class__.__name__
if layer_class_name in conv_transpose_types or layer_class_name in conv_types:
if not quant_conv:
return layer, modules_to_not_convert, modules_dtype_dict
use_quantized_matmul = use_quantized_matmul_conv
layer.weight.requires_grad_(False)
if return_device is None:
return_device = layer.weight.device
if quantization_device is not None:
layer.weight.data = layer.weight.to(quantization_device, non_blocking=non_blocking)
if use_dynamic_quantization:
weight_data = sdnq_quantize_layer_weight_dynamic(
layer.weight,
layer_class_name=layer_class_name,
weights_dtype=weights_dtype,
quantized_matmul_dtype=quantized_matmul_dtype,
torch_dtype=torch_dtype,
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
dynamic_loss_threshold=dynamic_loss_threshold,
use_svd=use_svd,
use_quantized_matmul=use_quantized_matmul,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
param_name=param_name,
)
else:
weight_data = sdnq_quantize_layer_weight(
layer.weight,
layer_class_name=layer_class_name,
weights_dtype=weights_dtype,
quantized_matmul_dtype=quantized_matmul_dtype,
torch_dtype=torch_dtype,
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
use_svd=use_svd,
use_quantized_matmul=use_quantized_matmul,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
param_name=param_name,
)
if weight_data is not None:
(
layer.weight.data,
layer.scale, layer.zero_point,
layer.svd_up, layer.svd_down,
layer.sdnq_dequantizer,
) = weight_data
del weight_data
layer = get_sdnq_wrapper_class(layer, get_forward_func(layer_class_name, layer.sdnq_dequantizer.quantized_matmul_dtype, layer.sdnq_dequantizer.use_quantized_matmul))
layer.weight = torch.nn.Parameter(layer.weight.to(return_device, non_blocking=non_blocking), requires_grad=False)
layer.scale = torch.nn.Parameter(layer.scale.to(return_device, non_blocking=non_blocking), requires_grad=False)
if layer.zero_point is not None:
layer.zero_point = torch.nn.Parameter(layer.zero_point.to(return_device, non_blocking=non_blocking), requires_grad=False)
if layer.svd_up is not None:
layer.svd_up = torch.nn.Parameter(layer.svd_up.to(return_device, non_blocking=non_blocking), requires_grad=False)
layer.svd_down = torch.nn.Parameter(layer.svd_down.to(return_device, non_blocking=non_blocking), requires_grad=False)
layer = layer.to(return_device, non_blocking=non_blocking)
if use_dynamic_quantization:
if modules_dtype_dict is None:
modules_dtype_dict = {}
if layer.sdnq_dequantizer.weights_dtype not in modules_dtype_dict.keys():
modules_dtype_dict[layer.sdnq_dequantizer.weights_dtype] = [param_name]
else:
modules_dtype_dict[layer.sdnq_dequantizer.weights_dtype].append(param_name)
else:
layer = layer.to(return_device, dtype=torch_dtype, non_blocking=non_blocking)
if use_dynamic_quantization:
if modules_to_not_convert is None:
modules_to_not_convert = []
modules_to_not_convert.append(param_name)
return layer, modules_to_not_convert, modules_dtype_dict
@devices.inference_context()
def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, dynamic_loss_threshold=1e-2, use_svd=False, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, use_dynamic_quantization=False, use_stochastic_rounding=False, dequantize_fp32=False, non_blocking=False, modules_to_not_convert: List[str] = None, modules_dtype_dict: Dict[str, List[str]] = None, quantization_device=None, return_device=None, full_param_name=""): # pylint: disable=unused-argument
has_children = list(model.children())
if not has_children:
return model, modules_to_not_convert, modules_dtype_dict
if modules_to_not_convert is None:
modules_to_not_convert = []
if modules_dtype_dict is None:
modules_dtype_dict = {}
for module_name, module in model.named_children():
if full_param_name:
param_name = full_param_name + "." + module_name
else:
param_name = module_name
if hasattr(module, "weight") and module.weight is not None:
param_name = param_name + ".weight"
if check_param_name_in(param_name, modules_to_not_convert):
continue
layer_class_name = module.__class__.__name__
if layer_class_name in allowed_types and module.weight.dtype in {torch.float32, torch.float16, torch.bfloat16}:
if (layer_class_name in conv_types or layer_class_name in conv_transpose_types) and not quant_conv:
continue
module, modules_to_not_convert, modules_dtype_dict = sdnq_quantize_layer(
module,
weights_dtype=get_minimum_dtype(weights_dtype, param_name, modules_dtype_dict),
quantized_matmul_dtype=quantized_matmul_dtype,
torch_dtype=torch_dtype,
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
dynamic_loss_threshold=dynamic_loss_threshold,
use_svd=use_svd,
quant_conv=quant_conv,
use_quantized_matmul=use_quantized_matmul,
use_quantized_matmul_conv=use_quantized_matmul_conv,
use_dynamic_quantization=use_dynamic_quantization,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
non_blocking=non_blocking,
quantization_device=quantization_device,
return_device=return_device,
modules_to_not_convert=modules_to_not_convert,
modules_dtype_dict=modules_dtype_dict,
param_name=param_name,
)
setattr(model, module_name, module)
module, modules_to_not_convert, modules_dtype_dict = apply_sdnq_to_module(
module,
dynamic_loss_threshold=dynamic_loss_threshold,
weights_dtype=weights_dtype,
quantized_matmul_dtype=quantized_matmul_dtype,
torch_dtype=torch_dtype,
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
use_svd=use_svd,
quant_conv=quant_conv,
use_quantized_matmul=use_quantized_matmul,
use_quantized_matmul_conv=use_quantized_matmul_conv,
use_dynamic_quantization=use_dynamic_quantization,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
non_blocking=non_blocking,
quantization_device=quantization_device,
return_device=return_device,
modules_to_not_convert=modules_to_not_convert,
modules_dtype_dict=modules_dtype_dict,
full_param_name=param_name,
)
setattr(model, module_name, module)
return model, modules_to_not_convert, modules_dtype_dict
@devices.inference_context()
def sdnq_post_load_quant(
model: torch.nn.Module,
weights_dtype: str = "int8",
quantized_matmul_dtype: str = None,
torch_dtype: torch.dtype = None,
group_size: int = 0,
svd_rank: int = 32,
svd_steps: int = 8,
dynamic_loss_threshold: float = 1e-2,
use_svd: bool = False,
quant_conv: bool = False,
use_quantized_matmul: bool = False,
use_quantized_matmul_conv: bool = False,
use_dynamic_quantization: bool = False,
use_stochastic_rounding: bool = False,
dequantize_fp32: bool = False,
non_blocking: bool = False,
add_skip_keys:bool = True,
modules_to_not_convert: List[str] = None,
modules_dtype_dict: Dict[str, List[str]] = None,
quantization_device: Optional[torch.device] = None,
return_device: Optional[torch.device] = None,
):
if modules_to_not_convert is None:
modules_to_not_convert = []
if modules_dtype_dict is None:
modules_dtype_dict = {}
modules_to_not_convert = modules_to_not_convert.copy()
modules_dtype_dict = modules_dtype_dict.copy()
if add_skip_keys:
model, modules_to_not_convert, modules_dtype_dict = add_module_skip_keys(model, modules_to_not_convert, modules_dtype_dict)
quantization_config = SDNQConfig(
weights_dtype=weights_dtype,
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
dynamic_loss_threshold=dynamic_loss_threshold,
use_svd=use_svd,
quant_conv=quant_conv,
use_quantized_matmul=use_quantized_matmul,
use_quantized_matmul_conv=use_quantized_matmul_conv,
use_dynamic_quantization=use_dynamic_quantization,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
non_blocking=non_blocking,
add_skip_keys=add_skip_keys,
modules_to_not_convert=modules_to_not_convert,
modules_dtype_dict=modules_dtype_dict,
quantization_device=quantization_device,
return_device=return_device,
)
model.eval()
model, modules_to_not_convert, modules_dtype_dict = apply_sdnq_to_module(
model,
weights_dtype=weights_dtype,
quantized_matmul_dtype=quantized_matmul_dtype,
torch_dtype=torch_dtype,
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
dynamic_loss_threshold=dynamic_loss_threshold,
use_svd=use_svd,
quant_conv=quant_conv,
use_quantized_matmul=use_quantized_matmul,
use_quantized_matmul_conv=use_quantized_matmul_conv,
use_dynamic_quantization=use_dynamic_quantization,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
non_blocking=non_blocking,
modules_to_not_convert=modules_to_not_convert,
modules_dtype_dict=modules_dtype_dict,
quantization_device=quantization_device,
return_device=return_device,
)
quantization_config.modules_to_not_convert = modules_to_not_convert
quantization_config.modules_dtype_dict = modules_dtype_dict
model.quantization_config = quantization_config
if hasattr(model, "config"):
try:
model.config.quantization_config = model.quantization_config
except Exception:
pass
try:
model.config["quantization_config"] = model.quantization_config.to_dict()
except Exception:
pass
model.quantization_method = QuantizationMethod.SDNQ
return model
class SDNQQuantize():
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer
def convert(
self,
input_dict: dict[str, list[torch.Tensor]],
model: torch.nn.Module = None,
full_layer_name: str = None,
missing_keys: list[str] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> dict[str, torch.FloatTensor]:
_module_name, value = tuple(input_dict.items())[0]
value = value[0]
self.hf_quantizer.create_quantized_param(model, value, full_layer_name, value.device)
param, name = get_module_from_name(model, full_layer_name)
param = getattr(param, name)
return {full_layer_name: param}
@property
def reverse_op(self):
raise NotImplementedError
class SDNQQuantizer(DiffusersQuantizer, HfQuantizer):
r"""
Diffusers and Transformers Quantizer for SDNQ
"""
requires_parameters_quantization = True
use_keep_in_fp32_modules = True
requires_calibration = False
required_packages = None
torch_dtype = None
def check_if_quantized_param(
self,
model,
param_value: "torch.Tensor",
param_name: str,
*args, **kwargs, # pylint: disable=unused-argument,keyword-arg-before-vararg
):
if self.pre_quantized:
layer, _tensor_name = get_module_from_name(model, param_name)
if hasattr(layer, "sdnq_dequantizer"):
return True
elif param_name.endswith(".weight"):
if not check_param_name_in(param_name, self.quantization_config.modules_to_not_convert):
layer_class_name = get_module_from_name(model, param_name)[0].__class__.__name__
if layer_class_name in allowed_types:
if layer_class_name in conv_types or layer_class_name in conv_transpose_types:
if self.quantization_config.quant_conv:
return True
else:
return True
return False
def check_quantized_param(self, *args, **kwargs) -> bool:
"""
needed for transformers compatibilty, returns self.check_if_quantized_param
"""
return self.check_if_quantized_param(*args, **kwargs)
def param_needs_quantization(self, model, param_name: str, *args, **kwargs) -> bool:
"""
needed for transformers compatibilty, returns self.check_if_quantized_param
"""
return self.check_if_quantized_param(model, None, param_name, *args, **kwargs)
@devices.inference_context()
def create_quantized_param( # pylint: disable=arguments-differ
self,
model,
param_value: torch.FloatTensor,
param_name: str,
target_device: torch.device,
*args, **kwargs, # pylint: disable=unused-argument
):
if self.pre_quantized:
layer, tensor_name = get_module_from_name(model, param_name)
if param_value is not None:
return_dtype = param_value.dtype if tensor_name == "weight" else torch.float32 if self.quantization_config.dequantize_fp32 else kwargs.get("dtype", param_value.dtype if self.torch_dtype is None else self.torch_dtype)
if param_value.dtype == return_dtype and devices.same_device(param_value.device, target_device):
param_value = param_value.clone()
else:
param_value = param_value.to(target_device, dtype=return_dtype)
if tensor_name == "weight" and layer.sdnq_dequantizer.use_quantized_matmul and not layer.sdnq_dequantizer.re_quantize_for_matmul:
param_value = prepare_weight_for_matmul(param_value)
elif tensor_name == "svd_up":
param_value, _ = prepare_svd_for_matmul(param_value, None, layer.sdnq_dequantizer.use_quantized_matmul)
elif tensor_name == "svd_down":
_, param_value = prepare_svd_for_matmul(None, param_value, layer.sdnq_dequantizer.use_quantized_matmul)
param_value = torch.nn.Parameter(param_value, requires_grad=False)
param_value._is_hf_initialized = True # pylint: disable=protected-access
setattr(layer, tensor_name, param_value)
return
torch_dtype = kwargs.get("dtype", param_value.dtype if self.torch_dtype is None else self.torch_dtype)
weights_dtype = get_minimum_dtype(self.quantization_config.weights_dtype, param_name, self.quantization_config.modules_dtype_dict)
if self.quantization_config.return_device is not None:
return_device = self.quantization_config.return_device
else:
return_device = target_device
if self.quantization_config.quantization_device is not None:
target_device = self.quantization_config.quantization_device
if param_value.dtype == torch.float32 and devices.same_device(param_value.device, target_device):
param_value = param_value.clone()
else:
param_value = param_value.to(target_device, non_blocking=self.quantization_config.non_blocking).to(dtype=torch.float32)
layer, tensor_name = get_module_from_name(model, param_name)
layer.weight = torch.nn.Parameter(param_value, requires_grad=False)
layer, self.quantization_config.modules_to_not_convert, self.quantization_config.modules_dtype_dict = sdnq_quantize_layer(
layer,
weights_dtype=weights_dtype,
quantized_matmul_dtype=self.quantization_config.quantized_matmul_dtype,
torch_dtype=torch_dtype,
group_size=self.quantization_config.group_size,
svd_rank=self.quantization_config.svd_rank,
svd_steps=self.quantization_config.svd_steps,
dynamic_loss_threshold=self.quantization_config.dynamic_loss_threshold,
use_svd=self.quantization_config.use_svd,
quant_conv=self.quantization_config.quant_conv,
use_quantized_matmul=self.quantization_config.use_quantized_matmul,
use_quantized_matmul_conv=self.quantization_config.use_quantized_matmul_conv,
use_dynamic_quantization=self.quantization_config.use_dynamic_quantization,
use_stochastic_rounding=self.quantization_config.use_stochastic_rounding,
dequantize_fp32=self.quantization_config.dequantize_fp32,
non_blocking=self.quantization_config.non_blocking,
modules_to_not_convert=self.quantization_config.modules_to_not_convert,
modules_dtype_dict=self.quantization_config.modules_dtype_dict,
quantization_device=None,
return_device=return_device,
param_name=param_name,
)
layer.weight._is_hf_initialized = True # pylint: disable=protected-access
if hasattr(layer, "scale"):
layer.scale._is_hf_initialized = True # pylint: disable=protected-access
if layer.zero_point is not None:
layer.zero_point._is_hf_initialized = True # pylint: disable=protected-access
if layer.svd_up is not None:
layer.svd_up._is_hf_initialized = True # pylint: disable=protected-access
layer.svd_down._is_hf_initialized = True # pylint: disable=protected-access
parent_module, tensor_name = get_module_from_name(model, param_name.removesuffix(tensor_name).removesuffix("."))
setattr(parent_module, tensor_name, layer)
def get_quantize_ops(self):
return SDNQQuantize(self)
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.80 for key, val in max_memory.items()}
return max_memory
def adjust_target_dtype(self, target_dtype: torch.dtype) -> torch.dtype: # pylint: disable=unused-argument,arguments-renamed
return dtype_dict[self.quantization_config.weights_dtype]["target_dtype"]
def update_torch_dtype(self, torch_dtype: torch.dtype = None) -> torch.dtype:
self.torch_dtype = torch_dtype
return torch_dtype
def update_dtype(self, dtype: torch.dtype = None) -> torch.dtype:
"""
needed for transformers compatibilty, returns self.update_torch_dtype
"""
return self.update_torch_dtype(dtype)
def _process_model_before_weight_loading( # pylint: disable=arguments-differ
self,
model,
device_map, # pylint: disable=unused-argument
keep_in_fp32_modules: List[str] = None,
**kwargs, # pylint: disable=unused-argument
):
if self.pre_quantized:
self.quantization_config.quantization_device = None
self.quantization_config.return_device = None
self.quantization_config.non_blocking = False
self.quantization_config.add_skip_keys = False
with init_empty_weights():
model = sdnq_post_load_quant(model, torch_dtype=self.torch_dtype, add_skip_keys=False, use_dynamic_quantization=False, **get_quant_args_from_config(self.quantization_config))
if self.quantization_config.add_skip_keys:
if keep_in_fp32_modules is not None:
self.quantization_config.modules_to_not_convert.extend(keep_in_fp32_modules)
if hasattr(self, "get_modules_to_not_convert") and hasattr(model, "tie_weights"):
self.quantization_config.modules_to_not_convert.extend(self.get_modules_to_not_convert(model, add_default_skips=True))
model, self.quantization_config.modules_to_not_convert, self.quantization_config.modules_dtype_dict = add_module_skip_keys(
model, self.quantization_config.modules_to_not_convert, self.quantization_config.modules_dtype_dict
)
if hasattr(model, "config"):
try:
model.config.quantization_config = self.quantization_config
except Exception:
pass
try:
model.config["quantization_config"] = self.quantization_config.to_dict()
except Exception:
pass
model.quantization_config = self.quantization_config
model.quantization_method = QuantizationMethod.SDNQ
def _process_model_after_weight_loading(self, model, **kwargs): # pylint: disable=unused-argument
if self.pre_quantized:
from .loader import post_process_model
model = post_process_model(model)
if self.quantization_config.is_training:
from .training import convert_sdnq_model_to_training
model = convert_sdnq_model_to_training(
model,
dtype=self.torch_dtype,
quantized_matmul_dtype=self.quantization_config.quantized_matmul_dtype,
use_grad_ckpt=self.quantization_config.use_grad_ckpt,
use_quantized_matmul=self.quantization_config.use_quantized_matmul,
use_stochastic_rounding=self.quantization_config.use_stochastic_rounding,
dequantize_fp32=self.quantization_config.dequantize_fp32,
)
if shared.opts.diffusers_offload_mode != "none":
try:
model = model.to(device=devices.cpu)
except Exception:
model = model.to_empty(device=devices.cpu)
devices.torch_gc(force=True, reason="sdnq")
return model
def get_accelerator_warm_up_factor(self):
return 32 // dtype_dict[self.quantization_config.weights_dtype]["num_bits"]
def get_cuda_warm_up_factor(self):
"""
needed for transformers compatibilty, returns self.get_accelerator_warm_up_factor
"""
return self.get_accelerator_warm_up_factor()
def _dequantize(self, model):
return dequantize_sdnq_model(model)
def is_serializable(self, *args, **kwargs) -> bool: # pylint: disable=unused-argument, invalid-overridden-method
return not self.quantization_config.is_training
@property
def is_trainable(self):
return self.quantization_config.is_training
@property
def is_qat_trainable(self) -> bool:
return self.is_trainable()
@property
def is_compileable(self):
return True
@dataclass
class SDNQConfig(QuantizationConfigMixin):
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `sdnq`.
Args:
weights_dtype (`str`, *optional*, defaults to `"int8"`):
The target dtype for the weights after quantization.
Check out `sdnq.common.accepted_weight_dtypes` for all the supported values.
These are some of the recommended values to use: ("int8", "int7", "int6", "uint5", "uint4", "uint3", "uint2", "float8_e4m3fn", "float7_e3m3fn", "float6_e3m2fn", "float5_e2m2fn", "float4_e2m1fn", "float3_e1m1fn", "float2_e1m0fn")
quantized_matmul_dtype (`str`, *optional*, defaults to `None`):
The target dtype for quantized matmul.
`None` will use "int8" with integer weight dtypes and "float8_e4m3fn" or "float16" with float weight dtypes.
Supported values are: ("int8", "float8_e4m3fn", "float16")
group_size (`int`, *optional*, defaults to `0`):
Used to decide how many elements of a tensor will share the same quantization group.
group_size = 0 will automatically select a group size based on weights_dtype.
svd_rank (`int`, *optional*, defaults to `32`):
The rank size used for the SVDQuant algorithm.
dynamic_loss_threshold (`float`, *optional*, defaults to `1e-2`):
The target quantization mse loss threshold to use for dynamic quantization.
svd_steps (`int`, *optional*, defaults to `8`):
The number of iterations to use in svd lowrank estimation.
use_svd (`bool`, *optional*, defaults to `False`):
Enabling this option will use SVDQuant algorithm on top of SDNQ quantization.
quant_conv (`bool`, *optional*, defaults to `False`):
Enabling this option will quantize the convolutional layers in UNet models too.
use_quantized_matmul (`bool`, *optional*, defaults to `False`):
Enabling this option will use quantized INT8 or FP8 MatMul instead of BF16 / FP16.
use_quantized_matmul_conv (`bool`, *optional*, defaults to `False`):
Same as use_quantized_matmul_conv but for the convolutional layers with UNets like SDXL.
use_stochastic_rounding (`bool`, *optional*, defaults to `False`):
Enabling this option will use stochastic rounding on the quantization step.
use_dynamic_quantization (`bool`, *optional*, defaults to `False`):
Enabling this option will dynamically select a per layer quantization type based on the dynamic_loss_threshold.
weights_dtype will be used as the minimum allowed quantization type when this option is enabled.
dequantize_fp32 (`bool`, *optional*, defaults to `False`):
Enabling this option will use FP32 on the dequantization step.
non_blocking (`bool`, *optional*, defaults to `False`):
Enabling this option will use non blocking ops when moving layers between the quantization device and the return device.
add_skip_keys (`bool`, *optional*, defaults to `True`):
Disabling this option won't add model specific modules_to_not_convert and modules_dtype_dict keys.
quantization_device (`torch.device`, *optional*, defaults to `None`):
Used to set which device will be used for the quantization calculation on model load.
return_device (`torch.device`, *optional*, defaults to `None`):
Used to set which device will the quantized weights be sent back to.
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
modules_dtype_dict (`dict`, *optional*, default to `None`):
The dict of dtypes and list of modules, useful for quantizing some modules with a different dtype.
"""
def __init__( # pylint: disable=super-init-not-called
self,
weights_dtype: str = "int8",
quantized_matmul_dtype: str = None,
group_size: int = 0,
svd_rank: int = 32,
svd_steps: int = 8,
dynamic_loss_threshold: float = 1e-2,
use_svd: bool = False,
use_grad_ckpt: bool = True,
quant_conv: bool = False,
use_quantized_matmul: bool = False,
use_quantized_matmul_conv: bool = False,
use_static_quantization: bool = True,
use_dynamic_quantization: bool = False,
use_stochastic_rounding: bool = False,
dequantize_fp32: bool = False,
non_blocking: bool = False,
add_skip_keys: bool = True,
quantization_device: Optional[torch.device] = None,
return_device: Optional[torch.device] = None,
modules_to_not_convert: Optional[List[str]] = None,
modules_dtype_dict: Optional[Dict[str, List[str]]] = None,
is_training: bool = False,
**kwargs, # pylint: disable=unused-argument
):
self.weights_dtype = weights_dtype
self.quantized_matmul_dtype = quantized_matmul_dtype
self.is_training = is_training
if self.is_training:
self.quant_method = QuantizationMethod.SDNQ_TRAINING
else:
self.quant_method = QuantizationMethod.SDNQ
self.group_size = group_size
self.svd_rank = svd_rank
self.dynamic_loss_threshold = dynamic_loss_threshold
self.svd_steps = svd_steps
self.use_svd = use_svd
self.use_grad_ckpt = use_grad_ckpt
self.quant_conv = quant_conv
self.use_quantized_matmul = use_quantized_matmul
self.use_quantized_matmul_conv = use_quantized_matmul_conv
self.use_static_quantization = use_static_quantization
self.use_dynamic_quantization = use_dynamic_quantization
self.use_stochastic_rounding = use_stochastic_rounding
self.dequantize_fp32 = dequantize_fp32
self.non_blocking = non_blocking
self.add_skip_keys = add_skip_keys
self.quantization_device = quantization_device
self.return_device = return_device
self.modules_to_not_convert = modules_to_not_convert
self.modules_dtype_dict = modules_dtype_dict
self.is_integer = dtype_dict[self.weights_dtype]["is_integer"]
self.sdnq_version = sdnq_version
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct
"""
if self.use_quantized_matmul and not check_torch_compile():
raise RuntimeError("SDNQ Quantized MatMul requires a working Triton install.")
if self.weights_dtype not in accepted_weight_dtypes:
raise ValueError(f"SDNQ only support weight dtypes in {accepted_weight_dtypes} but found {self.weights_dtype}")
if self.quantized_matmul_dtype is not None and self.quantized_matmul_dtype not in accepted_matmul_dtypes:
raise ValueError(f"SDNQ only support quantized matmul dtypes in {accepted_matmul_dtypes} but found {self.quantized_matmul_dtype}")
if self.modules_to_not_convert is None:
self.modules_to_not_convert = []
elif isinstance(self.modules_to_not_convert, str):
self.modules_to_not_convert = [self.modules_to_not_convert]
elif isinstance(self.modules_to_not_convert, tuple):
self.modules_to_not_convert = list(self.modules_to_not_convert)
elif not isinstance(self.modules_to_not_convert, list):
raise ValueError(f"modules_to_not_convert must be a list but got {type(self.modules_to_not_convert)}")
if self.modules_dtype_dict is None:
self.modules_dtype_dict = {}
elif not isinstance(self.modules_dtype_dict, dict):
raise ValueError(f"modules_dtype_dict must be a dict but got {type(self.modules_dtype_dict)}")
elif len(self.modules_dtype_dict.keys()) > 0:
self.modules_dtype_dict = self.modules_dtype_dict.copy()
for key, value in self.modules_dtype_dict.items():
if isinstance(value, str):
value = [value]
self.modules_dtype_dict[key] = value
elif isinstance(value, tuple):
value = list(value)
self.modules_dtype_dict[key] = value
if not isinstance(key, str) or not isinstance(value, list):
raise ValueError(f"modules_dtype_dict must be a dictionary of strings and lists but got {type(key)} and {type(value)}")
self.modules_to_not_convert = self.modules_to_not_convert.copy()
self.modules_dtype_dict = self.modules_dtype_dict.copy()
def to_dict(self):
quantization_config_dict = self.__dict__.copy() # make serializable
quantization_config_dict["quantization_device"] = str(quantization_config_dict["quantization_device"]) if quantization_config_dict["quantization_device"] is not None else None
quantization_config_dict["return_device"] = str(quantization_config_dict["return_device"]) if quantization_config_dict["return_device"] is not None else None
return quantization_config_dict
import diffusers.quantizers.auto # noqa: E402,RUF100 # pylint: disable=wrong-import-order
diffusers.quantizers.auto.AUTO_QUANTIZER_MAPPING["sdnq"] = SDNQQuantizer
diffusers.quantizers.auto.AUTO_QUANTIZATION_CONFIG_MAPPING["sdnq"] = SDNQConfig
diffusers.quantizers.auto.AUTO_QUANTIZER_MAPPING["sdnq_training"] = SDNQQuantizer
diffusers.quantizers.auto.AUTO_QUANTIZATION_CONFIG_MAPPING["sdnq_training"] = SDNQConfig
import transformers.quantizers.auto # noqa: E402,RUF100 # pylint: disable=wrong-import-order
transformers.quantizers.auto.AUTO_QUANTIZER_MAPPING["sdnq"] = SDNQQuantizer
transformers.quantizers.auto.AUTO_QUANTIZATION_CONFIG_MAPPING["sdnq"] = SDNQConfig
transformers.quantizers.auto.AUTO_QUANTIZER_MAPPING["sdnq_training"] = SDNQQuantizer
transformers.quantizers.auto.AUTO_QUANTIZATION_CONFIG_MAPPING["sdnq_training"] = SDNQConfig
sdnq_quantize_layer_weight_compiled = compile_func(sdnq_quantize_layer_weight)