1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[bistandbytes] improve replacement warnings for bnb (#11132)

* improve replacement warnings for bnb

* updates to docs.
This commit is contained in:
Sayak Paul
2025-04-08 21:18:34 +05:30
committed by GitHub
parent 4b27c4a494
commit 1a04812439
3 changed files with 38 additions and 6 deletions

View File

@@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
models by reducing the precision of the weights and activations, thus making models more efficient in terms
of both storage and computation.
"""
model, has_been_replaced = _replace_with_bnb_linear(
model, modules_to_not_convert, current_key_name, quantization_config
)
model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config)
has_been_replaced = any(
isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt))
for _, replaced_module in model.named_modules()
)
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
@@ -283,16 +285,18 @@ def dequantize_and_replace(
modules_to_not_convert=None,
quantization_config=None,
):
model, has_been_replaced = _dequantize_and_replace(
model, _ = _dequantize_and_replace(
model,
dtype=model.dtype,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)
has_been_replaced = any(
isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules()
)
if not has_been_replaced:
logger.warning(
"For some reason the model has not been properly dequantized. You might see unexpected behavior."
"Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model."
)
return model

View File

@@ -70,6 +70,8 @@ if is_torch_available():
if is_bitsandbytes_available():
import bitsandbytes as bnb
from diffusers.quantizers.bitsandbytes.utils import replace_with_bnb_linear
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@@ -371,6 +373,18 @@ class BnB4BitBasicTests(Base4bitTests):
assert key_to_target in str(err_context.exception)
def test_bnb_4bit_logs_warning_for_no_quantization(self):
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
assert (
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
in cap_logger.out
)
class BnB4BitTrainingTests(Base4bitTests):
def setUp(self):

View File

@@ -68,6 +68,8 @@ if is_torch_available():
if is_bitsandbytes_available():
import bitsandbytes as bnb
from diffusers.quantizers.bitsandbytes import replace_with_bnb_linear
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@@ -317,6 +319,18 @@ class BnB8bitBasicTests(Base8bitTests):
# Check that this does not throw an error
_ = self.model_fp16.to(torch_device)
def test_bnb_8bit_logs_warning_for_no_quantization(self):
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
assert (
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
in cap_logger.out
)
class Bnb8bitDeviceTests(Base8bitTests):
def setUp(self) -> None: