From 1a04812439c82a9dd318d14a800bb04e84dbbfc0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 8 Apr 2025 21:18:34 +0530 Subject: [PATCH] [bistandbytes] improve replacement warnings for bnb (#11132) * improve replacement warnings for bnb * updates to docs. --- src/diffusers/quantizers/bitsandbytes/utils.py | 16 ++++++++++------ tests/quantization/bnb/test_4bit.py | 14 ++++++++++++++ tests/quantization/bnb/test_mixed_int8.py | 14 ++++++++++++++ 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index a9771b368a..e150281e81 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name models by reducing the precision of the weights and activations, thus making models more efficient in terms of both storage and computation. """ - model, has_been_replaced = _replace_with_bnb_linear( - model, modules_to_not_convert, current_key_name, quantization_config - ) + model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config) + has_been_replaced = any( + isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)) + for _, replaced_module in model.named_modules() + ) if not has_been_replaced: logger.warning( "You are loading your model in 8bit or 4bit but no linear modules were found in your model." @@ -283,16 +285,18 @@ def dequantize_and_replace( modules_to_not_convert=None, quantization_config=None, ): - model, has_been_replaced = _dequantize_and_replace( + model, _ = _dequantize_and_replace( model, dtype=model.dtype, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config, ) - + has_been_replaced = any( + isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules() + ) if not has_been_replaced: logger.warning( - "For some reason the model has not been properly dequantized. You might see unexpected behavior." + "Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model." ) return model diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index fdcc5314d2..096ee4c344 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -70,6 +70,8 @@ if is_torch_available(): if is_bitsandbytes_available(): import bitsandbytes as bnb + from diffusers.quantizers.bitsandbytes.utils import replace_with_bnb_linear + @require_bitsandbytes_version_greater("0.43.2") @require_accelerate @@ -371,6 +373,18 @@ class BnB4BitBasicTests(Base4bitTests): assert key_to_target in str(err_context.exception) + def test_bnb_4bit_logs_warning_for_no_quantization(self): + model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU()) + quantization_config = BitsAndBytesConfig(load_in_4bit=True) + logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + _ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config) + assert ( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + in cap_logger.out + ) + class BnB4BitTrainingTests(Base4bitTests): def setUp(self): diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index a5e38f931e..1049bfecba 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -68,6 +68,8 @@ if is_torch_available(): if is_bitsandbytes_available(): import bitsandbytes as bnb + from diffusers.quantizers.bitsandbytes import replace_with_bnb_linear + @require_bitsandbytes_version_greater("0.43.2") @require_accelerate @@ -317,6 +319,18 @@ class BnB8bitBasicTests(Base8bitTests): # Check that this does not throw an error _ = self.model_fp16.to(torch_device) + def test_bnb_8bit_logs_warning_for_no_quantization(self): + model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU()) + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + _ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config) + assert ( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + in cap_logger.out + ) + class Bnb8bitDeviceTests(Base8bitTests): def setUp(self) -> None: