1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix invalid component handling behaviour in PipelineQuantizationConfig (#11750)

* start

* updates
This commit is contained in:
Sayak Paul
2025-06-20 07:54:12 +05:30
committed by GitHub
parent 195926bbdc
commit 3d8d8485fc
3 changed files with 80 additions and 0 deletions

View File

@@ -1131,3 +1131,26 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
break
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
def _maybe_warn_for_wrong_component_in_quant_config(pipe_init_dict, quant_config):
if quant_config is None:
return
actual_pipe_components = set(pipe_init_dict.keys())
missing = ""
quant_components = None
if getattr(quant_config, "components_to_quantize", None) is not None:
quant_components = set(quant_config.components_to_quantize)
elif getattr(quant_config, "quant_mapping", None) is not None and isinstance(quant_config.quant_mapping, dict):
quant_components = set(quant_config.quant_mapping.keys())
if quant_components and not quant_components.issubset(actual_pipe_components):
missing = quant_components - actual_pipe_components
if missing:
logger.warning(
f"The following components in the quantization config {missing} will be ignored "
"as they do not belong to the underlying pipeline. Acceptable values for the pipeline "
f"components are: {', '.join(actual_pipe_components)}."
)

View File

@@ -88,6 +88,7 @@ from .pipeline_loading_utils import (
_identify_model_variants,
_maybe_raise_error_for_incorrect_transformers,
_maybe_raise_warning_for_inpainting,
_maybe_warn_for_wrong_component_in_quant_config,
_resolve_custom_pipeline_and_cls,
_unwrap_model,
_update_init_kwargs_with_connected_pipeline,
@@ -984,6 +985,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# 7. Load each module in the pipeline
current_device_map = None
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
# 7.1 device_map shenanigans
if final_device_map is not None and len(final_device_map) > 0:

View File

@@ -16,10 +16,13 @@ import tempfile
import unittest
import torch
from parameterized import parameterized
from diffusers import DiffusionPipeline, QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
CaptureLogger,
is_transformers_available,
require_accelerate,
require_bitsandbytes_version_greater,
@@ -188,3 +191,55 @@ class PipelineQuantizationTests(unittest.TestCase):
output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
self.assertTrue(torch.allclose(output_1, output_2))
@parameterized.expand(["quant_kwargs", "quant_mapping"])
def test_warn_invalid_component(self, method):
invalid_component = "foo"
if method == "quant_kwargs":
components_to_quantize = ["transformer", invalid_component]
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=components_to_quantize,
)
else:
quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig("int8"),
invalid_component: TranBitsAndBytesConfig(load_in_8bit=True),
}
)
logger = logging.get_logger("diffusers.pipelines.pipeline_loading_utils")
logger.setLevel(logging.WARNING)
with CaptureLogger(logger) as cap_logger:
_ = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
self.assertTrue(invalid_component in cap_logger.out)
@parameterized.expand(["quant_kwargs", "quant_mapping"])
def test_no_quantization_for_all_invalid_components(self, method):
invalid_component = "foo"
if method == "quant_kwargs":
components_to_quantize = [invalid_component]
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=components_to_quantize,
)
else:
quant_config = PipelineQuantizationConfig(
quant_mapping={invalid_component: TranBitsAndBytesConfig(load_in_8bit=True)}
)
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
for name, component in pipe.components.items():
if isinstance(component, torch.nn.Module):
self.assertTrue(not hasattr(component.config, "quantization_config"))