From ba475eee8d6bfe72a40c51e560779a865485f113 Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 7 Jan 2026 12:26:44 +0530 Subject: [PATCH] update --- tests/models/testing_utils/attention.py | 26 ++------ tests/models/testing_utils/cache.py | 18 ++---- tests/models/testing_utils/common.py | 41 +++++------- tests/models/testing_utils/compile.py | 15 +++-- tests/models/testing_utils/memory.py | 11 +++- tests/models/testing_utils/parallelism.py | 7 +- tests/models/testing_utils/quantization.py | 75 ++++++++++------------ 7 files changed, 81 insertions(+), 112 deletions(-) diff --git a/tests/models/testing_utils/attention.py b/tests/models/testing_utils/attention.py index d732195c7e..3f89026dfa 100644 --- a/tests/models/testing_utils/attention.py +++ b/tests/models/testing_utils/attention.py @@ -40,7 +40,6 @@ class AttentionTesterMixin: Expected class attributes to be set by subclasses: - model_class: The model class to test - - base_precision: Tolerance for floating point comparisons (default: 1e-3) - uses_custom_attn_processor: Whether model uses custom attention processors (default: False) Expected methods to be implemented by subclasses: @@ -51,8 +50,7 @@ class AttentionTesterMixin: Use `pytest -m "not attention"` to skip these tests """ - base_precision = 1e-3 - + @torch.no_grad() def test_fuse_unfuse_qkv_projections(self): init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs() @@ -63,14 +61,10 @@ class AttentionTesterMixin: if not hasattr(model, "fuse_qkv_projections"): pytest.skip("Model does not support QKV projection fusion.") - # Get output before fusion - with torch.no_grad(): - output_before_fusion = model(**inputs_dict, return_dict=False)[0] + output_before_fusion = model(**inputs_dict, return_dict=False)[0] - # Fuse projections model.fuse_qkv_projections() - # Verify fusion occurred by checking for fused attributes has_fused_projections = False for module in model.modules(): if isinstance(module, AttentionModuleMixin): @@ -80,38 +74,30 @@ class AttentionTesterMixin: break if has_fused_projections: - # Get output after fusion - with torch.no_grad(): - output_after_fusion = model(**inputs_dict, return_dict=False)[0] + output_after_fusion = model(**inputs_dict, return_dict=False)[0] - # Verify outputs match assert_tensors_close( output_before_fusion, output_after_fusion, - atol=self.base_precision, + atol=1e-3, rtol=0, msg="Output should not change after fusing projections", ) - # Unfuse projections model.unfuse_qkv_projections() - # Verify unfusion occurred for module in model.modules(): if isinstance(module, AttentionModuleMixin): assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing" assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing" assert not module.fused_projections, "fused_projections flag should be False" - # Get output after unfusion - with torch.no_grad(): - output_after_unfusion = model(**inputs_dict, return_dict=False)[0] + output_after_unfusion = model(**inputs_dict, return_dict=False)[0] - # Verify outputs still match assert_tensors_close( output_before_fusion, output_after_unfusion, - atol=self.base_precision, + atol=1e-3, rtol=0, msg="Output should match original after unfusing projections", ) diff --git a/tests/models/testing_utils/cache.py b/tests/models/testing_utils/cache.py index 8fc31ee969..1f828ca9f8 100644 --- a/tests/models/testing_utils/cache.py +++ b/tests/models/testing_utils/cache.py @@ -196,6 +196,7 @@ class CacheTesterMixin: model.disable_cache() + @torch.no_grad() def _test_reset_stateful_cache(self): """Test that _reset_stateful_cache resets the cache state.""" init_dict = self.get_init_dict() @@ -207,11 +208,8 @@ class CacheTesterMixin: model.enable_cache(config) - # Run forward to populate cache state - with torch.no_grad(): - _ = model(**inputs_dict, return_dict=False)[0] + _ = model(**inputs_dict, return_dict=False)[0] - # Reset should not raise any errors model._reset_stateful_cache() model.disable_cache() @@ -358,6 +356,7 @@ class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin): "Cached output should be different from non-cached output due to cache approximation." ) + @torch.no_grad() def _test_reset_stateful_cache(self): """Test that _reset_stateful_cache resets the FBC cache state (requires cache_context).""" init_dict = self.get_init_dict() @@ -368,12 +367,9 @@ class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin): config = self._get_cache_config() model.enable_cache(config) - # FBC requires cache_context to be set for inference with model.cache_context("fbc_test"): - with torch.no_grad(): - _ = model(**inputs_dict, return_dict=False)[0] + _ = model(**inputs_dict, return_dict=False)[0] - # Reset should not raise any errors model._reset_stateful_cache() model.disable_cache() @@ -491,6 +487,7 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin): "Cached output should be different from non-cached output due to cache approximation." ) + @torch.no_grad() def _test_reset_stateful_cache(self): """Test that _reset_stateful_cache resets the FasterCache state.""" init_dict = self.get_init_dict() @@ -501,12 +498,9 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin): config = self._get_cache_config() model.enable_cache(config) - # First pass with timestep outside skip range self._current_timestep[0] = 1000 - with torch.no_grad(): - _ = model(**inputs_dict, return_dict=False)[0] + _ = model(**inputs_dict, return_dict=False)[0] - # Reset should not raise any errors model._reset_stateful_cache() model.disable_cache() diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index 3611eff7ef..49930e2df5 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -145,7 +145,6 @@ class BaseModelTesterConfig: - pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None) - pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {}) - output_shape: Expected output shape for output validation tests (default: None) - - base_precision: Default tolerance for floating point comparisons (default: 1e-3) - model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7]) Required methods (must be implemented by subclasses): @@ -259,6 +258,7 @@ class ModelTesterMixin: pass """ + @torch.no_grad() def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5): torch.manual_seed(0) model = self.model_class(**self.get_init_dict()) @@ -269,7 +269,6 @@ class ModelTesterMixin: new_model = self.model_class.from_pretrained(tmp_path) new_model.to(torch_device) - # check if all parameters shape are the same for param_name in model.state_dict().keys(): param_1 = model.state_dict()[param_name] param_2 = new_model.state_dict()[param_name] @@ -277,12 +276,12 @@ class ModelTesterMixin: f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}" ) - with torch.no_grad(): - image = model(**self.get_dummy_inputs(), return_dict=False)[0] - new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] + image = model(**self.get_dummy_inputs(), return_dict=False)[0] + new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") + @torch.no_grad() def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0): model = self.model_class(**self.get_init_dict()) model.to(torch_device) @@ -291,18 +290,15 @@ class ModelTesterMixin: model.save_pretrained(tmp_path, variant="fp16") new_model = self.model_class.from_pretrained(tmp_path, variant="fp16") - # non-variant cannot be loaded with pytest.raises(OSError) as exc_info: self.model_class.from_pretrained(tmp_path) - # make sure that error message states what keys are missing assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value) new_model.to(torch_device) - with torch.no_grad(): - image = model(**self.get_dummy_inputs(), return_dict=False)[0] - new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] + image = model(**self.get_dummy_inputs(), return_dict=False)[0] + new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") @@ -324,16 +320,15 @@ class ModelTesterMixin: new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype) assert new_model.dtype == dtype + @torch.no_grad() def test_determinism(self, atol=1e-5, rtol=0): model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() - with torch.no_grad(): - first = model(**self.get_dummy_inputs(), return_dict=False)[0] - second = model(**self.get_dummy_inputs(), return_dict=False)[0] + first = model(**self.get_dummy_inputs(), return_dict=False)[0] + second = model(**self.get_dummy_inputs(), return_dict=False)[0] - # Filter out NaN values before comparison first_flat = first.flatten() second_flat = second.flatten() mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat)) @@ -344,24 +339,23 @@ class ModelTesterMixin: first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic" ) + @torch.no_grad() def test_output(self, expected_output_shape=None): model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() inputs_dict = self.get_dummy_inputs() - with torch.no_grad(): - output = model(**inputs_dict, return_dict=False)[0] + output = model(**inputs_dict, return_dict=False)[0] assert output is not None, "Model output is None" assert output[0].shape == expected_output_shape or self.output_shape, ( f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}" ) + @torch.no_grad() def test_outputs_equivalence(self): def set_nan_tensor_to_zero(t): - # Temporary fallback until `aten::_index_put_impl_` is implemented in mps - # Track progress in https://github.com/pytorch/pytorch/issues/77764 device = t.device if device.type == "mps": t = t.to("cpu") @@ -390,9 +384,8 @@ class ModelTesterMixin: model.to(torch_device) model.eval() - with torch.no_grad(): - outputs_dict = model(**self.get_dummy_inputs()) - outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False) + outputs_dict = model(**self.get_dummy_inputs()) + outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False) recursive_check(outputs_tuple, outputs_dict) @@ -465,6 +458,7 @@ class ModelTesterMixin: reason="float16 and bfloat16 can only be use for inference with an accelerator", ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + @torch.no_grad() def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): model = self.model_class(**self.get_init_dict()) model.to(torch_device) @@ -479,9 +473,8 @@ class ModelTesterMixin: else: assert param.data.dtype == dtype - with torch.no_grad(): - output = model(**self.get_dummy_inputs(), return_dict=False)[0] - output_loaded = model_loaded(**self.get_dummy_inputs(), return_dict=False)[0] + output = model(**self.get_dummy_inputs(), return_dict=False)[0] + output_loaded = model_loaded(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close(output, output_loaded, atol=1e-4, rtol=0, msg=f"Loaded model output differs for {dtype}") diff --git a/tests/models/testing_utils/compile.py b/tests/models/testing_utils/compile.py index 7a479ed5b4..4459e73d2a 100644 --- a/tests/models/testing_utils/compile.py +++ b/tests/models/testing_utils/compile.py @@ -98,6 +98,7 @@ class TorchCompileTesterMixin: _ = model(**inputs_dict) _ = model(**inputs_dict) + @torch.no_grad() def test_compile_with_group_offloading(self): if not self.model_class._supports_group_offloading: pytest.skip("Model does not support group offloading.") @@ -120,10 +121,10 @@ class TorchCompileTesterMixin: model.enable_group_offload(**group_offload_kwargs) model.compile() - with torch.no_grad(): - _ = model(**inputs_dict) - _ = model(**inputs_dict) + _ = model(**inputs_dict) + _ = model(**inputs_dict) + @torch.no_grad() def test_compile_on_different_shapes(self): if self.different_shapes_for_compilation is None: pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") @@ -135,10 +136,11 @@ class TorchCompileTesterMixin: model = torch.compile(model, fullgraph=True, dynamic=True) for height, width in self.different_shapes_for_compilation: - with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): + with torch._dynamo.config.patch(error_on_recompile=True): inputs_dict = self.get_dummy_inputs(height=height, width=width) _ = model(**inputs_dict) + @torch.no_grad() def test_compile_works_with_aot(self, tmp_path): from torch._inductor.package import load_package @@ -155,6 +157,5 @@ class TorchCompileTesterMixin: model.forward = loaded_binary - with torch.no_grad(): - _ = model(**inputs_dict) - _ = model(**inputs_dict) + _ = model(**inputs_dict) + _ = model(**inputs_dict) diff --git a/tests/models/testing_utils/memory.py b/tests/models/testing_utils/memory.py index 6886ece75b..d5441d427a 100644 --- a/tests/models/testing_utils/memory.py +++ b/tests/models/testing_utils/memory.py @@ -208,7 +208,8 @@ class GroupOffloadTesterMixin: """ @require_group_offload_support - def test_group_offloading(self, record_stream=False): + @pytest.mark.parametrize("record_stream", [False, True]) + def test_group_offloading(self, record_stream): init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs() torch.manual_seed(0) @@ -280,8 +281,10 @@ class GroupOffloadTesterMixin: ) @require_group_offload_support + @pytest.mark.parametrize("record_stream", [False, True]) + @pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"]) @torch.no_grad() - def test_group_offloading_with_layerwise_casting(self, record_stream=False, offload_type="block_level"): + def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type): torch.manual_seed(0) init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs() @@ -306,9 +309,11 @@ class GroupOffloadTesterMixin: _ = model(**inputs_dict)[0] @require_group_offload_support + @pytest.mark.parametrize("record_stream", [False, True]) + @pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"]) @torch.no_grad() @torch.inference_mode() - def test_group_offloading_with_disk(self, tmp_path, offload_type="block_level", record_stream=False, atol=1e-5): + def test_group_offloading_with_disk(self, tmp_path, record_stream, offload_type, atol=1e-5): def _has_generator_arg(model): sig = inspect.signature(model.forward) params = sig.parameters diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index 3bbbfe91bb..e0a32d5ed6 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -27,9 +27,9 @@ from ...testing_utils import ( ) +@torch.no_grad() def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue): try: - # Setup distributed environment os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" @@ -56,8 +56,7 @@ def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, cp_config = ContextParallelConfig(**cp_dict) model.enable_parallelism(config=cp_config) - with torch.no_grad(): - output = model(**inputs_on_device, return_dict=False)[0] + output = model(**inputs_on_device, return_dict=False)[0] if rank == 0: result_queue.put(("success", output.shape)) @@ -73,8 +72,6 @@ def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, @is_context_parallel @require_torch_multi_accelerator class ContextParallelTesterMixin: - base_precision = 1e-3 - @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) def test_context_parallel_inference(self, cp_type): if not torch.distributed.is_available(): diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index c199ec5062..dd7db8d330 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -168,15 +168,15 @@ class QuantizationTesterMixin: f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" ) + @torch.no_grad() def _test_quantization_inference(self, config_kwargs): model_quantized = self._create_quantized_model(config_kwargs) - with torch.no_grad(): - inputs = self.get_dummy_inputs() - output = model_quantized(**inputs, return_dict=False)[0] + inputs = self.get_dummy_inputs() + output = model_quantized(**inputs, return_dict=False)[0] - assert output is not None, "Model output is None" - assert not torch.isnan(output).any(), "Model output contains NaN" + assert output is not None, "Model output is None" + assert not torch.isnan(output).any(), "Model output contains NaN" def _test_quantization_dtype_assignment(self, config_kwargs): model = self._create_quantized_model(config_kwargs) @@ -196,6 +196,7 @@ class QuantizationTesterMixin: model.to(torch_device) + @torch.no_grad() def _test_quantization_lora_inference(self, config_kwargs): try: from peft import LoraConfig @@ -217,13 +218,13 @@ class QuantizationTesterMixin: ) model.add_adapter(lora_config) - with torch.no_grad(): - inputs = self.get_dummy_inputs() - output = model(**inputs, return_dict=False)[0] + inputs = self.get_dummy_inputs() + output = model(**inputs, return_dict=False)[0] - assert output is not None, "Model output is None with LoRA" - assert not torch.isnan(output).any(), "Model output contains NaN with LoRA" + assert output is not None, "Model output is None with LoRA" + assert not torch.isnan(output).any(), "Model output contains NaN with LoRA" + @torch.no_grad() def _test_quantization_serialization(self, config_kwargs, tmp_path): model = self._create_quantized_model(config_kwargs) @@ -231,10 +232,9 @@ class QuantizationTesterMixin: model_loaded = self.model_class.from_pretrained(str(tmp_path)) - with torch.no_grad(): - inputs = self.get_dummy_inputs() - output = model_loaded(**inputs, return_dict=False)[0] - assert not torch.isnan(output).any(), "Loaded model output contains NaN" + inputs = self.get_dummy_inputs() + output = model_loaded(**inputs, return_dict=False)[0] + assert not torch.isnan(output).any(), "Loaded model output contains NaN" def _test_quantized_layers(self, config_kwargs): model_fp = self._load_unquantized_model() @@ -317,6 +317,7 @@ class QuantizationTesterMixin: f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" ) + @torch.no_grad() def _test_quantization_device_map(self, config_kwargs): """ Test that quantized models work correctly with device_map="auto". @@ -326,17 +327,15 @@ class QuantizationTesterMixin: """ model = self._create_quantized_model(config_kwargs, device_map="auto") - # Verify device map is set assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute" assert model.hf_device_map is not None, "hf_device_map should not be None" - # Verify inference works - with torch.no_grad(): - inputs = self.get_dummy_inputs() - output = model(**inputs, return_dict=False)[0] - assert output is not None, "Model output is None" - assert not torch.isnan(output).any(), "Model output contains NaN" + inputs = self.get_dummy_inputs() + output = model(**inputs, return_dict=False)[0] + assert output is not None, "Model output is None" + assert not torch.isnan(output).any(), "Model output contains NaN" + @torch.no_grad() def _test_dequantize(self, config_kwargs): """ Test that dequantize() converts quantized model back to standard linear layers. @@ -346,24 +345,19 @@ class QuantizationTesterMixin: """ model = self._create_quantized_model(config_kwargs) - # Verify model has dequantize method if not hasattr(model, "dequantize"): pytest.skip("Model does not have dequantize method") - # Dequantize the model model.dequantize() - # Verify no modules are quantized after dequantization for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()" - # Verify inference still works after dequantization - with torch.no_grad(): - inputs = self.get_dummy_inputs() - output = model(**inputs, return_dict=False)[0] - assert output is not None, "Model output is None after dequantization" - assert not torch.isnan(output).any(), "Model output contains NaN after dequantization" + inputs = self.get_dummy_inputs() + output = model(**inputs, return_dict=False)[0] + assert output is not None, "Model output is None after dequantization" + assert not torch.isnan(output).any(), "Model output contains NaN after dequantization" def _test_quantization_training(self, config_kwargs): """ @@ -566,6 +560,7 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin): torch.bfloat16, ], f"Unexpected dtype: {model.config['_pre_quantization_dtype']}" + @torch.no_grad() def test_bnb_keep_modules_in_fp32(self): if not hasattr(self.model_class, "_keep_in_fp32_modules"): pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules") @@ -589,9 +584,8 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin): f"Module {name} should be uint8 but is {module.weight.dtype}" ) - with torch.no_grad(): - inputs = self.get_dummy_inputs() - _ = model(**inputs) + inputs = self.get_dummy_inputs() + _ = model(**inputs) finally: if original_fp32_modules is not None: self.model_class._keep_in_fp32_modules = original_fp32_modules @@ -1097,6 +1091,7 @@ class QuantizationCompileTesterMixin: backend_empty_cache(torch_device) torch.compiler.reset() + @torch.no_grad() def _test_torch_compile(self, config_kwargs): """ Test that torch.compile works correctly with a quantized model. @@ -1108,16 +1103,15 @@ class QuantizationCompileTesterMixin: model.to(torch_device) model.eval() - # Compile the model with fullgraph=True to ensure no graph breaks model = torch.compile(model, fullgraph=True) - # Run inference with error_on_recompile to detect recompilation issues - with torch.no_grad(), torch._dynamo.config.patch(error_on_recompile=True): + with torch._dynamo.config.patch(error_on_recompile=True): inputs = self.get_dummy_inputs() output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" assert not torch.isnan(output).any(), "Model output contains NaN" + @torch.no_grad() def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False): """ Test that torch.compile works correctly with a quantized model and group offloading. @@ -1143,11 +1137,10 @@ class QuantizationCompileTesterMixin: model.enable_group_offload(**group_offload_kwargs) model = torch.compile(model) - with torch.no_grad(): - inputs = self.get_dummy_inputs() - output = model(**inputs, return_dict=False)[0] - assert output is not None, "Model output is None" - assert not torch.isnan(output).any(), "Model output contains NaN" + inputs = self.get_dummy_inputs() + output = model(**inputs, return_dict=False)[0] + assert output is not None, "Model output is None" + assert not torch.isnan(output).any(), "Model output contains NaN" @is_bitsandbytes