mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix invalid component handling behaviour in PipelineQuantizationConfig (#11750)
* start * updates
This commit is contained in:
@@ -1131,3 +1131,26 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
|
||||
break
|
||||
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
|
||||
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
|
||||
|
||||
|
||||
def _maybe_warn_for_wrong_component_in_quant_config(pipe_init_dict, quant_config):
|
||||
if quant_config is None:
|
||||
return
|
||||
|
||||
actual_pipe_components = set(pipe_init_dict.keys())
|
||||
missing = ""
|
||||
quant_components = None
|
||||
if getattr(quant_config, "components_to_quantize", None) is not None:
|
||||
quant_components = set(quant_config.components_to_quantize)
|
||||
elif getattr(quant_config, "quant_mapping", None) is not None and isinstance(quant_config.quant_mapping, dict):
|
||||
quant_components = set(quant_config.quant_mapping.keys())
|
||||
|
||||
if quant_components and not quant_components.issubset(actual_pipe_components):
|
||||
missing = quant_components - actual_pipe_components
|
||||
|
||||
if missing:
|
||||
logger.warning(
|
||||
f"The following components in the quantization config {missing} will be ignored "
|
||||
"as they do not belong to the underlying pipeline. Acceptable values for the pipeline "
|
||||
f"components are: {', '.join(actual_pipe_components)}."
|
||||
)
|
||||
|
||||
@@ -88,6 +88,7 @@ from .pipeline_loading_utils import (
|
||||
_identify_model_variants,
|
||||
_maybe_raise_error_for_incorrect_transformers,
|
||||
_maybe_raise_warning_for_inpainting,
|
||||
_maybe_warn_for_wrong_component_in_quant_config,
|
||||
_resolve_custom_pipeline_and_cls,
|
||||
_unwrap_model,
|
||||
_update_init_kwargs_with_connected_pipeline,
|
||||
@@ -984,6 +985,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# 7. Load each module in the pipeline
|
||||
current_device_map = None
|
||||
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
|
||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
||||
# 7.1 device_map shenanigans
|
||||
if final_device_map is not None and len(final_device_map) > 0:
|
||||
|
||||
@@ -16,10 +16,13 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import DiffusionPipeline, QuantoConfig
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_transformers_available,
|
||||
require_accelerate,
|
||||
require_bitsandbytes_version_greater,
|
||||
@@ -188,3 +191,55 @@ class PipelineQuantizationTests(unittest.TestCase):
|
||||
output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
|
||||
|
||||
self.assertTrue(torch.allclose(output_1, output_2))
|
||||
|
||||
@parameterized.expand(["quant_kwargs", "quant_mapping"])
|
||||
def test_warn_invalid_component(self, method):
|
||||
invalid_component = "foo"
|
||||
if method == "quant_kwargs":
|
||||
components_to_quantize = ["transformer", invalid_component]
|
||||
quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={"load_in_8bit": True},
|
||||
components_to_quantize=components_to_quantize,
|
||||
)
|
||||
else:
|
||||
quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": QuantoConfig("int8"),
|
||||
invalid_component: TranBitsAndBytesConfig(load_in_8bit=True),
|
||||
}
|
||||
)
|
||||
|
||||
logger = logging.get_logger("diffusers.pipelines.pipeline_loading_utils")
|
||||
logger.setLevel(logging.WARNING)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
self.model_name,
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
self.assertTrue(invalid_component in cap_logger.out)
|
||||
|
||||
@parameterized.expand(["quant_kwargs", "quant_mapping"])
|
||||
def test_no_quantization_for_all_invalid_components(self, method):
|
||||
invalid_component = "foo"
|
||||
if method == "quant_kwargs":
|
||||
components_to_quantize = [invalid_component]
|
||||
quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={"load_in_8bit": True},
|
||||
components_to_quantize=components_to_quantize,
|
||||
)
|
||||
else:
|
||||
quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={invalid_component: TranBitsAndBytesConfig(load_in_8bit=True)}
|
||||
)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
self.model_name,
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
for name, component in pipe.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
self.assertTrue(not hasattr(component.config, "quantization_config"))
|
||||
|
||||
Reference in New Issue
Block a user