1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[Quantization] Add TRT-ModelOpt as a Backend (#11173)

* initial commit

* update

* updates

* update

* update

* update

* update

* update

* update

* addressed PR comments

* update

* addressed PR comments

* update

* update

* update

* update

* update

* update

* updates

* update

* update

* addressed PR comments

* updates

* code formatting

* update

* addressed PR comments

* addressed PR comments

* addressed PR comments

* addressed PR comments

* fix docs and dependencies

* fixed dependency test

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Ishan Modi
2025-09-03 10:14:52 +05:30
committed by GitHub
parent 6549b04ec6
commit 4acbfbf13b
17 changed files with 936 additions and 3 deletions

View File

@@ -13,6 +13,7 @@ from .utils import (
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
is_nvidia_modelopt_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
@@ -111,6 +112,18 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
try:
if not is_torch_available() and not is_accelerate_available() and not is_nvidia_modelopt_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_nvidia_modelopt_objects
_import_structure["utils.dummy_nvidia_modelopt_objects"] = [
name for name in dir(dummy_nvidia_modelopt_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("NVIDIAModelOptConfig")
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -795,6 +808,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .quantizers.quantization_config import QuantoConfig
try:
if not is_nvidia_modelopt_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_nvidia_modelopt_objects import *
else:
from .quantizers.quantization_config import NVIDIAModelOptConfig
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()

View File

@@ -39,6 +39,7 @@ deps = {
"gguf": "gguf>=0.10.0",
"torchao": "torchao>=0.7.0",
"bitsandbytes": "bitsandbytes>=0.43.3",
"nvidia_modelopt[hf]": "nvidia_modelopt[hf]>=0.33.1",
"regex": "regex!=2019.12.17",
"requests": "requests",
"tensorboard": "tensorboard",

View File

@@ -21,9 +21,11 @@ from typing import Dict, Optional, Union
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .gguf import GGUFQuantizer
from .modelopt import NVIDIAModelOptQuantizer
from .quantization_config import (
BitsAndBytesConfig,
GGUFQuantizationConfig,
NVIDIAModelOptConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
@@ -39,6 +41,7 @@ AUTO_QUANTIZER_MAPPING = {
"gguf": GGUFQuantizer,
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
"modelopt": NVIDIAModelOptQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -47,6 +50,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"gguf": GGUFQuantizationConfig,
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
"modelopt": NVIDIAModelOptConfig,
}
@@ -137,6 +141,9 @@ class DiffusersAutoQuantizer:
if isinstance(quantization_config, dict):
quantization_config = cls.from_dict(quantization_config)
if isinstance(quantization_config, NVIDIAModelOptConfig):
quantization_config.check_model_patching()
if warning_msg != "":
warnings.warn(warning_msg)

View File

@@ -0,0 +1 @@
from .modelopt_quantizer import NVIDIAModelOptQuantizer

View File

@@ -0,0 +1,190 @@
from typing import TYPE_CHECKING, Any, Dict, List, Union
from ...utils import (
get_module_from_name,
is_accelerate_available,
is_nvidia_modelopt_available,
is_torch_available,
logging,
)
from ..base import DiffusersQuantizer
if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin
if is_torch_available():
import torch
import torch.nn as nn
if is_accelerate_available():
from accelerate.utils import set_module_tensor_to_device
logger = logging.get_logger(__name__)
class NVIDIAModelOptQuantizer(DiffusersQuantizer):
r"""
Diffusers Quantizer for TensorRT Model Optimizer
"""
use_keep_in_fp32_modules = True
requires_calibration = False
required_packages = ["nvidia_modelopt"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
if not is_nvidia_modelopt_available():
raise ImportError(
"Loading an nvidia-modelopt quantized model requires nvidia-modelopt library (`pip install nvidia-modelopt`)"
)
self.offload = False
device_map = kwargs.get("device_map", None)
if isinstance(device_map, dict):
if "cpu" in device_map.values() or "disk" in device_map.values():
if self.pre_quantized:
raise ValueError(
"You are attempting to perform cpu/disk offload with a pre-quantized modelopt model "
"This is not supported yet. Please remove the CPU or disk device from the `device_map` argument."
)
else:
self.offload = True
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
# ModelOpt imports diffusers internally. This is here to prevent circular imports
from modelopt.torch.quantization.utils import is_quantized
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized:
return True
elif is_quantized(module) and "weight" in tensor_name:
return True
return False
def create_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
*args,
**kwargs,
):
"""
Create the quantized parameter by calling .calibrate() after setting it to the module.
"""
# ModelOpt imports diffusers internally. This is here to prevent circular imports
import modelopt.torch.quantization as mtq
dtype = kwargs.get("dtype", torch.float32)
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized:
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
else:
set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
mtq.calibrate(
module, self.quantization_config.modelopt_config["algorithm"], self.quantization_config.forward_loop
)
mtq.compress(module)
module.weight.requires_grad = False
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
return max_memory
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
if self.quantization_config.quant_type == "FP8":
target_dtype = torch.float8_e4m3fn
return target_dtype
def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
if torch_dtype is None:
logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
torch_dtype = torch.float32
return torch_dtype
def get_conv_param_names(self, model: "ModelMixin") -> List[str]:
"""
Get parameter names for all convolutional layers in a HuggingFace ModelMixin. Includes Conv1d/2d/3d and
ConvTranspose1d/2d/3d.
"""
conv_types = (
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
)
conv_param_names = []
for name, module in model.named_modules():
if isinstance(module, conv_types):
for param_name, _ in module.named_parameters(recurse=False):
conv_param_names.append(f"{name}.{param_name}")
return conv_param_names
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
# ModelOpt imports diffusers internally. This is here to prevent circular imports
import modelopt.torch.opt as mto
if self.pre_quantized:
return
modules_to_not_convert = self.quantization_config.modules_to_not_convert
if modules_to_not_convert is None:
modules_to_not_convert = []
if isinstance(modules_to_not_convert, str):
modules_to_not_convert = [modules_to_not_convert]
modules_to_not_convert.extend(keep_in_fp32_modules)
if self.quantization_config.disable_conv_quantization:
modules_to_not_convert.extend(self.get_conv_param_names(model))
for module in modules_to_not_convert:
self.quantization_config.modelopt_config["quant_cfg"]["*" + module + "*"] = {"enable": False}
self.quantization_config.modules_to_not_convert = modules_to_not_convert
mto.apply_mode(model, mode=[("quantize", self.quantization_config.modelopt_config)])
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model, **kwargs):
# ModelOpt imports diffusers internally. This is here to prevent circular imports
from modelopt.torch.opt import ModeloptStateManager
if self.pre_quantized:
return model
for _, m in model.named_modules():
if hasattr(m, ModeloptStateManager._state_key) and m is not model:
ModeloptStateManager.remove_state(m)
return model
@property
def is_trainable(self):
return True
@property
def is_serializable(self):
self.quantization_config.check_model_patching(operation="saving")
return True

View File

@@ -25,10 +25,11 @@ import importlib.metadata
import inspect
import json
import os
import warnings
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
from packaging import version
@@ -46,6 +47,7 @@ class QuantizationMethod(str, Enum):
GGUF = "gguf"
TORCHAO = "torchao"
QUANTO = "quanto"
MODELOPT = "modelopt"
if is_torchao_available():
@@ -268,7 +270,14 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
if bnb_4bit_quant_storage is None:
self.bnb_4bit_quant_storage = torch.uint8
elif isinstance(bnb_4bit_quant_storage, str):
if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]:
if bnb_4bit_quant_storage not in [
"float16",
"float32",
"int8",
"uint8",
"float64",
"bfloat16",
]:
raise ValueError(
"`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') "
)
@@ -479,7 +488,12 @@ class TorchAoConfig(QuantizationConfigMixin):
```
"""
def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]] = None, **kwargs) -> None:
def __init__(
self,
quant_type: str,
modules_to_not_convert: Optional[List[str]] = None,
**kwargs,
) -> None:
self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert
@@ -724,3 +738,194 @@ class QuantoConfig(QuantizationConfigMixin):
accepted_weights = ["float8", "int8", "int4", "int2"]
if self.weights_dtype not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
@dataclass
class NVIDIAModelOptConfig(QuantizationConfigMixin):
"""This is a config class to use nvidia modelopt for quantization.
Args:
quant_type (`str`):
The type of quantization we want to use, following is how to use:
**weightquant_activationquant ==> FP8_FP8** In the above example we have use FP8 for both weight and
activation quantization. Following are the all the options:
- FP8
- INT8
- INT4
- NF4
- NVFP4
modules_to_not_convert (`List[str]`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
weight_only (`bool`, *optional*, default to `False`):
If set to `True`, the quantization will be applied only to the weights of the model.
channel_quantize (`int`, *optional*, default to `None`):
The channel quantization axis, useful for quantizing models across different axes.
block_quantize (`int`, *optional*, default to `None`):
The block size, useful to further quantize each channel/axes into blocks.
scale_channel_quantize (`int`, *optional*, default to `None`):
The scale channel quantization axis, useful for quantizing calculated scale across different axes.
scale_block_quantize (`int`, *optional*, default to `None`):
The scale block size, useful for quantizing each scale channel/axes into blocks.
algorithm (`str`, *optional*, default to `"max"`):
The algorithm to use for quantization, currently only supports `"max"`.
forward_loop (`Callable`, *optional*, default to `None`):
The forward loop function to use for calibration during quantization.
modelopt_config (`dict`, *optional*, default to `None`):
The modelopt config, useful for passing custom configs to modelopt.
disable_conv_quantization (`bool`, *optional*, default to `False`):
If set to `True`, the quantization will be disabled for convolutional layers.
kwargs (`Dict[str, Any]`, *optional*):
Additional parameters which are to be used for calibration.
"""
quanttype_to_numbits = {
"FP8": (4, 3),
"INT8": 8,
"INT4": 4,
"NF4": 4,
"NVFP4": (2, 1),
}
quanttype_to_scalingbits = {
"NF4": 8,
"NVFP4": (4, 3),
}
def __init__(
self,
quant_type: str,
modules_to_not_convert: Optional[List[str]] = None,
weight_only: bool = True,
channel_quantize: Optional[int] = None,
block_quantize: Optional[int] = None,
scale_channel_quantize: Optional[int] = None,
scale_block_quantize: Optional[int] = None,
algorithm: str = "max",
forward_loop: Optional[Callable] = None,
modelopt_config: Optional[dict] = None,
disable_conv_quantization: bool = False,
**kwargs,
) -> None:
self.quant_method = QuantizationMethod.MODELOPT
self._normalize_quant_type(quant_type)
self.modules_to_not_convert = modules_to_not_convert
self.weight_only = weight_only
self.channel_quantize = channel_quantize
self.block_quantize = block_quantize
self.calib_cfg = {
"method": algorithm,
# add more options here if needed
}
self.forward_loop = forward_loop
self.scale_channel_quantize = scale_channel_quantize
self.scale_block_quantize = scale_block_quantize
self.modelopt_config = self.get_config_from_quant_type() if not modelopt_config else modelopt_config
self.disable_conv_quantization = disable_conv_quantization
def check_model_patching(self, operation: str = "loading"):
# ModelOpt imports diffusers internally. This is here to prevent circular imports
from modelopt.torch.opt.plugins.huggingface import _PATCHED_CLASSES
if len(_PATCHED_CLASSES) == 0:
warning_msg = (
f"Not {operation} weights in modelopt format. This might cause unreliable behavior."
"Please make sure to run the following code before loading/saving model weights:\n\n"
" from modelopt.torch.opt import enable_huggingface_checkpointing\n"
" enable_huggingface_checkpointing()\n"
)
warnings.warn(warning_msg)
def _normalize_quant_type(self, quant_type: str) -> str:
"""
Validates and normalizes the quantization type string.
Splits the quant_type into weight and activation components, verifies them against supported types, and
replaces unsupported values with safe defaults.
Args:
quant_type (str): The input quantization type string (e.g., 'FP8_INT8').
Returns:
str: A valid quantization type string (e.g., 'FP8_INT8' or 'FP8').
"""
parts = quant_type.split("_")
w_type = parts[0]
act_type = parts[1] if len(parts) > 1 else None
if len(parts) > 2:
logger.warning(f"Quantization type {quant_type} is not supported. Picking FP8_INT8 as default")
w_type = "FP8"
act_type = None
else:
if w_type not in NVIDIAModelOptConfig.quanttype_to_numbits:
logger.warning(f"Weight Quantization type {w_type} is not supported. Picking FP8 as default")
w_type = "FP8"
if act_type is not None and act_type not in NVIDIAModelOptConfig.quanttype_to_numbits:
logger.warning(f"Activation Quantization type {act_type} is not supported. Picking INT8 as default")
act_type = None
self.quant_type = w_type + ("_" + act_type if act_type is not None else "")
def get_config_from_quant_type(self) -> Dict[str, Any]:
"""
Get the config from the quantization type.
"""
import modelopt.torch.quantization as mtq
BASE_CONFIG = {
"quant_cfg": {
"*weight_quantizer": {"fake_quant": False},
"*input_quantizer": {},
"*output_quantizer": {"enable": False},
"*q_bmm_quantizer": {},
"*k_bmm_quantizer": {},
"*v_bmm_quantizer": {},
"*softmax_quantizer": {},
**mtq.config._default_disabled_quantizer_cfg,
},
"algorithm": self.calib_cfg,
}
quant_cfg = BASE_CONFIG["quant_cfg"]
if self.weight_only:
for k in quant_cfg:
if "*weight_quantizer" not in k and not quant_cfg[k]:
quant_cfg[k]["enable"] = False
parts = self.quant_type.split("_")
w_type = parts[0]
act_type = parts[1].replace("A", "") if len(parts) > 1 else None
for k in quant_cfg:
if k not in mtq.config._default_disabled_quantizer_cfg and "enable" not in quant_cfg[k]:
if k == "*input_quantizer":
if act_type is not None:
quant_cfg[k]["num_bits"] = NVIDIAModelOptConfig.quanttype_to_numbits[act_type]
continue
quant_cfg[k]["num_bits"] = NVIDIAModelOptConfig.quanttype_to_numbits[w_type]
if self.block_quantize is not None and self.channel_quantize is not None:
quant_cfg["*weight_quantizer"]["block_sizes"] = {self.channel_quantize: self.block_quantize}
quant_cfg["*input_quantizer"]["block_sizes"] = {
self.channel_quantize: self.block_quantize,
"type": "dynamic",
}
elif self.channel_quantize is not None:
quant_cfg["*weight_quantizer"]["axis"] = self.channel_quantize
quant_cfg["*input_quantizer"]["axis"] = self.channel_quantize
quant_cfg["*input_quantizer"]["type"] = "dynamic"
# Only fixed scaling sizes are supported for now in modelopt
if self.scale_channel_quantize is not None and self.scale_block_quantize is not None:
if w_type in NVIDIAModelOptConfig.quanttype_to_scalingbits:
quant_cfg["*weight_quantizer"]["block_sizes"].update(
{
"scale_bits": NVIDIAModelOptConfig.quanttype_to_scalingbits[w_type],
"scale_block_sizes": {self.scale_channel_quantize: self.scale_block_quantize},
}
)
if act_type and act_type in NVIDIAModelOptConfig.quanttype_to_scalingbits:
quant_cfg["*input_quantizer"]["block_sizes"].update(
{
"scale_bits": NVIDIAModelOptConfig.quanttype_to_scalingbits[act_type],
"scale_block_sizes": {self.scale_channel_quantize: self.scale_block_quantize},
}
)
return BASE_CONFIG

View File

@@ -89,6 +89,8 @@ from .import_utils import (
is_matplotlib_available,
is_nltk_available,
is_note_seq_available,
is_nvidia_modelopt_available,
is_nvidia_modelopt_version,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,

View File

@@ -0,0 +1,17 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class NVIDIAModelOptConfig(metaclass=DummyObject):
_backends = ["nvidia_modelopt"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["nvidia_modelopt"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["nvidia_modelopt"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["nvidia_modelopt"])

View File

@@ -226,6 +226,7 @@ _sageattention_available, _sageattention_version = _is_package_available("sageat
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
def is_torch_available():
@@ -364,6 +365,10 @@ def is_optimum_quanto_available():
return _optimum_quanto_available
def is_nvidia_modelopt_available():
return _nvidia_modelopt_available
def is_timm_available():
return _timm_available
@@ -830,6 +835,21 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version)
def is_nvidia_modelopt_version(operation: str, version: str):
"""
Compares the current Nvidia ModelOpt version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _nvidia_modelopt_available:
return False
return compare_versions(parse(_nvidia_modelopt_version), operation, version)
def is_xformers_version(operation: str, version: str):
"""
Compares the current xformers version to a given reference with an operation.

View File

@@ -38,6 +38,7 @@ from .import_utils import (
is_gguf_available,
is_kernels_available,
is_note_seq_available,
is_nvidia_modelopt_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
@@ -638,6 +639,18 @@ def require_torchao_version_greater_or_equal(torchao_version):
return decorator
def require_modelopt_version_greater_or_equal(modelopt_version):
def decorator(test_case):
correct_nvidia_modelopt_version = is_nvidia_modelopt_available() and version.parse(
version.parse(importlib.metadata.version("modelopt")).base_version
) >= version.parse(modelopt_version)
return unittest.skipUnless(
correct_nvidia_modelopt_version, f"Test requires modelopt with version greater than {modelopt_version}."
)(test_case)
return decorator
def require_kernels_version_greater_or_equal(kernels_version):
def decorator(test_case):
correct_kernels_version = is_kernels_available() and version.parse(