mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[bistandbytes] improve replacement warnings for bnb (#11132)
* improve replacement warnings for bnb * updates to docs.
This commit is contained in:
@@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
|
||||
models by reducing the precision of the weights and activations, thus making models more efficient in terms
|
||||
of both storage and computation.
|
||||
"""
|
||||
model, has_been_replaced = _replace_with_bnb_linear(
|
||||
model, modules_to_not_convert, current_key_name, quantization_config
|
||||
)
|
||||
model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config)
|
||||
|
||||
has_been_replaced = any(
|
||||
isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt))
|
||||
for _, replaced_module in model.named_modules()
|
||||
)
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
||||
@@ -283,16 +285,18 @@ def dequantize_and_replace(
|
||||
modules_to_not_convert=None,
|
||||
quantization_config=None,
|
||||
):
|
||||
model, has_been_replaced = _dequantize_and_replace(
|
||||
model, _ = _dequantize_and_replace(
|
||||
model,
|
||||
dtype=model.dtype,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
has_been_replaced = any(
|
||||
isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules()
|
||||
)
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"For some reason the model has not been properly dequantized. You might see unexpected behavior."
|
||||
"Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model."
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
@@ -70,6 +70,8 @@ if is_torch_available():
|
||||
if is_bitsandbytes_available():
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from diffusers.quantizers.bitsandbytes.utils import replace_with_bnb_linear
|
||||
|
||||
|
||||
@require_bitsandbytes_version_greater("0.43.2")
|
||||
@require_accelerate
|
||||
@@ -371,6 +373,18 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
|
||||
assert key_to_target in str(err_context.exception)
|
||||
|
||||
def test_bnb_4bit_logs_warning_for_no_quantization(self):
|
||||
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
|
||||
assert (
|
||||
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
||||
in cap_logger.out
|
||||
)
|
||||
|
||||
|
||||
class BnB4BitTrainingTests(Base4bitTests):
|
||||
def setUp(self):
|
||||
|
||||
@@ -68,6 +68,8 @@ if is_torch_available():
|
||||
if is_bitsandbytes_available():
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from diffusers.quantizers.bitsandbytes import replace_with_bnb_linear
|
||||
|
||||
|
||||
@require_bitsandbytes_version_greater("0.43.2")
|
||||
@require_accelerate
|
||||
@@ -317,6 +319,18 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
# Check that this does not throw an error
|
||||
_ = self.model_fp16.to(torch_device)
|
||||
|
||||
def test_bnb_8bit_logs_warning_for_no_quantization(self):
|
||||
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
|
||||
assert (
|
||||
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
||||
in cap_logger.out
|
||||
)
|
||||
|
||||
|
||||
class Bnb8bitDeviceTests(Base8bitTests):
|
||||
def setUp(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user