1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[quant] QoL improvements for pipeline-level quant config (#11876)

* add repr for pipelinequantconfig.

* update
This commit is contained in:
Sayak Paul
2025-07-10 08:53:01 +05:30
committed by GitHub
parent f33b89bafb
commit b41abb2230
5 changed files with 262 additions and 178 deletions

View File

@@ -40,6 +40,7 @@ _import_structure = {
"models": [],
"modular_pipelines": [],
"pipelines": [],
"quantizers.pipe_quant_config": ["PipelineQuantizationConfig"],
"quantizers.quantization_config": [],
"schedulers": [],
"utils": [

View File

@@ -1096,6 +1096,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
if device_map is not None:
setattr(model, "hf_device_map", final_device_map)
if quantization_config is not None:
setattr(model, "quantization_config", quantization_config)
return model
@property

View File

@@ -12,183 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Dict, List, Optional, Union
from ..utils import is_transformers_available, logging
from .auto import DiffusersAutoQuantizer
from .base import DiffusersQuantizer
from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
try:
from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
except ImportError:
class TransformersQuantConfigMixin:
pass
logger = logging.get_logger(__name__)
class PipelineQuantizationConfig:
"""
Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
Args:
quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
is available to both `diffusers` and `transformers`.
quant_kwargs (`dict`): Params to initialize the quantization backend class.
components_to_quantize (`list`): Components of a pipeline to be quantized.
quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
and `components_to_quantize`.
"""
def __init__(
self,
quant_backend: str = None,
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
components_to_quantize: Optional[List[str]] = None,
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
):
self.quant_backend = quant_backend
# Initialize kwargs to be {} to set to the defaults.
self.quant_kwargs = quant_kwargs or {}
self.components_to_quantize = components_to_quantize
self.quant_mapping = quant_mapping
self.post_init()
def post_init(self):
quant_mapping = self.quant_mapping
self.is_granular = True if quant_mapping is not None else False
self._validate_init_args()
def _validate_init_args(self):
if self.quant_backend and self.quant_mapping:
raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
if not self.quant_mapping and not self.quant_backend:
raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
if not self.quant_kwargs and not self.quant_mapping:
raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
if self.quant_backend is not None:
self._validate_init_kwargs_in_backends()
if self.quant_mapping is not None:
self._validate_quant_mapping_args()
def _validate_init_kwargs_in_backends(self):
quant_backend = self.quant_backend
self._check_backend_availability(quant_backend)
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
if quant_config_mapping_transformers is not None:
init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
else:
init_kwargs_transformers = None
init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
if init_kwargs_transformers != init_kwargs_diffusers:
raise ValueError(
"The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
"this mapping would look like."
)
def _validate_quant_mapping_args(self):
quant_mapping = self.quant_mapping
transformers_map, diffusers_map = self._get_quant_config_list()
available_transformers = list(transformers_map.values()) if transformers_map else None
available_diffusers = list(diffusers_map.values())
for module_name, config in quant_mapping.items():
if any(isinstance(config, cfg) for cfg in available_diffusers):
continue
if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
continue
if available_transformers:
raise ValueError(
f"Provided config for module_name={module_name} could not be found. "
f"Available diffusers configs: {available_diffusers}; "
f"Available transformers configs: {available_transformers}."
)
else:
raise ValueError(
f"Provided config for module_name={module_name} could not be found. "
f"Available diffusers configs: {available_diffusers}."
)
def _check_backend_availability(self, quant_backend: str):
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
available_backends_transformers = (
list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
)
available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
if (
available_backends_transformers and quant_backend not in available_backends_transformers
) or quant_backend not in quant_config_mapping_diffusers:
error_message = f"Provided quant_backend={quant_backend} was not found."
if available_backends_transformers:
error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
raise ValueError(error_message)
def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
quant_mapping = self.quant_mapping
components_to_quantize = self.components_to_quantize
# Granular case
if self.is_granular and module_name in quant_mapping:
logger.debug(f"Initializing quantization config class for {module_name}.")
config = quant_mapping[module_name]
return config
# Global config case
else:
should_quantize = False
# Only quantize the modules requested for.
if components_to_quantize and module_name in components_to_quantize:
should_quantize = True
# No specification for `components_to_quantize` means all modules should be quantized.
elif not self.is_granular and not components_to_quantize:
should_quantize = True
if should_quantize:
logger.debug(f"Initializing quantization config class for {module_name}.")
mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
quant_config_cls = mapping_to_use[self.quant_backend]
quant_kwargs = self.quant_kwargs
return quant_config_cls(**quant_kwargs)
# Fallback: no applicable configuration found.
return None
def _get_quant_config_list(self):
if is_transformers_available():
from transformers.quantizers.auto import (
AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
)
else:
quant_config_mapping_transformers = None
from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
return quant_config_mapping_transformers, quant_config_mapping_diffusers
from .pipe_quant_config import PipelineQuantizationConfig

View File

@@ -0,0 +1,202 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Dict, List, Optional, Union
from ..utils import is_transformers_available, logging
from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
try:
from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
except ImportError:
class TransformersQuantConfigMixin:
pass
logger = logging.get_logger(__name__)
class PipelineQuantizationConfig:
"""
Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
Args:
quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
is available to both `diffusers` and `transformers`.
quant_kwargs (`dict`): Params to initialize the quantization backend class.
components_to_quantize (`list`): Components of a pipeline to be quantized.
quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
and `components_to_quantize`.
"""
def __init__(
self,
quant_backend: str = None,
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
components_to_quantize: Optional[List[str]] = None,
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
):
self.quant_backend = quant_backend
# Initialize kwargs to be {} to set to the defaults.
self.quant_kwargs = quant_kwargs or {}
self.components_to_quantize = components_to_quantize
self.quant_mapping = quant_mapping
self.config_mapping = {} # book-keeping Example: `{module_name: quant_config}`
self.post_init()
def post_init(self):
quant_mapping = self.quant_mapping
self.is_granular = True if quant_mapping is not None else False
self._validate_init_args()
def _validate_init_args(self):
if self.quant_backend and self.quant_mapping:
raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
if not self.quant_mapping and not self.quant_backend:
raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
if not self.quant_kwargs and not self.quant_mapping:
raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
if self.quant_backend is not None:
self._validate_init_kwargs_in_backends()
if self.quant_mapping is not None:
self._validate_quant_mapping_args()
def _validate_init_kwargs_in_backends(self):
quant_backend = self.quant_backend
self._check_backend_availability(quant_backend)
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
if quant_config_mapping_transformers is not None:
init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
else:
init_kwargs_transformers = None
init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
if init_kwargs_transformers != init_kwargs_diffusers:
raise ValueError(
"The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
"this mapping would look like."
)
def _validate_quant_mapping_args(self):
quant_mapping = self.quant_mapping
transformers_map, diffusers_map = self._get_quant_config_list()
available_transformers = list(transformers_map.values()) if transformers_map else None
available_diffusers = list(diffusers_map.values())
for module_name, config in quant_mapping.items():
if any(isinstance(config, cfg) for cfg in available_diffusers):
continue
if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
continue
if available_transformers:
raise ValueError(
f"Provided config for module_name={module_name} could not be found. "
f"Available diffusers configs: {available_diffusers}; "
f"Available transformers configs: {available_transformers}."
)
else:
raise ValueError(
f"Provided config for module_name={module_name} could not be found. "
f"Available diffusers configs: {available_diffusers}."
)
def _check_backend_availability(self, quant_backend: str):
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
available_backends_transformers = (
list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
)
available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
if (
available_backends_transformers and quant_backend not in available_backends_transformers
) or quant_backend not in quant_config_mapping_diffusers:
error_message = f"Provided quant_backend={quant_backend} was not found."
if available_backends_transformers:
error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
raise ValueError(error_message)
def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
quant_mapping = self.quant_mapping
components_to_quantize = self.components_to_quantize
# Granular case
if self.is_granular and module_name in quant_mapping:
logger.debug(f"Initializing quantization config class for {module_name}.")
config = quant_mapping[module_name]
self.config_mapping.update({module_name: config})
return config
# Global config case
else:
should_quantize = False
# Only quantize the modules requested for.
if components_to_quantize and module_name in components_to_quantize:
should_quantize = True
# No specification for `components_to_quantize` means all modules should be quantized.
elif not self.is_granular and not components_to_quantize:
should_quantize = True
if should_quantize:
logger.debug(f"Initializing quantization config class for {module_name}.")
mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
quant_config_cls = mapping_to_use[self.quant_backend]
quant_kwargs = self.quant_kwargs
quant_obj = quant_config_cls(**quant_kwargs)
self.config_mapping.update({module_name: quant_obj})
return quant_obj
# Fallback: no applicable configuration found.
return None
def _get_quant_config_list(self):
if is_transformers_available():
from transformers.quantizers.auto import (
AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
)
else:
quant_config_mapping_transformers = None
from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
return quant_config_mapping_transformers, quant_config_mapping_diffusers
def __repr__(self):
out = ""
config_mapping = dict(sorted(self.config_mapping.copy().items()))
for module_name, config in config_mapping.items():
out += f"{module_name} {config}"
return out

View File

@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tempfile
import unittest
import torch
from parameterized import parameterized
from diffusers import DiffusionPipeline, QuantoConfig
from diffusers import BitsAndBytesConfig, DiffusionPipeline, QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
@@ -243,3 +244,57 @@ class PipelineQuantizationTests(unittest.TestCase):
for name, component in pipe.components.items():
if isinstance(component, torch.nn.Module):
self.assertTrue(not hasattr(component.config, "quantization_config"))
@parameterized.expand(["quant_kwargs", "quant_mapping"])
def test_quant_config_repr(self, method):
component_name = "transformer"
if method == "quant_kwargs":
components_to_quantize = [component_name]
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=components_to_quantize,
)
else:
quant_config = PipelineQuantizationConfig(
quant_mapping={component_name: BitsAndBytesConfig(load_in_8bit=True)}
)
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
self.assertTrue(getattr(pipe, "quantization_config", None) is not None)
retrieved_config = pipe.quantization_config
expected_config = """
transformer BitsAndBytesConfig {
"_load_in_4bit": false,
"_load_in_8bit": true,
"bnb_4bit_compute_dtype": "float32",
"bnb_4bit_quant_storage": "uint8",
"bnb_4bit_quant_type": "fp4",
"bnb_4bit_use_double_quant": false,
"llm_int8_enable_fp32_cpu_offload": false,
"llm_int8_has_fp16_weight": false,
"llm_int8_skip_modules": null,
"llm_int8_threshold": 6.0,
"load_in_4bit": false,
"load_in_8bit": true,
"quant_method": "bitsandbytes"
}
"""
expected_data = self._parse_config_string(expected_config)
actual_data = self._parse_config_string(str(retrieved_config))
self.assertTrue(actual_data == expected_data)
def _parse_config_string(self, config_string: str) -> tuple[str, dict]:
first_brace = config_string.find("{")
if first_brace == -1:
raise ValueError("Could not find opening brace '{' in the string.")
json_part = config_string[first_brace:]
data = json.loads(json_part)
return data