From e82001e40dd66e44eede3696414cdccc9f577597 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 18 Dec 2025 13:16:50 +0530 Subject: [PATCH] update --- tests/models/testing_utils/__init__.py | 38 + tests/models/testing_utils/cache.py | 536 +++++++++++ tests/models/testing_utils/compile.py | 12 +- tests/models/testing_utils/memory.py | 158 ++-- tests/models/testing_utils/quantization.py | 877 +++++++++++++----- tests/models/testing_utils/single_file.py | 33 +- .../test_models_transformer_flux.py | 99 ++ tests/testing_utils.py | 20 +- 8 files changed, 1445 insertions(+), 328 deletions(-) create mode 100644 tests/models/testing_utils/cache.py diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py index 6dfb77c713..ea076b3ec7 100644 --- a/tests/models/testing_utils/__init__.py +++ b/tests/models/testing_utils/__init__.py @@ -1,4 +1,13 @@ from .attention import AttentionTesterMixin +from .cache import ( + CacheTesterMixin, + FasterCacheConfigMixin, + FasterCacheTesterMixin, + FirstBlockCacheConfigMixin, + FirstBlockCacheTesterMixin, + PyramidAttentionBroadcastConfigMixin, + PyramidAttentionBroadcastTesterMixin, +) from .common import BaseModelTesterConfig, ModelTesterMixin from .compile import TorchCompileTesterMixin from .ip_adapter import IPAdapterTesterMixin @@ -6,11 +15,22 @@ from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin from .parallelism import ContextParallelTesterMixin from .quantization import ( + BitsAndBytesCompileTesterMixin, + BitsAndBytesConfigMixin, BitsAndBytesTesterMixin, + GGUFCompileTesterMixin, + GGUFConfigMixin, GGUFTesterMixin, + ModelOptCompileTesterMixin, + ModelOptConfigMixin, ModelOptTesterMixin, + QuantizationCompileTesterMixin, QuantizationTesterMixin, + QuantoCompileTesterMixin, + QuantoConfigMixin, QuantoTesterMixin, + TorchAoCompileTesterMixin, + TorchAoConfigMixin, TorchAoTesterMixin, ) from .single_file import SingleFileTesterMixin @@ -20,9 +40,18 @@ from .training import TrainingTesterMixin __all__ = [ "AttentionTesterMixin", "BaseModelTesterConfig", + "BitsAndBytesCompileTesterMixin", + "BitsAndBytesConfigMixin", "BitsAndBytesTesterMixin", + "CacheTesterMixin", "ContextParallelTesterMixin", "CPUOffloadTesterMixin", + "FasterCacheConfigMixin", + "FasterCacheTesterMixin", + "FirstBlockCacheConfigMixin", + "FirstBlockCacheTesterMixin", + "GGUFCompileTesterMixin", + "GGUFConfigMixin", "GGUFTesterMixin", "GroupOffloadTesterMixin", "IPAdapterTesterMixin", @@ -30,11 +59,20 @@ __all__ = [ "LoraHotSwappingForModelTesterMixin", "LoraTesterMixin", "MemoryTesterMixin", + "ModelOptCompileTesterMixin", + "ModelOptConfigMixin", "ModelOptTesterMixin", "ModelTesterMixin", + "PyramidAttentionBroadcastConfigMixin", + "PyramidAttentionBroadcastTesterMixin", + "QuantizationCompileTesterMixin", "QuantizationTesterMixin", + "QuantoCompileTesterMixin", + "QuantoConfigMixin", "QuantoTesterMixin", "SingleFileTesterMixin", + "TorchAoCompileTesterMixin", + "TorchAoConfigMixin", "TorchAoTesterMixin", "TorchCompileTesterMixin", "TrainingTesterMixin", diff --git a/tests/models/testing_utils/cache.py b/tests/models/testing_utils/cache.py new file mode 100644 index 0000000000..c1b916d34e --- /dev/null +++ b/tests/models/testing_utils/cache.py @@ -0,0 +1,536 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc + +import pytest +import torch + +from diffusers.hooks import FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig +from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK +from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK +from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK +from diffusers.models.cache_utils import CacheMixin + +from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device + + +def require_cache_mixin(func): + """Decorator to skip tests if model doesn't use CacheMixin.""" + + def wrapper(self, *args, **kwargs): + if not issubclass(self.model_class, CacheMixin): + pytest.skip(f"{self.model_class.__name__} does not use CacheMixin.") + return func(self, *args, **kwargs) + + return wrapper + + +class CacheTesterMixin: + """ + Base mixin class providing common test implementations for cache testing. + + Cache-specific mixins should: + 1. Inherit from their respective config mixin (e.g., PyramidAttentionBroadcastConfigMixin) + 2. Inherit from this mixin + 3. Define the cache config to use for tests + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + + Expected methods in test classes: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + """ + + def setup_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def _get_cache_config(self): + """ + Get the cache config for testing. + Should be implemented by subclasses. + """ + raise NotImplementedError("Subclass must implement _get_cache_config") + + def _get_hook_names(self): + """ + Get the hook names to check for this cache type. + Should be implemented by subclasses. + Returns a list of hook name strings. + """ + raise NotImplementedError("Subclass must implement _get_hook_names") + + def _test_cache_enable_disable_state(self): + """Test that cache enable/disable updates the is_cache_enabled state correctly.""" + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + # Initially cache should not be enabled + assert not model.is_cache_enabled, "Cache should not be enabled initially." + + config = self._get_cache_config() + + # Enable cache + model.enable_cache(config) + assert model.is_cache_enabled, "Cache should be enabled after enable_cache()." + + # Disable cache + model.disable_cache() + assert not model.is_cache_enabled, "Cache should not be enabled after disable_cache()." + + def _test_cache_double_enable_raises_error(self): + """Test that enabling cache twice raises an error.""" + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + config = self._get_cache_config() + + model.enable_cache(config) + + # Trying to enable again should raise ValueError + with pytest.raises(ValueError, match="Caching has already been enabled"): + model.enable_cache(config) + + # Cleanup + model.disable_cache() + + def _test_cache_hooks_registered(self): + """Test that cache hooks are properly registered and removed.""" + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + config = self._get_cache_config() + hook_names = self._get_hook_names() + + model.enable_cache(config) + + # Check that at least one hook was registered + hook_count = 0 + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + for hook_name in hook_names: + hook = module._diffusers_hook.get_hook(hook_name) + if hook is not None: + hook_count += 1 + + assert hook_count > 0, f"At least one cache hook should be registered. Hook names: {hook_names}" + + # Disable and verify hooks are removed + model.disable_cache() + + hook_count_after = 0 + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + for hook_name in hook_names: + hook = module._diffusers_hook.get_hook(hook_name) + if hook is not None: + hook_count_after += 1 + + assert hook_count_after == 0, "Cache hooks should be removed after disable_cache()." + + @torch.no_grad() + def _test_cache_inference(self): + """Test that model can run inference with cache enabled.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + config = self._get_cache_config() + + model.enable_cache(config) + + # First pass populates the cache + _ = model(**inputs_dict, return_dict=False)[0] + + # Create modified inputs for second pass (vary hidden_states to simulate denoising) + inputs_dict_step2 = inputs_dict.copy() + if "hidden_states" in inputs_dict_step2: + inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1 + + # Second pass uses cached attention with different hidden_states (produces approximated output) + output_with_cache = model(**inputs_dict_step2, return_dict=False)[0] + + assert output_with_cache is not None, "Model output should not be None with cache enabled." + assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled." + + # Run same inputs without cache to compare + model.disable_cache() + output_without_cache = model(**inputs_dict_step2, return_dict=False)[0] + + # Cached output should be different from non-cached output (due to approximation) + assert not torch.allclose( + output_without_cache, output_with_cache, atol=1e-5 + ), "Cached output should be different from non-cached output due to cache approximation." + + def _test_cache_context_manager(self): + """Test the cache_context context manager.""" + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + config = self._get_cache_config() + + model.enable_cache(config) + + # Test cache_context works without error + with model.cache_context("test_context"): + pass + + model.disable_cache() + + def _test_reset_stateful_cache(self): + """Test that _reset_stateful_cache resets the cache state.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + config = self._get_cache_config() + + model.enable_cache(config) + + # Run forward to populate cache state + with torch.no_grad(): + _ = model(**inputs_dict, return_dict=False)[0] + + # Reset should not raise any errors + model._reset_stateful_cache() + + model.disable_cache() + + +@is_cache +class PyramidAttentionBroadcastConfigMixin: + """ + Base mixin providing PyramidAttentionBroadcast cache config. + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + """ + + # Default PAB config - can be overridden by subclasses + PAB_CONFIG = { + "spatial_attention_block_skip_range": 2, + } + + # Store timestep for callback (must be within default range (100, 800) for skipping to trigger) + _current_timestep = 500 + + def _get_cache_config(self): + config_kwargs = self.PAB_CONFIG.copy() + config_kwargs["current_timestep_callback"] = lambda: self._current_timestep + return PyramidAttentionBroadcastConfig(**config_kwargs) + + def _get_hook_names(self): + return [_PYRAMID_ATTENTION_BROADCAST_HOOK] + + +@is_cache +class PyramidAttentionBroadcastTesterMixin(PyramidAttentionBroadcastConfigMixin, CacheTesterMixin): + """ + Mixin class for testing PyramidAttentionBroadcast caching on models. + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: cache + Use `pytest -m "not cache"` to skip these tests + """ + + @require_cache_mixin + def test_pab_cache_enable_disable_state(self): + self._test_cache_enable_disable_state() + + @require_cache_mixin + def test_pab_cache_double_enable_raises_error(self): + self._test_cache_double_enable_raises_error() + + @require_cache_mixin + def test_pab_cache_hooks_registered(self): + self._test_cache_hooks_registered() + + @require_cache_mixin + def test_pab_cache_inference(self): + self._test_cache_inference() + + @require_cache_mixin + def test_pab_cache_context_manager(self): + self._test_cache_context_manager() + + @require_cache_mixin + def test_pab_reset_stateful_cache(self): + self._test_reset_stateful_cache() + + +@is_cache +class FirstBlockCacheConfigMixin: + """ + Base mixin providing FirstBlockCache config. + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + """ + + # Default FBC config - can be overridden by subclasses + # Higher threshold makes FBC more aggressive about caching (skips more often) + FBC_CONFIG = { + "threshold": 1.0, + } + + def _get_cache_config(self): + return FirstBlockCacheConfig(**self.FBC_CONFIG) + + def _get_hook_names(self): + return [_FBC_LEADER_BLOCK_HOOK, _FBC_BLOCK_HOOK] + + +@is_cache +class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin): + """ + Mixin class for testing FirstBlockCache on models. + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: cache + Use `pytest -m "not cache"` to skip these tests + """ + + @torch.no_grad() + def _test_cache_inference(self): + """Test that model can run inference with FBC cache enabled (requires cache_context).""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + config = self._get_cache_config() + model.enable_cache(config) + + # FBC requires cache_context to be set for inference + with model.cache_context("fbc_test"): + # First pass populates the cache + _ = model(**inputs_dict, return_dict=False)[0] + + # Create modified inputs for second pass (small perturbation keeps residuals similar) + inputs_dict_step2 = inputs_dict.copy() + if "hidden_states" in inputs_dict_step2: + inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.01 + + # Second pass - FBC should skip remaining blocks and use cached residuals + output_with_cache = model(**inputs_dict_step2, return_dict=False)[0] + + assert output_with_cache is not None, "Model output should not be None with cache enabled." + assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled." + + # Run same inputs without cache to compare + model.disable_cache() + output_without_cache = model(**inputs_dict_step2, return_dict=False)[0] + + # Cached output should be different from non-cached output (due to approximation) + assert not torch.allclose( + output_without_cache, output_with_cache, atol=1e-5 + ), "Cached output should be different from non-cached output due to cache approximation." + + def _test_reset_stateful_cache(self): + """Test that _reset_stateful_cache resets the FBC cache state (requires cache_context).""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + config = self._get_cache_config() + model.enable_cache(config) + + # FBC requires cache_context to be set for inference + with model.cache_context("fbc_test"): + with torch.no_grad(): + _ = model(**inputs_dict, return_dict=False)[0] + + # Reset should not raise any errors + model._reset_stateful_cache() + + model.disable_cache() + + @require_cache_mixin + def test_fbc_cache_enable_disable_state(self): + self._test_cache_enable_disable_state() + + @require_cache_mixin + def test_fbc_cache_double_enable_raises_error(self): + self._test_cache_double_enable_raises_error() + + @require_cache_mixin + def test_fbc_cache_hooks_registered(self): + self._test_cache_hooks_registered() + + @require_cache_mixin + def test_fbc_cache_inference(self): + self._test_cache_inference() + + @require_cache_mixin + def test_fbc_cache_context_manager(self): + self._test_cache_context_manager() + + @require_cache_mixin + def test_fbc_reset_stateful_cache(self): + self._test_reset_stateful_cache() + + +@is_cache +class FasterCacheConfigMixin: + """ + Base mixin providing FasterCache config. + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + """ + + # Default FasterCache config - can be overridden by subclasses + FASTER_CACHE_CONFIG = { + "spatial_attention_block_skip_range": 2, + "spatial_attention_timestep_skip_range": (-1, 901), + "tensor_format": "BCHW", + } + + # Store timestep for callback - use a list so it can be mutated during test + # Starts outside skip range so first pass computes; changed to inside range for subsequent passes + _current_timestep = [1000] + + def _get_cache_config(self): + config_kwargs = self.FASTER_CACHE_CONFIG.copy() + config_kwargs["current_timestep_callback"] = lambda: self._current_timestep[0] + return FasterCacheConfig(**config_kwargs) + + def _get_hook_names(self): + return [_FASTER_CACHE_DENOISER_HOOK, _FASTER_CACHE_BLOCK_HOOK] + + +@is_cache +class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin): + """ + Mixin class for testing FasterCache on models. + + Note: FasterCache is designed for pipeline-level inference with proper CFG batch handling + and timestep management. Inference tests are skipped at model level - FasterCache should + be tested via pipeline tests (e.g., FluxPipeline, HunyuanVideoPipeline). + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: cache + Use `pytest -m "not cache"` to skip these tests + """ + + @torch.no_grad() + def _test_cache_inference(self): + """Test that model can run inference with FasterCache enabled.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + config = self._get_cache_config() + + model.enable_cache(config) + + # First pass with timestep outside skip range - computes and populates cache + self._current_timestep[0] = 1000 + _ = model(**inputs_dict, return_dict=False)[0] + + # Move timestep inside skip range so subsequent passes use cache + self._current_timestep[0] = 500 + + # Create modified inputs for second pass + inputs_dict_step2 = inputs_dict.copy() + if "hidden_states" in inputs_dict_step2: + inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1 + + # Second pass uses cached attention with different hidden_states + output_with_cache = model(**inputs_dict_step2, return_dict=False)[0] + + assert output_with_cache is not None, "Model output should not be None with cache enabled." + assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled." + + # Run same inputs without cache to compare + model.disable_cache() + output_without_cache = model(**inputs_dict_step2, return_dict=False)[0] + + # Cached output should be different from non-cached output (due to approximation) + assert not torch.allclose( + output_without_cache, output_with_cache, atol=1e-5 + ), "Cached output should be different from non-cached output due to cache approximation." + + def _test_reset_stateful_cache(self): + """Test that _reset_stateful_cache resets the FasterCache state.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + config = self._get_cache_config() + model.enable_cache(config) + + # First pass with timestep outside skip range + self._current_timestep[0] = 1000 + with torch.no_grad(): + _ = model(**inputs_dict, return_dict=False)[0] + + # Reset should not raise any errors + model._reset_stateful_cache() + + model.disable_cache() + + @require_cache_mixin + def test_faster_cache_enable_disable_state(self): + self._test_cache_enable_disable_state() + + @require_cache_mixin + def test_faster_cache_double_enable_raises_error(self): + self._test_cache_double_enable_raises_error() + + @require_cache_mixin + def test_faster_cache_hooks_registered(self): + self._test_cache_hooks_registered() + + @require_cache_mixin + def test_faster_cache_inference(self): + self._test_cache_inference() + + @require_cache_mixin + def test_faster_cache_context_manager(self): + self._test_cache_context_manager() + + @require_cache_mixin + def test_faster_cache_reset_stateful_cache(self): + self._test_reset_stateful_cache() diff --git a/tests/models/testing_utils/compile.py b/tests/models/testing_utils/compile.py index 8ff4c097b4..7a479ed5b4 100644 --- a/tests/models/testing_utils/compile.py +++ b/tests/models/testing_utils/compile.py @@ -15,7 +15,6 @@ import gc import os -import tempfile import pytest import torch @@ -140,7 +139,7 @@ class TorchCompileTesterMixin: inputs_dict = self.get_dummy_inputs(height=height, width=width) _ = model(**inputs_dict) - def test_compile_works_with_aot(self): + def test_compile_works_with_aot(self, tmp_path): from torch._inductor.package import load_package init_dict = self.get_init_dict() @@ -149,11 +148,10 @@ class TorchCompileTesterMixin: model = self.model_class(**init_dict).to(torch_device) exported_model = torch.export.export(model, args=(), kwargs=inputs_dict) - with tempfile.TemporaryDirectory() as tmpdir: - package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2") - _ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path) - assert os.path.exists(package_path), f"Package file not created at {package_path}" - loaded_binary = load_package(package_path, run_single_threaded=True) + package_path = os.path.join(str(tmp_path), f"{self.model_class.__name__}.pt2") + _ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path) + assert os.path.exists(package_path), f"Package file not created at {package_path}" + loaded_binary = load_package(package_path, run_single_threaded=True) model.forward = loaded_binary diff --git a/tests/models/testing_utils/memory.py b/tests/models/testing_utils/memory.py index 5486bbb0cd..6886ece75b 100644 --- a/tests/models/testing_utils/memory.py +++ b/tests/models/testing_utils/memory.py @@ -16,7 +16,6 @@ import gc import glob import inspect -import tempfile from functools import wraps import pytest @@ -97,7 +96,7 @@ class CPUOffloadTesterMixin: model_split_percents = [0.5, 0.7] @require_offload_support - def test_cpu_offload(self): + def test_cpu_offload(self, tmp_path): config = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**config).eval() @@ -110,25 +109,24 @@ class CPUOffloadTesterMixin: model_size = compute_module_sizes(model)[""] # We test several splits of sizes to make sure it works max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] - with tempfile.TemporaryDirectory() as tmp_dir: - model.cpu().save_pretrained(tmp_dir) + model.cpu().save_pretrained(str(tmp_path)) - for max_size in max_gpu_sizes: - max_memory = {0: max_size, "cpu": model_size * 2} - new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) - # Making sure part of the model will actually end up offloaded - assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU" + for max_size in max_gpu_sizes: + max_memory = {0: max_size, "cpu": model_size * 2} + new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory) + # Making sure part of the model will actually end up offloaded + assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU" - check_device_map_is_respected(new_model, new_model.hf_device_map) - torch.manual_seed(0) - new_output = new_model(**inputs_dict) + check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) - assert_tensors_close( - base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with CPU offloading" - ) + assert_tensors_close( + base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with CPU offloading" + ) @require_offload_support - def test_disk_offload_without_safetensors(self): + def test_disk_offload_without_safetensors(self, tmp_path): config = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**config).eval() @@ -143,26 +141,25 @@ class CPUOffloadTesterMixin: # Force disk offload by setting very small CPU memory max_memory = {0: max_size, "cpu": int(0.1 * max_size)} - with tempfile.TemporaryDirectory() as tmp_dir: - model.cpu().save_pretrained(tmp_dir, safe_serialization=False) - # This errors out because it's missing an offload folder - with pytest.raises(ValueError): - new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + model.cpu().save_pretrained(str(tmp_path), safe_serialization=False) + # This errors out because it's missing an offload folder + with pytest.raises(ValueError): + new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory) - new_model = self.model_class.from_pretrained( - tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir - ) + new_model = self.model_class.from_pretrained( + str(tmp_path), device_map="auto", max_memory=max_memory, offload_folder=str(tmp_path) + ) - check_device_map_is_respected(new_model, new_model.hf_device_map) - torch.manual_seed(0) - new_output = new_model(**inputs_dict) + check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) - assert_tensors_close( - base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with disk offloading" - ) + assert_tensors_close( + base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with disk offloading" + ) @require_offload_support - def test_disk_offload_with_safetensors(self): + def test_disk_offload_with_safetensors(self, tmp_path): config = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**config).eval() @@ -173,26 +170,25 @@ class CPUOffloadTesterMixin: base_output = model(**inputs_dict) model_size = compute_module_sizes(model)[""] - with tempfile.TemporaryDirectory() as tmp_dir: - model.cpu().save_pretrained(tmp_dir) + model.cpu().save_pretrained(str(tmp_path)) - max_size = int(self.model_split_percents[0] * model_size) - max_memory = {0: max_size, "cpu": max_size} - new_model = self.model_class.from_pretrained( - tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory - ) + max_size = int(self.model_split_percents[0] * model_size) + max_memory = {0: max_size, "cpu": max_size} + new_model = self.model_class.from_pretrained( + str(tmp_path), device_map="auto", offload_folder=str(tmp_path), max_memory=max_memory + ) - check_device_map_is_respected(new_model, new_model.hf_device_map) - torch.manual_seed(0) - new_output = new_model(**inputs_dict) + check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) - assert_tensors_close( - base_output[0], - new_output[0], - atol=1e-5, - rtol=0, - msg="Output should match with disk offloading (safetensors)", - ) + assert_tensors_close( + base_output[0], + new_output[0], + atol=1e-5, + rtol=0, + msg="Output should match with disk offloading (safetensors)", + ) @is_group_offload @@ -312,7 +308,7 @@ class GroupOffloadTesterMixin: @require_group_offload_support @torch.no_grad() @torch.inference_mode() - def test_group_offloading_with_disk(self, offload_type="block_level", record_stream=False, atol=1e-5): + def test_group_offloading_with_disk(self, tmp_path, offload_type="block_level", record_stream=False, atol=1e-5): def _has_generator_arg(model): sig = inspect.signature(model.forward) params = sig.parameters @@ -340,41 +336,41 @@ class GroupOffloadTesterMixin: num_blocks_per_group = None if offload_type == "leaf_level" else 1 additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group} - with tempfile.TemporaryDirectory() as tmpdir: - model.enable_group_offload( - torch_device, - offload_type=offload_type, + tmpdir = str(tmp_path) + model.enable_group_offload( + torch_device, + offload_type=offload_type, + offload_to_disk_path=tmpdir, + use_stream=True, + record_stream=record_stream, + **additional_kwargs, + ) + has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") + assert has_safetensors, "No safetensors found in the directory." + + # For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic + # in nature. So, skip it. + if offload_type != "leaf_level": + is_correct, extra_files, missing_files = _check_safetensors_serialization( + module=model, offload_to_disk_path=tmpdir, - use_stream=True, - record_stream=record_stream, - **additional_kwargs, + offload_type=offload_type, + num_blocks_per_group=num_blocks_per_group, ) - has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") - assert has_safetensors, "No safetensors found in the directory." + if not is_correct: + if extra_files: + raise ValueError(f"Found extra files: {', '.join(extra_files)}") + elif missing_files: + raise ValueError(f"Following files are missing: {', '.join(missing_files)}") - # For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic - # in nature. So, skip it. - if offload_type != "leaf_level": - is_correct, extra_files, missing_files = _check_safetensors_serialization( - module=model, - offload_to_disk_path=tmpdir, - offload_type=offload_type, - num_blocks_per_group=num_blocks_per_group, - ) - if not is_correct: - if extra_files: - raise ValueError(f"Found extra files: {', '.join(extra_files)}") - elif missing_files: - raise ValueError(f"Following files are missing: {', '.join(missing_files)}") - - output_with_group_offloading = _run_forward(model, inputs_dict) - assert_tensors_close( - output_without_group_offloading, - output_with_group_offloading, - atol=atol, - rtol=0, - msg="Output should match with disk-based group offloading", - ) + output_with_group_offloading = _run_forward(model, inputs_dict) + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading, + atol=atol, + rtol=0, + msg="Output should match with disk-based group offloading", + ) class LayerwiseCastingTesterMixin: diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 26904e8cf9..4d3bde21b7 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -14,7 +14,6 @@ # limitations under the License. import gc -import tempfile import pytest import torch @@ -34,7 +33,9 @@ from ...testing_utils import ( is_bitsandbytes, is_gguf, is_modelopt, + is_quantization, is_quanto, + is_torch_compile, is_torchao, require_accelerate, require_accelerator, @@ -64,6 +65,29 @@ if is_torchao_available(): pass +class LoRALayer(torch.nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only. + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 + """ + + def __init__(self, module: torch.nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = torch.nn.Sequential( + torch.nn.Linear(module.in_features, rank, bias=False), + torch.nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + torch.nn.init.normal_(self.adapter[0].weight, std=small_std) + torch.nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) + + @require_accelerator class QuantizationTesterMixin: """ @@ -128,9 +152,9 @@ class QuantizationTesterMixin: model_quantized = self._create_quantized_model(config_kwargs) num_params_quantized = model_quantized.num_parameters() - assert num_params == num_params_quantized, ( - f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" - ) + assert ( + num_params == num_params_quantized + ), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2): model = self._load_unquantized_model() @@ -140,9 +164,9 @@ class QuantizationTesterMixin: mem_quantized = model_quantized.get_memory_footprint() ratio = mem / mem_quantized - assert ratio >= expected_memory_reduction, ( - f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" - ) + assert ( + ratio >= expected_memory_reduction + ), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" def _test_quantization_inference(self, config_kwargs): model_quantized = self._create_quantized_model(config_kwargs) @@ -200,18 +224,17 @@ class QuantizationTesterMixin: assert output is not None, "Model output is None with LoRA" assert not torch.isnan(output).any(), "Model output contains NaN with LoRA" - def _test_quantization_serialization(self, config_kwargs): + def _test_quantization_serialization(self, config_kwargs, tmp_path): model = self._create_quantized_model(config_kwargs) - with tempfile.TemporaryDirectory() as tmpdir: - model.save_pretrained(tmpdir, safe_serialization=True) + model.save_pretrained(str(tmp_path), safe_serialization=True) - model_loaded = self.model_class.from_pretrained(tmpdir) + model_loaded = self.model_class.from_pretrained(str(tmp_path)) - with torch.no_grad(): - inputs = self.get_dummy_inputs() - output = model_loaded(**inputs, return_dict=False)[0] - assert not torch.isnan(output).any(), "Loaded model output contains NaN" + with torch.no_grad(): + inputs = self.get_dummy_inputs() + output = model_loaded(**inputs, return_dict=False)[0] + assert not torch.isnan(output).any(), "Loaded model output contains NaN" def _test_quantized_layers(self, config_kwargs): model_fp = self._load_unquantized_model() @@ -237,12 +260,12 @@ class QuantizationTesterMixin: self._verify_if_layer_quantized(name, module, config_kwargs) num_quantized_layers += 1 - assert num_quantized_layers > 0, ( - f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" - ) - assert num_quantized_layers == expected_quantized_layers, ( - f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" - ) + assert ( + num_quantized_layers > 0 + ), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" + assert ( + num_quantized_layers == expected_quantized_layers + ), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert): """ @@ -266,9 +289,9 @@ class QuantizationTesterMixin: if any(excluded in name for excluded in modules_to_not_convert): found_excluded = True # This module should NOT be quantized - assert not self._is_module_quantized(module), ( - f"Module {name} should not be quantized but was found to be quantized" - ) + assert not self._is_module_quantized( + module + ), f"Module {name} should not be quantized but was found to be quantized" assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}" @@ -290,9 +313,9 @@ class QuantizationTesterMixin: mem_with_exclusion = model_with_exclusion.get_memory_footprint() mem_fully_quantized = model_fully_quantized.get_memory_footprint() - assert mem_with_exclusion > mem_fully_quantized, ( - f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" - ) + assert ( + mem_with_exclusion > mem_fully_quantized + ), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" def _test_quantization_device_map(self, config_kwargs): """ @@ -342,32 +365,75 @@ class QuantizationTesterMixin: assert output is not None, "Model output is None after dequantization" assert not torch.isnan(output).any(), "Model output contains NaN after dequantization" + def _test_quantization_training(self, config_kwargs): + """ + Test that quantized models can be used for training with LoRA-like adapters. + This test: + 1. Freezes all model parameters + 2. Casts small parameters (e.g., layernorm) to fp32 for stability + 3. Adds LoRA adapters to attention layers + 4. Runs forward and backward passes + 5. Verifies gradients are computed correctly + + Args: + config_kwargs: Quantization config parameters + """ + model = self._create_quantized_model(config_kwargs) + + # Step 1: freeze all parameters + for param in model.parameters(): + param.requires_grad = False + if param.ndim == 1: + # cast small parameters (e.g. layernorm) to fp32 for stability + param.data = param.data.to(torch.float32) + + # Step 2: add adapters to attention layers + adapter_count = 0 + for _, module in model.named_modules(): + if "Attention" in repr(type(module)): + if hasattr(module, "to_k"): + module.to_k = LoRALayer(module.to_k, rank=4) + adapter_count += 1 + if hasattr(module, "to_q"): + module.to_q = LoRALayer(module.to_q, rank=4) + adapter_count += 1 + if hasattr(module, "to_v"): + module.to_v = LoRALayer(module.to_v, rank=4) + adapter_count += 1 + + if adapter_count == 0: + pytest.skip("No attention layers found in model for adapter training test") + + # Step 3: run forward and backward pass + inputs = self.get_dummy_inputs() + + with torch.amp.autocast(torch_device, dtype=torch.float16): + out = model(**inputs, return_dict=False)[0] + out.norm().backward() + + # Step 4: verify gradients are computed + for module in model.modules(): + if isinstance(module, LoRALayer): + assert module.adapter[1].weight.grad is not None, "LoRA adapter gradient is None" + assert module.adapter[1].weight.grad.norm().item() > 0, "LoRA adapter gradient norm is zero" + + +@is_quantization @is_bitsandbytes @require_accelerator @require_bitsandbytes_version_greater("0.43.2") @require_accelerate -class BitsAndBytesTesterMixin(QuantizationTesterMixin): +class BitsAndBytesConfigMixin: """ - Mixin class for testing BitsAndBytes quantization on models. + Base mixin providing BitsAndBytes quantization config and model creation. Expected class attributes: - model_class: The model class to test - pretrained_model_name_or_path: Hub repository ID for the pretrained model - - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) - - Expected methods to be implemented by subclasses: - - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass - - Optional class attributes: - - BNB_CONFIGS: Dict of config name -> BitsAndBytesConfig kwargs to test - - Pytest mark: bitsandbytes - Use `pytest -m "not bitsandbytes"` to skip these tests + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained """ - # Standard BnB configs tested for all models - # Subclasses can override to add or modify configs BNB_CONFIGS = { "4bit_nf4": { "load_in_4bit": True, @@ -399,42 +465,88 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin): def _verify_if_layer_quantized(self, name, module, config_kwargs): expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params - assert module.weight.__class__ == expected_weight_class, ( - f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" + assert ( + module.weight.__class__ == expected_weight_class + ), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" + + +@is_bitsandbytes +@require_accelerator +@require_bitsandbytes_version_greater("0.43.2") +@require_accelerate +class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin): + """ + Mixin class for testing BitsAndBytes quantization on models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Optional class attributes: + - BNB_CONFIGS: Dict of config name -> BitsAndBytesConfig kwargs to test + + Pytest mark: bitsandbytes + Use `pytest -m "not bitsandbytes"` to skip these tests + """ + + @pytest.mark.parametrize( + "config_name", + list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ) + def test_bnb_quantization_num_parameters(self, config_name): + self._test_quantization_num_parameters(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) + + @pytest.mark.parametrize( + "config_name", + list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ) + def test_bnb_quantization_memory_footprint(self, config_name): + expected = BitsAndBytesConfigMixin.BNB_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2) + self._test_quantization_memory_footprint( + BitsAndBytesConfigMixin.BNB_CONFIGS[config_name], expected_memory_reduction=expected ) - @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys())) - def test_bnb_quantization_num_parameters(self, config_name): - self._test_quantization_num_parameters(self.BNB_CONFIGS[config_name]) - - @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys())) - def test_bnb_quantization_memory_footprint(self, config_name): - expected = self.BNB_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2) - self._test_quantization_memory_footprint(self.BNB_CONFIGS[config_name], expected_memory_reduction=expected) - - @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys())) + @pytest.mark.parametrize( + "config_name", + list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ) def test_bnb_quantization_inference(self, config_name): - self._test_quantization_inference(self.BNB_CONFIGS[config_name]) + self._test_quantization_inference(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) def test_bnb_quantization_dtype_assignment(self, config_name): - self._test_quantization_dtype_assignment(self.BNB_CONFIGS[config_name]) + self._test_quantization_dtype_assignment(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) def test_bnb_quantization_lora_inference(self, config_name): - self._test_quantization_lora_inference(self.BNB_CONFIGS[config_name]) + self._test_quantization_lora_inference(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) - def test_bnb_quantization_serialization(self, config_name): - self._test_quantization_serialization(self.BNB_CONFIGS[config_name]) + def test_bnb_quantization_serialization(self, config_name, tmp_path): + self._test_quantization_serialization(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name], tmp_path) - @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys())) + @pytest.mark.parametrize( + "config_name", + list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ) def test_bnb_quantized_layers(self, config_name): - self._test_quantized_layers(self.BNB_CONFIGS[config_name]) + self._test_quantized_layers(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) - @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys())) + @pytest.mark.parametrize( + "config_name", + list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ) def test_bnb_quantization_config_serialization(self, config_name): - model = self._create_quantized_model(self.BNB_CONFIGS[config_name]) + model = self._create_quantized_model(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) assert "quantization_config" in model.config, "Missing quantization_config" _ = model.config["quantization_config"].to_dict() @@ -442,8 +554,8 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin): _ = model.config["quantization_config"].to_json_string() def test_bnb_original_dtype(self): - config_name = list(self.BNB_CONFIGS.keys())[0] - config_kwargs = self.BNB_CONFIGS[config_name] + config_name = list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys())[0] + config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS[config_name] model = self._create_quantized_model(config_kwargs) @@ -458,7 +570,7 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin): if not hasattr(self.model_class, "_keep_in_fp32_modules"): pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules") - config_kwargs = self.BNB_CONFIGS["4bit_nf4"] + config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"] original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None) self.model_class._keep_in_fp32_modules = ["proj_out"] @@ -469,13 +581,13 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin): for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules): - assert module.weight.dtype == torch.float32, ( - f"Module {name} should be FP32 but is {module.weight.dtype}" - ) + assert ( + module.weight.dtype == torch.float32 + ), f"Module {name} should be FP32 but is {module.weight.dtype}" else: - assert module.weight.dtype == torch.uint8, ( - f"Module {name} should be uint8 but is {module.weight.dtype}" - ) + assert ( + module.weight.dtype == torch.uint8 + ), f"Module {name} should be uint8 but is {module.weight.dtype}" with torch.no_grad(): inputs = self.get_dummy_inputs() @@ -490,39 +602,37 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin): if modules_to_exclude is None: pytest.skip("modules_to_not_convert_for_test not defined for this model") - self._test_quantization_modules_to_not_convert(self.BNB_CONFIGS["4bit_nf4"], modules_to_exclude) + self._test_quantization_modules_to_not_convert( + BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude + ) @pytest.mark.parametrize("config_name", ["4bit_nf4", "8bit"], ids=["4bit_nf4", "8bit"]) def test_bnb_device_map(self, config_name): """Test that device_map='auto' works correctly with quantization.""" - self._test_quantization_device_map(self.BNB_CONFIGS[config_name]) + self._test_quantization_device_map(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) def test_bnb_dequantize(self): """Test that dequantize() works correctly.""" - self._test_dequantize(self.BNB_CONFIGS["4bit_nf4"]) + self._test_dequantize(BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"]) + + def test_bnb_training(self): + """Test that quantized models can be used for training with adapters.""" + self._test_quantization_training(BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"]) +@is_quantization @is_quanto @require_quanto @require_accelerate @require_accelerator -class QuantoTesterMixin(QuantizationTesterMixin): +class QuantoConfigMixin: """ - Mixin class for testing Quanto quantization on models. + Base mixin providing Quanto quantization config and model creation. Expected class attributes: - model_class: The model class to test - pretrained_model_name_or_path: Hub repository ID for the pretrained model - - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) - - Expected methods to be implemented by subclasses: - - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass - - Optional class attributes: - - QUANTO_WEIGHT_TYPES: Dict of weight_type_name -> qtype - - Pytest mark: quanto - Use `pytest -m "not quanto"` to skip these tests + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained """ QUANTO_WEIGHT_TYPES = { @@ -549,62 +659,14 @@ class QuantoTesterMixin(QuantizationTesterMixin): def _verify_if_layer_quantized(self, name, module, config_kwargs): assert isinstance(module, QLinear), f"Layer {name} is not QLinear, got {type(module)}" - @pytest.mark.parametrize( - "weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()), ids=list(QUANTO_WEIGHT_TYPES.keys()) - ) - def test_quanto_quantization_num_parameters(self, weight_type_name): - self._test_quantization_num_parameters(self.QUANTO_WEIGHT_TYPES[weight_type_name]) - @pytest.mark.parametrize( - "weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()), ids=list(QUANTO_WEIGHT_TYPES.keys()) - ) - def test_quanto_quantization_memory_footprint(self, weight_type_name): - expected = self.QUANTO_EXPECTED_MEMORY_REDUCTIONS.get(weight_type_name, 1.2) - self._test_quantization_memory_footprint( - self.QUANTO_WEIGHT_TYPES[weight_type_name], expected_memory_reduction=expected - ) - - @pytest.mark.parametrize( - "weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()), ids=list(QUANTO_WEIGHT_TYPES.keys()) - ) - def test_quanto_quantization_inference(self, weight_type_name): - self._test_quantization_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name]) - - @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) - def test_quanto_quantized_layers(self, weight_type_name): - self._test_quantized_layers(self.QUANTO_WEIGHT_TYPES[weight_type_name]) - - @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) - def test_quanto_quantization_lora_inference(self, weight_type_name): - self._test_quantization_lora_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name]) - - @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) - def test_quanto_quantization_serialization(self, weight_type_name): - self._test_quantization_serialization(self.QUANTO_WEIGHT_TYPES[weight_type_name]) - - def test_quanto_modules_to_not_convert(self): - """Test that modules_to_not_convert parameter works correctly.""" - modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None) - if modules_to_exclude is None: - pytest.skip("modules_to_not_convert_for_test not defined for this model") - - self._test_quantization_modules_to_not_convert(self.QUANTO_WEIGHT_TYPES["int8"], modules_to_exclude) - - def test_quanto_device_map(self): - """Test that device_map='auto' works correctly with quantization.""" - self._test_quantization_device_map(self.QUANTO_WEIGHT_TYPES["int8"]) - - def test_quanto_dequantize(self): - """Test that dequantize() works correctly.""" - self._test_dequantize(self.QUANTO_WEIGHT_TYPES["int8"]) - - -@is_torchao +@is_quanto +@require_quanto +@require_accelerate @require_accelerator -@require_torchao_version_greater_or_equal("0.7.0") -class TorchAoTesterMixin(QuantizationTesterMixin): +class QuantoTesterMixin(QuantoConfigMixin, QuantizationTesterMixin): """ - Mixin class for testing TorchAO quantization on models. + Mixin class for testing Quanto quantization on models. Expected class attributes: - model_class: The model class to test @@ -615,10 +677,82 @@ class TorchAoTesterMixin(QuantizationTesterMixin): - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass Optional class attributes: - - TORCHAO_QUANT_TYPES: Dict of quantization type strings to test + - QUANTO_WEIGHT_TYPES: Dict of weight_type_name -> qtype - Pytest mark: torchao - Use `pytest -m "not torchao"` to skip these tests + Pytest mark: quanto + Use `pytest -m "not quanto"` to skip these tests + """ + + @pytest.mark.parametrize( + "weight_type_name", + list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()), + ids=list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()), + ) + def test_quanto_quantization_num_parameters(self, weight_type_name): + self._test_quantization_num_parameters(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name]) + + @pytest.mark.parametrize( + "weight_type_name", + list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()), + ids=list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()), + ) + def test_quanto_quantization_memory_footprint(self, weight_type_name): + expected = QuantoConfigMixin.QUANTO_EXPECTED_MEMORY_REDUCTIONS.get(weight_type_name, 1.2) + self._test_quantization_memory_footprint( + QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name], expected_memory_reduction=expected + ) + + @pytest.mark.parametrize( + "weight_type_name", + list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()), + ids=list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()), + ) + def test_quanto_quantization_inference(self, weight_type_name): + self._test_quantization_inference(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name]) + + @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) + def test_quanto_quantized_layers(self, weight_type_name): + self._test_quantized_layers(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name]) + + @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) + def test_quanto_quantization_lora_inference(self, weight_type_name): + self._test_quantization_lora_inference(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name]) + + @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) + def test_quanto_quantization_serialization(self, weight_type_name, tmp_path): + self._test_quantization_serialization(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name], tmp_path) + + def test_quanto_modules_to_not_convert(self): + """Test that modules_to_not_convert parameter works correctly.""" + modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None) + if modules_to_exclude is None: + pytest.skip("modules_to_not_convert_for_test not defined for this model") + + self._test_quantization_modules_to_not_convert( + QuantoConfigMixin.QUANTO_WEIGHT_TYPES["int8"], modules_to_exclude + ) + + def test_quanto_device_map(self): + """Test that device_map='auto' works correctly with quantization.""" + self._test_quantization_device_map(QuantoConfigMixin.QUANTO_WEIGHT_TYPES["int8"]) + + def test_quanto_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize(QuantoConfigMixin.QUANTO_WEIGHT_TYPES["int8"]) + + +@is_quantization +@is_torchao +@require_accelerator +@require_torchao_version_greater_or_equal("0.7.0") +class TorchAoConfigMixin: + """ + Base mixin providing TorchAO quantization config and model creation. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained """ TORCHAO_QUANT_TYPES = { @@ -643,69 +777,103 @@ class TorchAoTesterMixin(QuantizationTesterMixin): def _verify_if_layer_quantized(self, name, module, config_kwargs): assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}" - @pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()), ids=list(TORCHAO_QUANT_TYPES.keys())) - def test_torchao_quantization_num_parameters(self, quant_type): - self._test_quantization_num_parameters(self.TORCHAO_QUANT_TYPES[quant_type]) - @pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()), ids=list(TORCHAO_QUANT_TYPES.keys())) - def test_torchao_quantization_memory_footprint(self, quant_type): - expected = self.TORCHAO_EXPECTED_MEMORY_REDUCTIONS.get(quant_type, 1.2) - self._test_quantization_memory_footprint( - self.TORCHAO_QUANT_TYPES[quant_type], expected_memory_reduction=expected - ) - - @pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()), ids=list(TORCHAO_QUANT_TYPES.keys())) - def test_torchao_quantization_inference(self, quant_type): - self._test_quantization_inference(self.TORCHAO_QUANT_TYPES[quant_type]) - - @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) - def test_torchao_quantized_layers(self, quant_type): - self._test_quantized_layers(self.TORCHAO_QUANT_TYPES[quant_type]) - - @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) - def test_torchao_quantization_lora_inference(self, quant_type): - self._test_quantization_lora_inference(self.TORCHAO_QUANT_TYPES[quant_type]) - - @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) - def test_torchao_quantization_serialization(self, quant_type): - self._test_quantization_serialization(self.TORCHAO_QUANT_TYPES[quant_type]) - - def test_torchao_modules_to_not_convert(self): - """Test that modules_to_not_convert parameter works correctly.""" - # Get a module name that exists in the model - this needs to be set by test classes - # For now, use a generic pattern that should work with transformer models - modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None) - if modules_to_exclude is None: - pytest.skip("modules_to_not_convert_for_test not defined for this model") - - self._test_quantization_modules_to_not_convert(self.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude) - - def test_torchao_device_map(self): - """Test that device_map='auto' works correctly with quantization.""" - self._test_quantization_device_map(self.TORCHAO_QUANT_TYPES["int8wo"]) - - def test_torchao_dequantize(self): - """Test that dequantize() works correctly.""" - self._test_dequantize(self.TORCHAO_QUANT_TYPES["int8wo"]) - - -@is_gguf -@require_accelerate +@is_torchao @require_accelerator -@require_gguf_version_greater_or_equal("0.10.0") -class GGUFTesterMixin(QuantizationTesterMixin): +@require_torchao_version_greater_or_equal("0.7.0") +class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin): """ - Mixin class for testing GGUF quantization on models. + Mixin class for testing TorchAO quantization on models. Expected class attributes: - model_class: The model class to test - - gguf_filename: URL or path to the GGUF file + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) Expected methods to be implemented by subclasses: - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass - Pytest mark: gguf - Use `pytest -m "not gguf"` to skip these tests + Optional class attributes: + - TORCHAO_QUANT_TYPES: Dict of quantization type strings to test + + Pytest mark: torchao + Use `pytest -m "not torchao"` to skip these tests + """ + + @pytest.mark.parametrize( + "quant_type", + list(TorchAoConfigMixin.TORCHAO_QUANT_TYPES.keys()), + ids=list(TorchAoConfigMixin.TORCHAO_QUANT_TYPES.keys()), + ) + def test_torchao_quantization_num_parameters(self, quant_type): + self._test_quantization_num_parameters(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) + + @pytest.mark.parametrize( + "quant_type", + list(TorchAoConfigMixin.TORCHAO_QUANT_TYPES.keys()), + ids=list(TorchAoConfigMixin.TORCHAO_QUANT_TYPES.keys()), + ) + def test_torchao_quantization_memory_footprint(self, quant_type): + expected = TorchAoConfigMixin.TORCHAO_EXPECTED_MEMORY_REDUCTIONS.get(quant_type, 1.2) + self._test_quantization_memory_footprint( + TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type], expected_memory_reduction=expected + ) + + @pytest.mark.parametrize( + "quant_type", + list(TorchAoConfigMixin.TORCHAO_QUANT_TYPES.keys()), + ids=list(TorchAoConfigMixin.TORCHAO_QUANT_TYPES.keys()), + ) + def test_torchao_quantization_inference(self, quant_type): + self._test_quantization_inference(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) + + @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) + def test_torchao_quantized_layers(self, quant_type): + self._test_quantized_layers(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) + + @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) + def test_torchao_quantization_lora_inference(self, quant_type): + self._test_quantization_lora_inference(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) + + @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) + def test_torchao_quantization_serialization(self, quant_type, tmp_path): + self._test_quantization_serialization(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type], tmp_path) + + def test_torchao_modules_to_not_convert(self): + """Test that modules_to_not_convert parameter works correctly.""" + modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None) + if modules_to_exclude is None: + pytest.skip("modules_to_not_convert_for_test not defined for this model") + + self._test_quantization_modules_to_not_convert( + TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude + ) + + def test_torchao_device_map(self): + """Test that device_map='auto' works correctly with quantization.""" + self._test_quantization_device_map(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]) + + def test_torchao_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]) + + def test_torchao_training(self): + """Test that quantized models can be used for training with adapters.""" + self._test_quantization_training(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]) + + +@is_quantization +@is_gguf +@require_accelerate +@require_accelerator +@require_gguf_version_greater_or_equal("0.10.0") +class GGUFConfigMixin: + """ + Base mixin providing GGUF quantization config and model creation. + + Expected class attributes: + - model_class: The model class to test + - gguf_filename: URL or path to the GGUF file """ gguf_filename = None @@ -729,6 +897,26 @@ class GGUFTesterMixin(QuantizationTesterMixin): assert hasattr(module.weight, "quant_type"), f"{name} weight missing quant_type" assert module.weight.dtype == torch.uint8, f"{name} weight dtype should be uint8" + +@is_gguf +@require_accelerate +@require_accelerator +@require_gguf_version_greater_or_equal("0.10.0") +class GGUFTesterMixin(GGUFConfigMixin, QuantizationTesterMixin): + """ + Mixin class for testing GGUF quantization on models. + + Expected class attributes: + - model_class: The model class to test + - gguf_filename: URL or path to the GGUF file + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: gguf + Use `pytest -m "not gguf"` to skip these tests + """ + def test_gguf_quantization_inference(self): self._test_quantization_inference({"compute_dtype": torch.bfloat16}) @@ -763,27 +951,19 @@ class GGUFTesterMixin(QuantizationTesterMixin): self._test_quantized_layers({"compute_dtype": torch.bfloat16}) +@is_quantization @is_modelopt @require_accelerator @require_accelerate @require_modelopt_version_greater_or_equal("0.33.1") -class ModelOptTesterMixin(QuantizationTesterMixin): +class ModelOptConfigMixin: """ - Mixin class for testing NVIDIA ModelOpt quantization on models. + Base mixin providing NVIDIA ModelOpt quantization config and model creation. Expected class attributes: - model_class: The model class to test - pretrained_model_name_or_path: Hub repository ID for the pretrained model - - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) - - Expected methods to be implemented by subclasses: - - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass - - Optional class attributes: - - MODELOPT_CONFIGS: Dict of config name -> NVIDIAModelOptConfig kwargs to test - - Pytest mark: modelopt - Use `pytest -m "not modelopt"` to skip these tests + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained """ MODELOPT_CONFIGS = { @@ -808,36 +988,68 @@ class ModelOptTesterMixin(QuantizationTesterMixin): def _verify_if_layer_quantized(self, name, module, config_kwargs): assert mtq.utils.is_quantized(module), f"Layer {name} does not have weight_quantizer attribute (not quantized)" + +@is_modelopt +@require_accelerator +@require_accelerate +@require_modelopt_version_greater_or_equal("0.33.1") +class ModelOptTesterMixin(ModelOptConfigMixin, QuantizationTesterMixin): + """ + Mixin class for testing NVIDIA ModelOpt quantization on models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Optional class attributes: + - MODELOPT_CONFIGS: Dict of config name -> NVIDIAModelOptConfig kwargs to test + + Pytest mark: modelopt + Use `pytest -m "not modelopt"` to skip these tests + """ + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) def test_modelopt_quantization_num_parameters(self, config_name): - self._test_quantization_num_parameters(self.MODELOPT_CONFIGS[config_name]) + self._test_quantization_num_parameters(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name]) - @pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()), ids=list(MODELOPT_CONFIGS.keys())) + @pytest.mark.parametrize( + "config_name", + list(ModelOptConfigMixin.MODELOPT_CONFIGS.keys()), + ids=list(ModelOptConfigMixin.MODELOPT_CONFIGS.keys()), + ) def test_modelopt_quantization_memory_footprint(self, config_name): - expected = self.MODELOPT_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2) + expected = ModelOptConfigMixin.MODELOPT_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2) self._test_quantization_memory_footprint( - self.MODELOPT_CONFIGS[config_name], expected_memory_reduction=expected + ModelOptConfigMixin.MODELOPT_CONFIGS[config_name], expected_memory_reduction=expected ) - @pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()), ids=list(MODELOPT_CONFIGS.keys())) + @pytest.mark.parametrize( + "config_name", + list(ModelOptConfigMixin.MODELOPT_CONFIGS.keys()), + ids=list(ModelOptConfigMixin.MODELOPT_CONFIGS.keys()), + ) def test_modelopt_quantization_inference(self, config_name): - self._test_quantization_inference(self.MODELOPT_CONFIGS[config_name]) + self._test_quantization_inference(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name]) @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) def test_modelopt_quantization_dtype_assignment(self, config_name): - self._test_quantization_dtype_assignment(self.MODELOPT_CONFIGS[config_name]) + self._test_quantization_dtype_assignment(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name]) @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) def test_modelopt_quantization_lora_inference(self, config_name): - self._test_quantization_lora_inference(self.MODELOPT_CONFIGS[config_name]) + self._test_quantization_lora_inference(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name]) @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) - def test_modelopt_quantization_serialization(self, config_name): - self._test_quantization_serialization(self.MODELOPT_CONFIGS[config_name]) + def test_modelopt_quantization_serialization(self, config_name, tmp_path): + self._test_quantization_serialization(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name], tmp_path) @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) def test_modelopt_quantized_layers(self, config_name): - self._test_quantized_layers(self.MODELOPT_CONFIGS[config_name]) + self._test_quantized_layers(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name]) def test_modelopt_modules_to_not_convert(self): """Test that modules_to_not_convert parameter works correctly.""" @@ -845,12 +1057,235 @@ class ModelOptTesterMixin(QuantizationTesterMixin): if modules_to_exclude is None: pytest.skip("modules_to_not_convert_for_test not defined for this model") - self._test_quantization_modules_to_not_convert(self.MODELOPT_CONFIGS["fp8"], modules_to_exclude) + self._test_quantization_modules_to_not_convert(ModelOptConfigMixin.MODELOPT_CONFIGS["fp8"], modules_to_exclude) def test_modelopt_device_map(self): """Test that device_map='auto' works correctly with quantization.""" - self._test_quantization_device_map(self.MODELOPT_CONFIGS["fp8"]) + self._test_quantization_device_map(ModelOptConfigMixin.MODELOPT_CONFIGS["fp8"]) def test_modelopt_dequantize(self): """Test that dequantize() works correctly.""" - self._test_dequantize(self.MODELOPT_CONFIGS["fp8"]) + self._test_dequantize(ModelOptConfigMixin.MODELOPT_CONFIGS["fp8"]) + + +@is_torch_compile +class QuantizationCompileTesterMixin: + """ + Base mixin class providing common test implementations for torch.compile with quantized models. + + Backend-specific compile mixins should: + 1. Inherit from their respective config mixin (e.g., BitsAndBytesConfigMixin) + 2. Inherit from this mixin + 3. Define the config to use for compile tests + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + + Expected methods in test classes: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + """ + + def setup_method(self): + gc.collect() + backend_empty_cache(torch_device) + torch.compiler.reset() + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + torch.compiler.reset() + + def _test_torch_compile(self, config_kwargs): + """ + Test that torch.compile works correctly with a quantized model. + + Args: + config_kwargs: Quantization config parameters + """ + model = self._create_quantized_model(config_kwargs) + model.to(torch_device) + model.eval() + + # Compile the model with fullgraph=True to ensure no graph breaks + model = torch.compile(model, fullgraph=True) + + # Run inference with error_on_recompile to detect recompilation issues + with torch.no_grad(), torch._dynamo.config.patch(error_on_recompile=True): + inputs = self.get_dummy_inputs() + output = model(**inputs, return_dict=False)[0] + assert output is not None, "Model output is None" + assert not torch.isnan(output).any(), "Model output contains NaN" + + def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False): + """ + Test that torch.compile works correctly with a quantized model and group offloading. + + Args: + config_kwargs: Quantization config parameters + use_stream: Whether to use CUDA streams for offloading + """ + torch._dynamo.config.cache_size_limit = 1000 + + model = self._create_quantized_model(config_kwargs) + model.eval() + + if not hasattr(model, "enable_group_offload"): + pytest.skip("Model does not support group offloading") + + group_offload_kwargs = { + "onload_device": torch.device(torch_device), + "offload_device": torch.device("cpu"), + "offload_type": "leaf_level", + "use_stream": use_stream, + } + model.enable_group_offload(**group_offload_kwargs) + model = torch.compile(model) + + with torch.no_grad(): + inputs = self.get_dummy_inputs() + output = model(**inputs, return_dict=False)[0] + assert output is not None, "Model output is None" + assert not torch.isnan(output).any(), "Model output contains NaN" + + +@is_bitsandbytes +@require_accelerator +@require_bitsandbytes_version_greater("0.43.2") +@require_accelerate +class BitsAndBytesCompileTesterMixin(BitsAndBytesConfigMixin, QuantizationCompileTesterMixin): + """ + Mixin class for testing torch.compile with BitsAndBytes quantized models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: bitsandbytes + Use `pytest -m "not bitsandbytes"` to skip these tests + """ + + @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) + def test_bnb_torch_compile(self, config_name): + self._test_torch_compile(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) + def test_bnb_torch_compile_with_group_offload(self, config_name): + self._test_torch_compile_with_group_offload(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) + + +@is_quanto +@require_quanto +@require_accelerate +@require_accelerator +class QuantoCompileTesterMixin(QuantoConfigMixin, QuantizationCompileTesterMixin): + """ + Mixin class for testing torch.compile with Quanto quantized models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: quanto + Use `pytest -m "not quanto"` to skip these tests + """ + + @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) + def test_quanto_torch_compile(self, weight_type_name): + self._test_torch_compile(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name]) + + @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) + def test_quanto_torch_compile_with_group_offload(self, weight_type_name): + self._test_torch_compile_with_group_offload(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name]) + + +@is_torchao +@require_accelerator +@require_torchao_version_greater_or_equal("0.7.0") +class TorchAoCompileTesterMixin(TorchAoConfigMixin, QuantizationCompileTesterMixin): + """ + Mixin class for testing torch.compile with TorchAO quantized models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: torchao + Use `pytest -m "not torchao"` to skip these tests + """ + + @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) + def test_torchao_torch_compile(self, quant_type): + self._test_torch_compile(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) + + @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) + def test_torchao_torch_compile_with_group_offload(self, quant_type): + self._test_torch_compile_with_group_offload(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) + + +@is_gguf +@require_accelerate +@require_accelerator +@require_gguf_version_greater_or_equal("0.10.0") +class GGUFCompileTesterMixin(GGUFConfigMixin, QuantizationCompileTesterMixin): + """ + Mixin class for testing torch.compile with GGUF quantized models. + + Expected class attributes: + - model_class: The model class to test + - gguf_filename: URL or path to the GGUF file + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: gguf + Use `pytest -m "not gguf"` to skip these tests + """ + + def test_gguf_torch_compile(self): + self._test_torch_compile({"compute_dtype": torch.bfloat16}) + + def test_gguf_torch_compile_with_group_offload(self): + self._test_torch_compile_with_group_offload({"compute_dtype": torch.bfloat16}) + + +@is_modelopt +@require_accelerator +@require_accelerate +@require_modelopt_version_greater_or_equal("0.33.1") +class ModelOptCompileTesterMixin(ModelOptConfigMixin, QuantizationCompileTesterMixin): + """ + Mixin class for testing torch.compile with NVIDIA ModelOpt quantized models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: modelopt + Use `pytest -m "not modelopt"` to skip these tests + """ + + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) + def test_modelopt_torch_compile(self, config_name): + self._test_torch_compile(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) + def test_modelopt_torch_compile_with_group_offload(self, config_name): + self._test_torch_compile_with_group_offload(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name]) diff --git a/tests/models/testing_utils/single_file.py b/tests/models/testing_utils/single_file.py index 992e6dd8d9..52890fc3c6 100644 --- a/tests/models/testing_utils/single_file.py +++ b/tests/models/testing_utils/single_file.py @@ -14,7 +14,6 @@ # limitations under the License. import gc -import tempfile import torch from huggingface_hub import hf_hub_download, snapshot_download @@ -151,21 +150,20 @@ class SingleFileTesterMixin: param, param_single_file, atol=1e-5, rtol=1e-5, msg=f"Parameter values differ for {key}" ) - def test_single_file_loading_local_files_only(self): + def test_single_file_loading_local_files_only(self, tmp_path): single_file_kwargs = {} if hasattr(self, "torch_dtype") and self.torch_dtype: single_file_kwargs["torch_dtype"] = self.torch_dtype - with tempfile.TemporaryDirectory() as tmpdir: - pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir) + pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path)) - model_single_file = self.model_class.from_single_file( - local_ckpt_path, local_files_only=True, **single_file_kwargs - ) + model_single_file = self.model_class.from_single_file( + local_ckpt_path, local_files_only=True, **single_file_kwargs + ) - assert model_single_file is not None, "Failed to load model with local_files_only=True" + assert model_single_file is not None, "Failed to load model with local_files_only=True" def test_single_file_loading_with_diffusers_config(self): single_file_kwargs = {} @@ -196,22 +194,21 @@ class SingleFileTesterMixin: f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}" ) - def test_single_file_loading_with_diffusers_config_local_files_only(self): + def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path): single_file_kwargs = {} if hasattr(self, "torch_dtype") and self.torch_dtype: single_file_kwargs["torch_dtype"] = self.torch_dtype - with tempfile.TemporaryDirectory() as tmpdir: - pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir) - local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, tmpdir) + pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path)) + local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, str(tmp_path)) - model_single_file = self.model_class.from_single_file( - local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs - ) + model_single_file = self.model_class.from_single_file( + local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs + ) - assert model_single_file is not None, "Failed to load model with config and local_files_only=True" + assert model_single_file is not None, "Failed to load model with config and local_files_only=True" def test_single_file_loading_dtype(self): for dtype in [torch.float32, torch.float16]: diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index e0b38eda7f..b264eda22c 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -26,17 +26,25 @@ from ...testing_utils import enable_full_determinism, torch_device from ..testing_utils import ( AttentionTesterMixin, BaseModelTesterConfig, + BitsAndBytesCompileTesterMixin, BitsAndBytesTesterMixin, ContextParallelTesterMixin, + FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + GGUFCompileTesterMixin, GGUFTesterMixin, IPAdapterTesterMixin, LoraHotSwappingForModelTesterMixin, LoraTesterMixin, MemoryTesterMixin, + ModelOptCompileTesterMixin, ModelOptTesterMixin, ModelTesterMixin, + PyramidAttentionBroadcastTesterMixin, + QuantoCompileTesterMixin, QuantoTesterMixin, SingleFileTesterMixin, + TorchAoCompileTesterMixin, TorchAoTesterMixin, TorchCompileTesterMixin, TrainingTesterMixin, @@ -353,3 +361,94 @@ class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMix "txt_ids": randn_tensor((512, 3)), "guidance": torch.tensor([3.5]).to(torch_device), } + + +class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerQuantoCompile(FluxTransformerTesterConfig, QuantoCompileTesterMixin): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerTorchAoCompile(FluxTransformerTesterConfig, TorchAoCompileTesterMixin): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTesterMixin): + gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf" + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin): + """PyramidAttentionBroadcast cache tests for Flux Transformer.""" + + pass + + +class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin): + """FirstBlockCache tests for Flux Transformer.""" + + pass + + +class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin): + """FasterCache tests for Flux Transformer.""" + + # Flux is guidance distilled, so we can test at model level without CFG batch handling + FASTER_CACHE_CONFIG = { + "spatial_attention_block_skip_range": 2, + "spatial_attention_timestep_skip_range": (-1, 901), + "tensor_format": "BCHW", + "is_guidance_distilled": True, + } diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 4c97bbc14c..e7f1d10c05 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -414,6 +414,15 @@ def is_group_offload(test_case): return pytest.mark.group_offload(test_case) +def is_quantization(test_case): + """ + Decorator marking a test as a quantization test. These tests can be filtered using: + pytest -m "not quantization" to skip + pytest -m quantization to run only these tests + """ + return pytest.mark.quantization(test_case) + + def is_bitsandbytes(test_case): """ Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using: @@ -468,6 +477,15 @@ def is_context_parallel(test_case): return pytest.mark.context_parallel(test_case) +def is_cache(test_case): + """ + Decorator marking a test as a cache test. These tests can be filtered using: + pytest -m "not cache" to skip + pytest -m cache to run only these tests + """ + return pytest.mark.cache(test_case) + + def require_torch(test_case): """ Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. @@ -835,7 +853,7 @@ def require_modelopt_version_greater_or_equal(modelopt_version): ) >= version.parse(modelopt_version) return pytest.mark.skipif( not correct_nvidia_modelopt_version, - f"Test requires modelopt with version greater than {modelopt_version}.", + reason=f"Test requires modelopt with version greater than {modelopt_version}.", )(test_case) return decorator