1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Tests] improve quantization tests by additionally measuring the inference memory savings (#11021)

* memory usage tests

* fixes

* gguf
This commit is contained in:
Sayak Paul
2025-03-10 21:42:24 +05:30
committed by GitHub
parent 8eefed65bd
commit e7e6d85282
11 changed files with 136 additions and 105 deletions

View File

@@ -135,6 +135,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
**kwargs,
):
import bitsandbytes as bnb
@@ -445,6 +446,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
**kwargs,
):
import bitsandbytes as bnb

View File

@@ -108,6 +108,7 @@ class GGUFQuantizer(DiffusersQuantizer):
target_device: "torch.device",
state_dict: Optional[Dict[str, Any]] = None,
unexpected_keys: Optional[List[str]] = None,
**kwargs,
):
module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters and tensor_name not in module._buffers:

View File

@@ -215,6 +215,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: List[str],
**kwargs,
):
r"""
Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,

View File

View File

@@ -54,29 +54,8 @@ if is_transformers_available():
if is_torch_available():
import torch
import torch.nn as nn
class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
"""
def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)
def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
from ..utils import LoRALayer, get_memory_consumption_stat
if is_bitsandbytes_available():
@@ -96,6 +75,8 @@ class Base4bitTests(unittest.TestCase):
# This was obtained on audace so the number might slightly change
expected_rel_difference = 3.69
expected_memory_saving_ratio = 0.8
prompt = "a beautiful sunset amidst the mountains."
num_inference_steps = 10
seed = 0
@@ -140,8 +121,10 @@ class BnB4BitBasicTests(Base4bitTests):
)
def tearDown(self):
del self.model_fp16
del self.model_4bit
if hasattr(self, "model_fp16"):
del self.model_fp16
if hasattr(self, "model_4bit"):
del self.model_4bit
gc.collect()
torch.cuda.empty_cache()
@@ -180,6 +163,32 @@ class BnB4BitBasicTests(Base4bitTests):
linear = get_some_linear_layer(self.model_4bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
def test_model_memory_usage(self):
# Delete to not let anything interfere.
del self.model_4bit, self.model_fp16
# Re-instantiate.
inputs = self.get_dummy_inputs()
inputs = {
k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
}
model_fp16 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", torch_dtype=torch.float16
).to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
del model_fp16
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.float16
)
quantized_model_memory = get_memory_consumption_stat(model_4bit, inputs)
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio
def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype

View File

@@ -60,29 +60,8 @@ if is_transformers_available():
if is_torch_available():
import torch
import torch.nn as nn
class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77
"""
def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)
def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
from ..utils import LoRALayer, get_memory_consumption_stat
if is_bitsandbytes_available():
@@ -102,6 +81,8 @@ class Base8bitTests(unittest.TestCase):
# This was obtained on audace so the number might slightly change
expected_rel_difference = 1.94
expected_memory_saving_ratio = 0.7
prompt = "a beautiful sunset amidst the mountains."
num_inference_steps = 10
seed = 0
@@ -142,8 +123,10 @@ class BnB8bitBasicTests(Base8bitTests):
)
def tearDown(self):
del self.model_fp16
del self.model_8bit
if hasattr(self, "model_fp16"):
del self.model_fp16
if hasattr(self, "model_8bit"):
del self.model_8bit
gc.collect()
torch.cuda.empty_cache()
@@ -182,6 +165,28 @@ class BnB8bitBasicTests(Base8bitTests):
linear = get_some_linear_layer(self.model_8bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
def test_model_memory_usage(self):
# Delete to not let anything interfere.
del self.model_8bit, self.model_fp16
# Re-instantiate.
inputs = self.get_dummy_inputs()
inputs = {
k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
}
model_fp16 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", torch_dtype=torch.float16
).to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
del model_fp16
config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=config, torch_dtype=torch.float16
)
quantized_model_memory = get_memory_consumption_stat(model_8bit, inputs)
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio
def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype
@@ -248,7 +253,7 @@ class BnB8bitBasicTests(Base8bitTests):
self.assertTrue(linear.weight.dtype == torch.int8)
self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt))
self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear))
self.assertTrue(isinstance(model_8bit.proj_out, torch.nn.Linear))
self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8)
def test_config_from_pretrained(self):

View File

View File

@@ -19,29 +19,8 @@ if is_optimum_quanto_available():
if is_torch_available():
import torch
import torch.nn as nn
class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
"""
def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)
def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
from ..utils import LoRALayer, get_memory_consumption_stat
@nightly
@@ -85,20 +64,20 @@ class QuantoBaseTesterMixin:
assert isinstance(module, QLinear)
def test_quanto_memory_usage(self):
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
unquantized_model_memory = unquantized_model.get_memory_footprint() / 1024**3
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
inputs = self.get_dummy_inputs()
inputs = {
k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool)
}
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
unquantized_model.to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs)
model.to(torch_device)
with torch.no_grad():
model(**inputs)
max_memory = torch.cuda.max_memory_allocated() / 1024**3
assert (1.0 - (max_memory / unquantized_model_memory)) >= self.expected_memory_reduction
quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
quantized_model.to(torch_device)
quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction
def test_keep_modules_in_fp32(self):
r"""
@@ -318,14 +297,14 @@ class FluxTransformerQuantoMixin(QuantoBaseTesterMixin):
class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.3
expected_memory_reduction = 0.6
def get_dummy_init_kwargs(self):
return {"weights_dtype": "float8"}
class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.3
expected_memory_reduction = 0.6
_test_torch_compile = True
def get_dummy_init_kwargs(self):

View File

View File

@@ -50,27 +50,7 @@ if is_torch_available():
import torch
import torch.nn as nn
class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
"""
def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)
def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
from ..utils import LoRALayer, get_memory_consumption_stat
if is_torchao_available():
@@ -503,6 +483,22 @@ class TorchAoTest(unittest.TestCase):
# there is additional overhead of scales and zero points
self.assertTrue(total_bf16 < total_int4wo)
def test_model_memory_usage(self):
model_id = "hf-internal-testing/tiny-flux-pipe"
expected_memory_saving_ratio = 2.0
inputs = self.get_dummy_tensor_inputs(device=torch_device)
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
transformer_bf16.to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs)
del transformer_bf16
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
transformer_int8wo.to(torch_device)
quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs)
assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio
def test_wrong_config(self):
with self.assertRaises(ValueError):
self.get_dummy_components(TorchAoConfig("int42"))

View File

@@ -0,0 +1,38 @@
from diffusers.utils import is_torch_available
if is_torch_available():
import torch
import torch.nn as nn
class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
"""
def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)
def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
@torch.no_grad()
@torch.inference_mode()
def get_memory_consumption_stat(model, inputs):
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
model(**inputs)
max_memory_mem_allocated = torch.cuda.max_memory_allocated()
return max_memory_mem_allocated