From c70de2bc3753820567c8097b7f433e8099d09437 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 18 Dec 2025 13:18:54 +0530 Subject: [PATCH] update --- tests/models/testing_utils/cache.py | 20 ++++---- tests/models/testing_utils/quantization.py | 54 +++++++++++----------- 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/tests/models/testing_utils/cache.py b/tests/models/testing_utils/cache.py index c1b916d34e..8fc31ee969 100644 --- a/tests/models/testing_utils/cache.py +++ b/tests/models/testing_utils/cache.py @@ -24,7 +24,7 @@ from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK from diffusers.models.cache_utils import CacheMixin -from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device +from ...testing_utils import backend_empty_cache, is_cache, torch_device def require_cache_mixin(func): @@ -177,9 +177,9 @@ class CacheTesterMixin: output_without_cache = model(**inputs_dict_step2, return_dict=False)[0] # Cached output should be different from non-cached output (due to approximation) - assert not torch.allclose( - output_without_cache, output_with_cache, atol=1e-5 - ), "Cached output should be different from non-cached output due to cache approximation." + assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), ( + "Cached output should be different from non-cached output due to cache approximation." + ) def _test_cache_context_manager(self): """Test the cache_context context manager.""" @@ -354,9 +354,9 @@ class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin): output_without_cache = model(**inputs_dict_step2, return_dict=False)[0] # Cached output should be different from non-cached output (due to approximation) - assert not torch.allclose( - output_without_cache, output_with_cache, atol=1e-5 - ), "Cached output should be different from non-cached output due to cache approximation." + assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), ( + "Cached output should be different from non-cached output due to cache approximation." + ) def _test_reset_stateful_cache(self): """Test that _reset_stateful_cache resets the FBC cache state (requires cache_context).""" @@ -487,9 +487,9 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin): output_without_cache = model(**inputs_dict_step2, return_dict=False)[0] # Cached output should be different from non-cached output (due to approximation) - assert not torch.allclose( - output_without_cache, output_with_cache, atol=1e-5 - ), "Cached output should be different from non-cached output due to cache approximation." + assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), ( + "Cached output should be different from non-cached output due to cache approximation." + ) def _test_reset_stateful_cache(self): """Test that _reset_stateful_cache resets the FasterCache state.""" diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 4d3bde21b7..c199ec5062 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -152,9 +152,9 @@ class QuantizationTesterMixin: model_quantized = self._create_quantized_model(config_kwargs) num_params_quantized = model_quantized.num_parameters() - assert ( - num_params == num_params_quantized - ), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" + assert num_params == num_params_quantized, ( + f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" + ) def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2): model = self._load_unquantized_model() @@ -164,9 +164,9 @@ class QuantizationTesterMixin: mem_quantized = model_quantized.get_memory_footprint() ratio = mem / mem_quantized - assert ( - ratio >= expected_memory_reduction - ), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" + assert ratio >= expected_memory_reduction, ( + f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" + ) def _test_quantization_inference(self, config_kwargs): model_quantized = self._create_quantized_model(config_kwargs) @@ -260,12 +260,12 @@ class QuantizationTesterMixin: self._verify_if_layer_quantized(name, module, config_kwargs) num_quantized_layers += 1 - assert ( - num_quantized_layers > 0 - ), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" - assert ( - num_quantized_layers == expected_quantized_layers - ), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" + assert num_quantized_layers > 0, ( + f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" + ) + assert num_quantized_layers == expected_quantized_layers, ( + f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" + ) def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert): """ @@ -289,9 +289,9 @@ class QuantizationTesterMixin: if any(excluded in name for excluded in modules_to_not_convert): found_excluded = True # This module should NOT be quantized - assert not self._is_module_quantized( - module - ), f"Module {name} should not be quantized but was found to be quantized" + assert not self._is_module_quantized(module), ( + f"Module {name} should not be quantized but was found to be quantized" + ) assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}" @@ -313,9 +313,9 @@ class QuantizationTesterMixin: mem_with_exclusion = model_with_exclusion.get_memory_footprint() mem_fully_quantized = model_fully_quantized.get_memory_footprint() - assert ( - mem_with_exclusion > mem_fully_quantized - ), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" + assert mem_with_exclusion > mem_fully_quantized, ( + f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" + ) def _test_quantization_device_map(self, config_kwargs): """ @@ -465,9 +465,9 @@ class BitsAndBytesConfigMixin: def _verify_if_layer_quantized(self, name, module, config_kwargs): expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params - assert ( - module.weight.__class__ == expected_weight_class - ), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" + assert module.weight.__class__ == expected_weight_class, ( + f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" + ) @is_bitsandbytes @@ -581,13 +581,13 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin): for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules): - assert ( - module.weight.dtype == torch.float32 - ), f"Module {name} should be FP32 but is {module.weight.dtype}" + assert module.weight.dtype == torch.float32, ( + f"Module {name} should be FP32 but is {module.weight.dtype}" + ) else: - assert ( - module.weight.dtype == torch.uint8 - ), f"Module {name} should be uint8 but is {module.weight.dtype}" + assert module.weight.dtype == torch.uint8, ( + f"Module {name} should be uint8 but is {module.weight.dtype}" + ) with torch.no_grad(): inputs = self.get_dummy_inputs()