mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[quant] QoL improvements for pipeline-level quant config (#11876)
* add repr for pipelinequantconfig. * update
This commit is contained in:
@@ -40,6 +40,7 @@ _import_structure = {
|
||||
"models": [],
|
||||
"modular_pipelines": [],
|
||||
"pipelines": [],
|
||||
"quantizers.pipe_quant_config": ["PipelineQuantizationConfig"],
|
||||
"quantizers.quantization_config": [],
|
||||
"schedulers": [],
|
||||
"utils": [
|
||||
|
||||
@@ -1096,6 +1096,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
if device_map is not None:
|
||||
setattr(model, "hf_device_map", final_device_map)
|
||||
if quantization_config is not None:
|
||||
setattr(model, "quantization_config", quantization_config)
|
||||
return model
|
||||
|
||||
@property
|
||||
|
||||
@@ -12,183 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ..utils import is_transformers_available, logging
|
||||
from .auto import DiffusersAutoQuantizer
|
||||
from .base import DiffusersQuantizer
|
||||
from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
|
||||
|
||||
|
||||
try:
|
||||
from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
|
||||
except ImportError:
|
||||
|
||||
class TransformersQuantConfigMixin:
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class PipelineQuantizationConfig:
|
||||
"""
|
||||
Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
|
||||
|
||||
Args:
|
||||
quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
|
||||
is available to both `diffusers` and `transformers`.
|
||||
quant_kwargs (`dict`): Params to initialize the quantization backend class.
|
||||
components_to_quantize (`list`): Components of a pipeline to be quantized.
|
||||
quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
|
||||
components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
|
||||
and `components_to_quantize`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_backend: str = None,
|
||||
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
|
||||
components_to_quantize: Optional[List[str]] = None,
|
||||
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
|
||||
):
|
||||
self.quant_backend = quant_backend
|
||||
# Initialize kwargs to be {} to set to the defaults.
|
||||
self.quant_kwargs = quant_kwargs or {}
|
||||
self.components_to_quantize = components_to_quantize
|
||||
self.quant_mapping = quant_mapping
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
quant_mapping = self.quant_mapping
|
||||
self.is_granular = True if quant_mapping is not None else False
|
||||
|
||||
self._validate_init_args()
|
||||
|
||||
def _validate_init_args(self):
|
||||
if self.quant_backend and self.quant_mapping:
|
||||
raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
|
||||
|
||||
if not self.quant_mapping and not self.quant_backend:
|
||||
raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
|
||||
|
||||
if not self.quant_kwargs and not self.quant_mapping:
|
||||
raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
|
||||
|
||||
if self.quant_backend is not None:
|
||||
self._validate_init_kwargs_in_backends()
|
||||
|
||||
if self.quant_mapping is not None:
|
||||
self._validate_quant_mapping_args()
|
||||
|
||||
def _validate_init_kwargs_in_backends(self):
|
||||
quant_backend = self.quant_backend
|
||||
|
||||
self._check_backend_availability(quant_backend)
|
||||
|
||||
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
||||
|
||||
if quant_config_mapping_transformers is not None:
|
||||
init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
|
||||
init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
|
||||
else:
|
||||
init_kwargs_transformers = None
|
||||
|
||||
init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
|
||||
init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
|
||||
|
||||
if init_kwargs_transformers != init_kwargs_diffusers:
|
||||
raise ValueError(
|
||||
"The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
|
||||
f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
|
||||
"this mapping would look like."
|
||||
)
|
||||
|
||||
def _validate_quant_mapping_args(self):
|
||||
quant_mapping = self.quant_mapping
|
||||
transformers_map, diffusers_map = self._get_quant_config_list()
|
||||
|
||||
available_transformers = list(transformers_map.values()) if transformers_map else None
|
||||
available_diffusers = list(diffusers_map.values())
|
||||
|
||||
for module_name, config in quant_mapping.items():
|
||||
if any(isinstance(config, cfg) for cfg in available_diffusers):
|
||||
continue
|
||||
|
||||
if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
|
||||
continue
|
||||
|
||||
if available_transformers:
|
||||
raise ValueError(
|
||||
f"Provided config for module_name={module_name} could not be found. "
|
||||
f"Available diffusers configs: {available_diffusers}; "
|
||||
f"Available transformers configs: {available_transformers}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provided config for module_name={module_name} could not be found. "
|
||||
f"Available diffusers configs: {available_diffusers}."
|
||||
)
|
||||
|
||||
def _check_backend_availability(self, quant_backend: str):
|
||||
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
||||
|
||||
available_backends_transformers = (
|
||||
list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
|
||||
)
|
||||
available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
|
||||
|
||||
if (
|
||||
available_backends_transformers and quant_backend not in available_backends_transformers
|
||||
) or quant_backend not in quant_config_mapping_diffusers:
|
||||
error_message = f"Provided quant_backend={quant_backend} was not found."
|
||||
if available_backends_transformers:
|
||||
error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
|
||||
error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
|
||||
raise ValueError(error_message)
|
||||
|
||||
def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
|
||||
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
||||
|
||||
quant_mapping = self.quant_mapping
|
||||
components_to_quantize = self.components_to_quantize
|
||||
|
||||
# Granular case
|
||||
if self.is_granular and module_name in quant_mapping:
|
||||
logger.debug(f"Initializing quantization config class for {module_name}.")
|
||||
config = quant_mapping[module_name]
|
||||
return config
|
||||
|
||||
# Global config case
|
||||
else:
|
||||
should_quantize = False
|
||||
# Only quantize the modules requested for.
|
||||
if components_to_quantize and module_name in components_to_quantize:
|
||||
should_quantize = True
|
||||
# No specification for `components_to_quantize` means all modules should be quantized.
|
||||
elif not self.is_granular and not components_to_quantize:
|
||||
should_quantize = True
|
||||
|
||||
if should_quantize:
|
||||
logger.debug(f"Initializing quantization config class for {module_name}.")
|
||||
mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
|
||||
quant_config_cls = mapping_to_use[self.quant_backend]
|
||||
quant_kwargs = self.quant_kwargs
|
||||
return quant_config_cls(**quant_kwargs)
|
||||
|
||||
# Fallback: no applicable configuration found.
|
||||
return None
|
||||
|
||||
def _get_quant_config_list(self):
|
||||
if is_transformers_available():
|
||||
from transformers.quantizers.auto import (
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
|
||||
)
|
||||
else:
|
||||
quant_config_mapping_transformers = None
|
||||
|
||||
from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
|
||||
|
||||
return quant_config_mapping_transformers, quant_config_mapping_diffusers
|
||||
from .pipe_quant_config import PipelineQuantizationConfig
|
||||
|
||||
202
src/diffusers/quantizers/pipe_quant_config.py
Normal file
202
src/diffusers/quantizers/pipe_quant_config.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ..utils import is_transformers_available, logging
|
||||
from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
|
||||
|
||||
|
||||
try:
|
||||
from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
|
||||
except ImportError:
|
||||
|
||||
class TransformersQuantConfigMixin:
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class PipelineQuantizationConfig:
|
||||
"""
|
||||
Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
|
||||
|
||||
Args:
|
||||
quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
|
||||
is available to both `diffusers` and `transformers`.
|
||||
quant_kwargs (`dict`): Params to initialize the quantization backend class.
|
||||
components_to_quantize (`list`): Components of a pipeline to be quantized.
|
||||
quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
|
||||
components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
|
||||
and `components_to_quantize`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_backend: str = None,
|
||||
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
|
||||
components_to_quantize: Optional[List[str]] = None,
|
||||
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
|
||||
):
|
||||
self.quant_backend = quant_backend
|
||||
# Initialize kwargs to be {} to set to the defaults.
|
||||
self.quant_kwargs = quant_kwargs or {}
|
||||
self.components_to_quantize = components_to_quantize
|
||||
self.quant_mapping = quant_mapping
|
||||
self.config_mapping = {} # book-keeping Example: `{module_name: quant_config}`
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
quant_mapping = self.quant_mapping
|
||||
self.is_granular = True if quant_mapping is not None else False
|
||||
|
||||
self._validate_init_args()
|
||||
|
||||
def _validate_init_args(self):
|
||||
if self.quant_backend and self.quant_mapping:
|
||||
raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
|
||||
|
||||
if not self.quant_mapping and not self.quant_backend:
|
||||
raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
|
||||
|
||||
if not self.quant_kwargs and not self.quant_mapping:
|
||||
raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
|
||||
|
||||
if self.quant_backend is not None:
|
||||
self._validate_init_kwargs_in_backends()
|
||||
|
||||
if self.quant_mapping is not None:
|
||||
self._validate_quant_mapping_args()
|
||||
|
||||
def _validate_init_kwargs_in_backends(self):
|
||||
quant_backend = self.quant_backend
|
||||
|
||||
self._check_backend_availability(quant_backend)
|
||||
|
||||
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
||||
|
||||
if quant_config_mapping_transformers is not None:
|
||||
init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
|
||||
init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
|
||||
else:
|
||||
init_kwargs_transformers = None
|
||||
|
||||
init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
|
||||
init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
|
||||
|
||||
if init_kwargs_transformers != init_kwargs_diffusers:
|
||||
raise ValueError(
|
||||
"The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
|
||||
f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
|
||||
"this mapping would look like."
|
||||
)
|
||||
|
||||
def _validate_quant_mapping_args(self):
|
||||
quant_mapping = self.quant_mapping
|
||||
transformers_map, diffusers_map = self._get_quant_config_list()
|
||||
|
||||
available_transformers = list(transformers_map.values()) if transformers_map else None
|
||||
available_diffusers = list(diffusers_map.values())
|
||||
|
||||
for module_name, config in quant_mapping.items():
|
||||
if any(isinstance(config, cfg) for cfg in available_diffusers):
|
||||
continue
|
||||
|
||||
if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
|
||||
continue
|
||||
|
||||
if available_transformers:
|
||||
raise ValueError(
|
||||
f"Provided config for module_name={module_name} could not be found. "
|
||||
f"Available diffusers configs: {available_diffusers}; "
|
||||
f"Available transformers configs: {available_transformers}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provided config for module_name={module_name} could not be found. "
|
||||
f"Available diffusers configs: {available_diffusers}."
|
||||
)
|
||||
|
||||
def _check_backend_availability(self, quant_backend: str):
|
||||
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
||||
|
||||
available_backends_transformers = (
|
||||
list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
|
||||
)
|
||||
available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
|
||||
|
||||
if (
|
||||
available_backends_transformers and quant_backend not in available_backends_transformers
|
||||
) or quant_backend not in quant_config_mapping_diffusers:
|
||||
error_message = f"Provided quant_backend={quant_backend} was not found."
|
||||
if available_backends_transformers:
|
||||
error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
|
||||
error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
|
||||
raise ValueError(error_message)
|
||||
|
||||
def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
|
||||
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
||||
|
||||
quant_mapping = self.quant_mapping
|
||||
components_to_quantize = self.components_to_quantize
|
||||
|
||||
# Granular case
|
||||
if self.is_granular and module_name in quant_mapping:
|
||||
logger.debug(f"Initializing quantization config class for {module_name}.")
|
||||
config = quant_mapping[module_name]
|
||||
self.config_mapping.update({module_name: config})
|
||||
return config
|
||||
|
||||
# Global config case
|
||||
else:
|
||||
should_quantize = False
|
||||
# Only quantize the modules requested for.
|
||||
if components_to_quantize and module_name in components_to_quantize:
|
||||
should_quantize = True
|
||||
# No specification for `components_to_quantize` means all modules should be quantized.
|
||||
elif not self.is_granular and not components_to_quantize:
|
||||
should_quantize = True
|
||||
|
||||
if should_quantize:
|
||||
logger.debug(f"Initializing quantization config class for {module_name}.")
|
||||
mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
|
||||
quant_config_cls = mapping_to_use[self.quant_backend]
|
||||
quant_kwargs = self.quant_kwargs
|
||||
quant_obj = quant_config_cls(**quant_kwargs)
|
||||
self.config_mapping.update({module_name: quant_obj})
|
||||
return quant_obj
|
||||
|
||||
# Fallback: no applicable configuration found.
|
||||
return None
|
||||
|
||||
def _get_quant_config_list(self):
|
||||
if is_transformers_available():
|
||||
from transformers.quantizers.auto import (
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
|
||||
)
|
||||
else:
|
||||
quant_config_mapping_transformers = None
|
||||
|
||||
from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
|
||||
|
||||
return quant_config_mapping_transformers, quant_config_mapping_diffusers
|
||||
|
||||
def __repr__(self):
|
||||
out = ""
|
||||
config_mapping = dict(sorted(self.config_mapping.copy().items()))
|
||||
for module_name, config in config_mapping.items():
|
||||
out += f"{module_name} {config}"
|
||||
return out
|
||||
@@ -12,13 +12,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import DiffusionPipeline, QuantoConfig
|
||||
from diffusers import BitsAndBytesConfig, DiffusionPipeline, QuantoConfig
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
@@ -243,3 +244,57 @@ class PipelineQuantizationTests(unittest.TestCase):
|
||||
for name, component in pipe.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
self.assertTrue(not hasattr(component.config, "quantization_config"))
|
||||
|
||||
@parameterized.expand(["quant_kwargs", "quant_mapping"])
|
||||
def test_quant_config_repr(self, method):
|
||||
component_name = "transformer"
|
||||
if method == "quant_kwargs":
|
||||
components_to_quantize = [component_name]
|
||||
quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={"load_in_8bit": True},
|
||||
components_to_quantize=components_to_quantize,
|
||||
)
|
||||
else:
|
||||
quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={component_name: BitsAndBytesConfig(load_in_8bit=True)}
|
||||
)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
self.model_name,
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
self.assertTrue(getattr(pipe, "quantization_config", None) is not None)
|
||||
retrieved_config = pipe.quantization_config
|
||||
expected_config = """
|
||||
transformer BitsAndBytesConfig {
|
||||
"_load_in_4bit": false,
|
||||
"_load_in_8bit": true,
|
||||
"bnb_4bit_compute_dtype": "float32",
|
||||
"bnb_4bit_quant_storage": "uint8",
|
||||
"bnb_4bit_quant_type": "fp4",
|
||||
"bnb_4bit_use_double_quant": false,
|
||||
"llm_int8_enable_fp32_cpu_offload": false,
|
||||
"llm_int8_has_fp16_weight": false,
|
||||
"llm_int8_skip_modules": null,
|
||||
"llm_int8_threshold": 6.0,
|
||||
"load_in_4bit": false,
|
||||
"load_in_8bit": true,
|
||||
"quant_method": "bitsandbytes"
|
||||
}
|
||||
|
||||
"""
|
||||
expected_data = self._parse_config_string(expected_config)
|
||||
actual_data = self._parse_config_string(str(retrieved_config))
|
||||
self.assertTrue(actual_data == expected_data)
|
||||
|
||||
def _parse_config_string(self, config_string: str) -> tuple[str, dict]:
|
||||
first_brace = config_string.find("{")
|
||||
if first_brace == -1:
|
||||
raise ValueError("Could not find opening brace '{' in the string.")
|
||||
|
||||
json_part = config_string[first_brace:]
|
||||
data = json.loads(json_part)
|
||||
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user