mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Tests] improve quantization tests by additionally measuring the inference memory savings (#11021)
* memory usage tests * fixes * gguf
This commit is contained in:
@@ -135,6 +135,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
target_device: "torch.device",
|
||||
state_dict: Dict[str, Any],
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
import bitsandbytes as bnb
|
||||
|
||||
@@ -445,6 +446,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
target_device: "torch.device",
|
||||
state_dict: Dict[str, Any],
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
import bitsandbytes as bnb
|
||||
|
||||
|
||||
@@ -108,6 +108,7 @@ class GGUFQuantizer(DiffusersQuantizer):
|
||||
target_device: "torch.device",
|
||||
state_dict: Optional[Dict[str, Any]] = None,
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if tensor_name not in module._parameters and tensor_name not in module._buffers:
|
||||
|
||||
@@ -215,6 +215,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
target_device: "torch.device",
|
||||
state_dict: Dict[str, Any],
|
||||
unexpected_keys: List[str],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,
|
||||
|
||||
0
tests/quantization/__init__.py
Normal file
0
tests/quantization/__init__.py
Normal file
@@ -54,29 +54,8 @@ if is_transformers_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class LoRALayer(nn.Module):
|
||||
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
|
||||
|
||||
Taken from
|
||||
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module, rank: int):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.adapter = nn.Sequential(
|
||||
nn.Linear(module.in_features, rank, bias=False),
|
||||
nn.Linear(rank, module.out_features, bias=False),
|
||||
)
|
||||
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
|
||||
nn.init.normal_(self.adapter[0].weight, std=small_std)
|
||||
nn.init.zeros_(self.adapter[1].weight)
|
||||
self.adapter.to(module.weight.device)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||
from ..utils import LoRALayer, get_memory_consumption_stat
|
||||
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
@@ -96,6 +75,8 @@ class Base4bitTests(unittest.TestCase):
|
||||
# This was obtained on audace so the number might slightly change
|
||||
expected_rel_difference = 3.69
|
||||
|
||||
expected_memory_saving_ratio = 0.8
|
||||
|
||||
prompt = "a beautiful sunset amidst the mountains."
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
@@ -140,8 +121,10 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
del self.model_fp16
|
||||
del self.model_4bit
|
||||
if hasattr(self, "model_fp16"):
|
||||
del self.model_fp16
|
||||
if hasattr(self, "model_4bit"):
|
||||
del self.model_4bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
@@ -180,6 +163,32 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
linear = get_some_linear_layer(self.model_4bit)
|
||||
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
|
||||
|
||||
def test_model_memory_usage(self):
|
||||
# Delete to not let anything interfere.
|
||||
del self.model_4bit, self.model_fp16
|
||||
|
||||
# Re-instantiate.
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = {
|
||||
k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
|
||||
}
|
||||
model_fp16 = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
|
||||
del model_fp16
|
||||
|
||||
nf4_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
model_4bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.float16
|
||||
)
|
||||
quantized_model_memory = get_memory_consumption_stat(model_4bit, inputs)
|
||||
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio
|
||||
|
||||
def test_original_dtype(self):
|
||||
r"""
|
||||
A simple test to check if the model succesfully stores the original dtype
|
||||
|
||||
@@ -60,29 +60,8 @@ if is_transformers_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class LoRALayer(nn.Module):
|
||||
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
|
||||
|
||||
Taken from
|
||||
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module, rank: int):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.adapter = nn.Sequential(
|
||||
nn.Linear(module.in_features, rank, bias=False),
|
||||
nn.Linear(rank, module.out_features, bias=False),
|
||||
)
|
||||
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
|
||||
nn.init.normal_(self.adapter[0].weight, std=small_std)
|
||||
nn.init.zeros_(self.adapter[1].weight)
|
||||
self.adapter.to(module.weight.device)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||
from ..utils import LoRALayer, get_memory_consumption_stat
|
||||
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
@@ -102,6 +81,8 @@ class Base8bitTests(unittest.TestCase):
|
||||
# This was obtained on audace so the number might slightly change
|
||||
expected_rel_difference = 1.94
|
||||
|
||||
expected_memory_saving_ratio = 0.7
|
||||
|
||||
prompt = "a beautiful sunset amidst the mountains."
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
@@ -142,8 +123,10 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
del self.model_fp16
|
||||
del self.model_8bit
|
||||
if hasattr(self, "model_fp16"):
|
||||
del self.model_fp16
|
||||
if hasattr(self, "model_8bit"):
|
||||
del self.model_8bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
@@ -182,6 +165,28 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
linear = get_some_linear_layer(self.model_8bit)
|
||||
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
|
||||
|
||||
def test_model_memory_usage(self):
|
||||
# Delete to not let anything interfere.
|
||||
del self.model_8bit, self.model_fp16
|
||||
|
||||
# Re-instantiate.
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = {
|
||||
k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
|
||||
}
|
||||
model_fp16 = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
|
||||
del model_fp16
|
||||
|
||||
config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
model_8bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=config, torch_dtype=torch.float16
|
||||
)
|
||||
quantized_model_memory = get_memory_consumption_stat(model_8bit, inputs)
|
||||
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio
|
||||
|
||||
def test_original_dtype(self):
|
||||
r"""
|
||||
A simple test to check if the model succesfully stores the original dtype
|
||||
@@ -248,7 +253,7 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
self.assertTrue(linear.weight.dtype == torch.int8)
|
||||
self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt))
|
||||
|
||||
self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear))
|
||||
self.assertTrue(isinstance(model_8bit.proj_out, torch.nn.Linear))
|
||||
self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8)
|
||||
|
||||
def test_config_from_pretrained(self):
|
||||
|
||||
0
tests/quantization/quanto/__init__.py
Normal file
0
tests/quantization/quanto/__init__.py
Normal file
@@ -19,29 +19,8 @@ if is_optimum_quanto_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class LoRALayer(nn.Module):
|
||||
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
|
||||
|
||||
Taken from
|
||||
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module, rank: int):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.adapter = nn.Sequential(
|
||||
nn.Linear(module.in_features, rank, bias=False),
|
||||
nn.Linear(rank, module.out_features, bias=False),
|
||||
)
|
||||
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
|
||||
nn.init.normal_(self.adapter[0].weight, std=small_std)
|
||||
nn.init.zeros_(self.adapter[1].weight)
|
||||
self.adapter.to(module.weight.device)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||
from ..utils import LoRALayer, get_memory_consumption_stat
|
||||
|
||||
|
||||
@nightly
|
||||
@@ -85,20 +64,20 @@ class QuantoBaseTesterMixin:
|
||||
assert isinstance(module, QLinear)
|
||||
|
||||
def test_quanto_memory_usage(self):
|
||||
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
|
||||
unquantized_model_memory = unquantized_model.get_memory_footprint() / 1024**3
|
||||
|
||||
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = {
|
||||
k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool)
|
||||
}
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
|
||||
unquantized_model.to(torch_device)
|
||||
unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs)
|
||||
|
||||
model.to(torch_device)
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
max_memory = torch.cuda.max_memory_allocated() / 1024**3
|
||||
assert (1.0 - (max_memory / unquantized_model_memory)) >= self.expected_memory_reduction
|
||||
quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
quantized_model.to(torch_device)
|
||||
quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)
|
||||
|
||||
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction
|
||||
|
||||
def test_keep_modules_in_fp32(self):
|
||||
r"""
|
||||
@@ -318,14 +297,14 @@ class FluxTransformerQuantoMixin(QuantoBaseTesterMixin):
|
||||
|
||||
|
||||
class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||
expected_memory_reduction = 0.3
|
||||
expected_memory_reduction = 0.6
|
||||
|
||||
def get_dummy_init_kwargs(self):
|
||||
return {"weights_dtype": "float8"}
|
||||
|
||||
|
||||
class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||
expected_memory_reduction = 0.3
|
||||
expected_memory_reduction = 0.6
|
||||
_test_torch_compile = True
|
||||
|
||||
def get_dummy_init_kwargs(self):
|
||||
|
||||
0
tests/quantization/torchao/__init__.py
Normal file
0
tests/quantization/torchao/__init__.py
Normal file
@@ -50,27 +50,7 @@ if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class LoRALayer(nn.Module):
|
||||
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
|
||||
|
||||
Taken from
|
||||
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module, rank: int):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.adapter = nn.Sequential(
|
||||
nn.Linear(module.in_features, rank, bias=False),
|
||||
nn.Linear(rank, module.out_features, bias=False),
|
||||
)
|
||||
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
|
||||
nn.init.normal_(self.adapter[0].weight, std=small_std)
|
||||
nn.init.zeros_(self.adapter[1].weight)
|
||||
self.adapter.to(module.weight.device)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||
from ..utils import LoRALayer, get_memory_consumption_stat
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
@@ -503,6 +483,22 @@ class TorchAoTest(unittest.TestCase):
|
||||
# there is additional overhead of scales and zero points
|
||||
self.assertTrue(total_bf16 < total_int4wo)
|
||||
|
||||
def test_model_memory_usage(self):
|
||||
model_id = "hf-internal-testing/tiny-flux-pipe"
|
||||
expected_memory_saving_ratio = 2.0
|
||||
|
||||
inputs = self.get_dummy_tensor_inputs(device=torch_device)
|
||||
|
||||
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
|
||||
transformer_bf16.to(torch_device)
|
||||
unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs)
|
||||
del transformer_bf16
|
||||
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
|
||||
transformer_int8wo.to(torch_device)
|
||||
quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs)
|
||||
assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio
|
||||
|
||||
def test_wrong_config(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self.get_dummy_components(TorchAoConfig("int42"))
|
||||
|
||||
38
tests/quantization/utils.py
Normal file
38
tests/quantization/utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from diffusers.utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class LoRALayer(nn.Module):
|
||||
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
|
||||
|
||||
Taken from
|
||||
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module, rank: int):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.adapter = nn.Sequential(
|
||||
nn.Linear(module.in_features, rank, bias=False),
|
||||
nn.Linear(rank, module.out_features, bias=False),
|
||||
)
|
||||
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
|
||||
nn.init.normal_(self.adapter[0].weight, std=small_std)
|
||||
nn.init.zeros_(self.adapter[1].weight)
|
||||
self.adapter.to(module.weight.device)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def get_memory_consumption_stat(model, inputs):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model(**inputs)
|
||||
max_memory_mem_allocated = torch.cuda.max_memory_allocated()
|
||||
return max_memory_mem_allocated
|
||||
Reference in New Issue
Block a user