From eae75437128e407cf1593223256f3f999553bedb Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 15 Dec 2025 16:02:38 +0530 Subject: [PATCH] update --- tests/models/testing_utils/__init__.py | 3 +- tests/models/testing_utils/attention.py | 95 +-------------- tests/models/testing_utils/common.py | 70 ++++------- tests/models/testing_utils/hub.py | 109 ----------------- tests/models/testing_utils/ip_adapter.py | 40 ------ tests/models/testing_utils/lora.py | 2 - tests/models/testing_utils/memory.py | 5 +- tests/models/testing_utils/quantization.py | 135 ++++++++++++--------- tests/models/testing_utils/training.py | 28 +---- utils/generate_model_tests.py | 36 ++++-- 10 files changed, 140 insertions(+), 383 deletions(-) delete mode 100644 tests/models/testing_utils/hub.py diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py index 229179737a..6dfb77c713 100644 --- a/tests/models/testing_utils/__init__.py +++ b/tests/models/testing_utils/__init__.py @@ -1,9 +1,10 @@ -from .attention import AttentionTesterMixin, ContextParallelTesterMixin +from .attention import AttentionTesterMixin from .common import BaseModelTesterConfig, ModelTesterMixin from .compile import TorchCompileTesterMixin from .ip_adapter import IPAdapterTesterMixin from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin +from .parallelism import ContextParallelTesterMixin from .quantization import ( BitsAndBytesTesterMixin, GGUFTesterMixin, diff --git a/tests/models/testing_utils/attention.py b/tests/models/testing_utils/attention.py index 45443046fb..d732195c7e 100644 --- a/tests/models/testing_utils/attention.py +++ b/tests/models/testing_utils/attention.py @@ -13,13 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import pytest import torch -import torch.multiprocessing as mp -from diffusers.models._modeling_parallel import ContextParallelConfig from diffusers.models.attention import AttentionModuleMixin from diffusers.models.attention_processor import ( AttnProcessor, @@ -28,8 +24,6 @@ from diffusers.models.attention_processor import ( from ...testing_utils import ( assert_tensors_close, is_attention, - is_context_parallel, - require_torch_multi_accelerator, torch_device, ) @@ -71,9 +65,7 @@ class AttentionTesterMixin: # Get output before fusion with torch.no_grad(): - output_before_fusion = model(**inputs_dict) - if isinstance(output_before_fusion, dict): - output_before_fusion = output_before_fusion.to_tuple()[0] + output_before_fusion = model(**inputs_dict, return_dict=False)[0] # Fuse projections model.fuse_qkv_projections() @@ -90,9 +82,7 @@ class AttentionTesterMixin: if has_fused_projections: # Get output after fusion with torch.no_grad(): - output_after_fusion = model(**inputs_dict) - if isinstance(output_after_fusion, dict): - output_after_fusion = output_after_fusion.to_tuple()[0] + output_after_fusion = model(**inputs_dict, return_dict=False)[0] # Verify outputs match assert_tensors_close( @@ -115,9 +105,7 @@ class AttentionTesterMixin: # Get output after unfusion with torch.no_grad(): - output_after_unfusion = model(**inputs_dict) - if isinstance(output_after_unfusion, dict): - output_after_unfusion = output_after_unfusion.to_tuple()[0] + output_after_unfusion = model(**inputs_dict, return_dict=False)[0] # Verify outputs still match assert_tensors_close( @@ -195,80 +183,3 @@ class AttentionTesterMixin: model.set_attn_processor(wrong_processors) assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch" - - -def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue): - try: - # Setup distributed environment - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - - torch.distributed.init_process_group( - backend="nccl", - init_method="env://", - world_size=world_size, - rank=rank, - ) - torch.cuda.set_device(rank) - device = torch.device(f"cuda:{rank}") - - model = model_class(**init_dict) - model.to(device) - model.eval() - - inputs_on_device = {} - for key, value in inputs_dict.items(): - if isinstance(value, torch.Tensor): - inputs_on_device[key] = value.to(device) - else: - inputs_on_device[key] = value - - cp_config = ContextParallelConfig(**cp_dict) - model.enable_parallelism(config=cp_config) - - with torch.no_grad(): - output = model(**inputs_on_device) - if isinstance(output, dict): - output = output.to_tuple()[0] - - if rank == 0: - result_queue.put(("success", output.shape)) - - except Exception as e: - if rank == 0: - result_queue.put(("error", str(e))) - finally: - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() - - -@is_context_parallel -@require_torch_multi_accelerator -class ContextParallelTesterMixin: - base_precision = 1e-3 - - @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) - def test_context_parallel_inference(self, cp_type): - if not torch.distributed.is_available(): - pytest.skip("torch.distributed is not available.") - - if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: - pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") - - world_size = 2 - init_dict = self.get_init_dict() - inputs_dict = self.get_dummy_inputs() - cp_dict = {cp_type: world_size} - - ctx = mp.get_context("spawn") - result_queue = ctx.Queue() - - mp.spawn( - _context_parallel_worker, - args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue), - nprocs=world_size, - join=True, - ) - - status, result = result_queue.get(timeout=60) - assert status == "success", f"Context parallel inference failed: {result}" diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index 11c10c4557..3611eff7ef 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -259,7 +259,7 @@ class ModelTesterMixin: pass """ - def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=0): + def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5): torch.manual_seed(0) model = self.model_class(**self.get_init_dict()) model.to(torch_device) @@ -278,15 +278,8 @@ class ModelTesterMixin: ) with torch.no_grad(): - image = model(**self.get_dummy_inputs()) - - if isinstance(image, dict): - image = image.to_tuple()[0] - - new_image = new_model(**self.get_dummy_inputs()) - - if isinstance(new_image, dict): - new_image = new_image.to_tuple()[0] + image = model(**self.get_dummy_inputs(), return_dict=False)[0] + new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") @@ -308,14 +301,8 @@ class ModelTesterMixin: new_model.to(torch_device) with torch.no_grad(): - image = model(**self.get_dummy_inputs()) - if isinstance(image, dict): - image = image.to_tuple()[0] - - new_image = new_model(**self.get_dummy_inputs()) - - if isinstance(new_image, dict): - new_image = new_image.to_tuple()[0] + image = model(**self.get_dummy_inputs(), return_dict=False)[0] + new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") @@ -343,13 +330,8 @@ class ModelTesterMixin: model.eval() with torch.no_grad(): - first = model(**self.get_dummy_inputs()) - if isinstance(first, dict): - first = first.to_tuple()[0] - - second = model(**self.get_dummy_inputs()) - if isinstance(second, dict): - second = second.to_tuple()[0] + first = model(**self.get_dummy_inputs(), return_dict=False)[0] + second = model(**self.get_dummy_inputs(), return_dict=False)[0] # Filter out NaN values before comparison first_flat = first.flatten() @@ -369,10 +351,7 @@ class ModelTesterMixin: inputs_dict = self.get_dummy_inputs() with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + output = model(**inputs_dict, return_dict=False)[0] assert output is not None, "Model output is None" assert output[0].shape == expected_output_shape or self.output_shape, ( @@ -501,13 +480,8 @@ class ModelTesterMixin: assert param.data.dtype == dtype with torch.no_grad(): - output = model(**self.get_dummy_inputs()) - if isinstance(output, dict): - output = output.to_tuple()[0] - - output_loaded = model_loaded(**self.get_dummy_inputs()) - if isinstance(output_loaded, dict): - output_loaded = output_loaded.to_tuple()[0] + output = model(**self.get_dummy_inputs(), return_dict=False)[0] + output_loaded = model_loaded(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close(output, output_loaded, atol=1e-4, rtol=0, msg=f"Loaded model output differs for {dtype}") @@ -519,7 +493,7 @@ class ModelTesterMixin: model = self.model_class(**config).eval() model = model.to(torch_device) - base_output = model(**inputs_dict) + base_output = model(**inputs_dict, return_dict=False)[0] model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small @@ -539,10 +513,10 @@ class ModelTesterMixin: torch.manual_seed(0) inputs_dict_new = self.get_dummy_inputs() - new_output = new_model(**inputs_dict_new) + new_output = new_model(**inputs_dict_new, return_dict=False)[0] assert_tensors_close( - base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after sharded save/load" + base_output, new_output, atol=1e-5, rtol=0, msg="Output should match after sharded save/load" ) @require_accelerator @@ -553,7 +527,7 @@ class ModelTesterMixin: model = self.model_class(**config).eval() model = model.to(torch_device) - base_output = model(**inputs_dict) + base_output = model(**inputs_dict, return_dict=False)[0] model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small @@ -578,10 +552,10 @@ class ModelTesterMixin: torch.manual_seed(0) inputs_dict_new = self.get_dummy_inputs() - new_output = new_model(**inputs_dict_new) + new_output = new_model(**inputs_dict_new, return_dict=False)[0] assert_tensors_close( - base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load" + base_output, new_output, atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load" ) def test_sharded_checkpoints_with_parallel_loading(self, tmp_path): @@ -593,7 +567,7 @@ class ModelTesterMixin: model = self.model_class(**config).eval() model = model.to(torch_device) - base_output = model(**inputs_dict) + base_output = model(**inputs_dict, return_dict=False)[0] model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small @@ -628,10 +602,10 @@ class ModelTesterMixin: torch.manual_seed(0) inputs_dict_parallel = self.get_dummy_inputs() - output_parallel = model_parallel(**inputs_dict_parallel) + output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0] assert_tensors_close( - base_output[0], output_parallel[0], atol=1e-5, rtol=0, msg="Output should match with parallel loading" + base_output, output_parallel, atol=1e-5, rtol=0, msg="Output should match with parallel loading" ) finally: @@ -652,7 +626,7 @@ class ModelTesterMixin: model = model.to(torch_device) torch.manual_seed(0) - base_output = model(**inputs_dict) + base_output = model(**inputs_dict, return_dict=False)[0] model_size = compute_module_sizes(model)[""] max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] @@ -668,8 +642,8 @@ class ModelTesterMixin: check_device_map_is_respected(new_model, new_model.hf_device_map) torch.manual_seed(0) - new_output = new_model(**inputs_dict) + new_output = new_model(**inputs_dict, return_dict=False)[0] assert_tensors_close( - base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with model parallelism" + base_output, new_output, atol=1e-5, rtol=0, msg="Output should match with model parallelism" ) diff --git a/tests/models/testing_utils/hub.py b/tests/models/testing_utils/hub.py deleted file mode 100644 index 40d8777c33..0000000000 --- a/tests/models/testing_utils/hub.py +++ /dev/null @@ -1,109 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tempfile -import uuid - -import pytest -import torch -from huggingface_hub.utils import ModelCard, delete_repo, is_jinja_available - -from ...others.test_utils import TOKEN, USER, is_staging_test - - -@is_staging_test -class ModelPushToHubTesterMixin: - """ - Mixin class for testing push_to_hub functionality on models. - - Expected class attributes to be set by subclasses: - - model_class: The model class to test - - Expected methods to be implemented by subclasses: - - get_init_dict(): Returns dict of arguments to initialize the model - """ - - identifier = uuid.uuid4() - repo_id = f"test-model-{identifier}" - org_repo_id = f"valid_org/{repo_id}-org" - - def test_push_to_hub(self): - """Test pushing model to hub and loading it back.""" - init_dict = self.get_init_dict() - model = self.model_class(**init_dict) - model.push_to_hub(self.repo_id, token=TOKEN) - - new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}") - for p1, p2 in zip(model.parameters(), new_model.parameters()): - assert torch.equal(p1, p2), "Parameters don't match after push_to_hub and from_pretrained" - - # Reset repo - delete_repo(token=TOKEN, repo_id=self.repo_id) - - # Push to hub via save_pretrained - with tempfile.TemporaryDirectory() as tmp_dir: - model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN) - - new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}") - for p1, p2 in zip(model.parameters(), new_model.parameters()): - assert torch.equal(p1, p2), ( - "Parameters don't match after save_pretrained with push_to_hub and from_pretrained" - ) - - # Reset repo - delete_repo(self.repo_id, token=TOKEN) - - def test_push_to_hub_in_organization(self): - """Test pushing model to hub in organization namespace.""" - init_dict = self.get_init_dict() - model = self.model_class(**init_dict) - model.push_to_hub(self.org_repo_id, token=TOKEN) - - new_model = self.model_class.from_pretrained(self.org_repo_id) - for p1, p2 in zip(model.parameters(), new_model.parameters()): - assert torch.equal(p1, p2), "Parameters don't match after push_to_hub to org and from_pretrained" - - # Reset repo - delete_repo(token=TOKEN, repo_id=self.org_repo_id) - - # Push to hub via save_pretrained - with tempfile.TemporaryDirectory() as tmp_dir: - model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id) - - new_model = self.model_class.from_pretrained(self.org_repo_id) - for p1, p2 in zip(model.parameters(), new_model.parameters()): - assert torch.equal(p1, p2), ( - "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained" - ) - - # Reset repo - delete_repo(self.org_repo_id, token=TOKEN) - - def test_push_to_hub_library_name(self): - """Test that library_name in model card is set to 'diffusers'.""" - if not is_jinja_available(): - pytest.skip("Model card tests cannot be performed without Jinja installed.") - - init_dict = self.get_init_dict() - model = self.model_class(**init_dict) - model.push_to_hub(self.repo_id, token=TOKEN) - - model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data - assert model_card.library_name == "diffusers", ( - f"Expected library_name 'diffusers', got {model_card.library_name}" - ) - - # Reset repo - delete_repo(self.repo_id, token=TOKEN) diff --git a/tests/models/testing_utils/ip_adapter.py b/tests/models/testing_utils/ip_adapter.py index 13e141869c..891a23d330 100644 --- a/tests/models/testing_utils/ip_adapter.py +++ b/tests/models/testing_utils/ip_adapter.py @@ -17,49 +17,9 @@ import pytest import torch -from diffusers.models.attention_processor import IPAdapterAttnProcessor - from ...testing_utils import is_ip_adapter, torch_device -def create_ip_adapter_state_dict(model): - """ - Create a dummy IP Adapter state dict for testing. - - Args: - model: The model to create IP adapter weights for - - Returns: - dict: IP adapter state dict with to_k_ip and to_v_ip weights - """ - ip_state_dict = {} - key_id = 1 - - for name in model.attn_processors.keys(): - # Skip self-attention processors - cross_attention_dim = getattr(model.config, "cross_attention_dim", None) - if cross_attention_dim is None: - continue - - # Get hidden size based on model architecture - hidden_size = getattr(model.config, "hidden_size", cross_attention_dim) - - # Create IP adapter processor to get state dict structure - sd = IPAdapterAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 - ).state_dict() - - ip_state_dict.update( - { - f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], - f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], - } - ) - key_id += 2 - - return {"ip_adapter": ip_state_dict} - - def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool: """ Check if IP Adapter processors are correctly set in the model. diff --git a/tests/models/testing_utils/lora.py b/tests/models/testing_utils/lora.py index b790e3ea26..5777f99278 100644 --- a/tests/models/testing_utils/lora.py +++ b/tests/models/testing_utils/lora.py @@ -79,8 +79,6 @@ class LoraTesterMixin: """ def setup_method(self): - from diffusers.loaders.peft import PeftAdapterMixin - if not issubclass(self.model_class, PeftAdapterMixin): pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).") diff --git a/tests/models/testing_utils/memory.py b/tests/models/testing_utils/memory.py index ebd76656f0..5486bbb0cd 100644 --- a/tests/models/testing_utils/memory.py +++ b/tests/models/testing_utils/memory.py @@ -455,10 +455,7 @@ class LayerwiseCastingTesterMixin: inputs_dict = self.get_inputs_dict() inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) with torch.amp.autocast(device_type=torch.device(torch_device).type): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + output = model(**inputs_dict, return_dict=False)[0] input_tensor = inputs_dict[self.main_input_name] noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 140f6db6d2..b7f960a135 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -128,9 +128,9 @@ class QuantizationTesterMixin: model_quantized = self._create_quantized_model(config_kwargs) num_params_quantized = model_quantized.num_parameters() - assert num_params == num_params_quantized, ( - f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" - ) + assert ( + num_params == num_params_quantized + ), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2): model = self._load_unquantized_model() @@ -140,19 +140,17 @@ class QuantizationTesterMixin: mem_quantized = model_quantized.get_memory_footprint() ratio = mem / mem_quantized - assert ratio >= expected_memory_reduction, ( - f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" - ) + assert ( + ratio >= expected_memory_reduction + ), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" def _test_quantization_inference(self, config_kwargs): model_quantized = self._create_quantized_model(config_kwargs) with torch.no_grad(): inputs = self.get_dummy_inputs() - output = model_quantized(**inputs) + output = model_quantized(**inputs, return_dict=False)[0] - if isinstance(output, tuple): - output = output[0] assert output is not None, "Model output is None" assert not torch.isnan(output).any(), "Model output contains NaN" @@ -197,10 +195,8 @@ class QuantizationTesterMixin: with torch.no_grad(): inputs = self.get_dummy_inputs() - output = model(**inputs) + output = model(**inputs, return_dict=False)[0] - if isinstance(output, tuple): - output = output[0] assert output is not None, "Model output is None with LoRA" assert not torch.isnan(output).any(), "Model output contains NaN with LoRA" @@ -214,9 +210,7 @@ class QuantizationTesterMixin: with torch.no_grad(): inputs = self.get_dummy_inputs() - output = model_loaded(**inputs) - if isinstance(output, tuple): - output = output[0] + output = model_loaded(**inputs, return_dict=False)[0] assert not torch.isnan(output).any(), "Loaded model output contains NaN" def _test_quantized_layers(self, config_kwargs): @@ -243,12 +237,12 @@ class QuantizationTesterMixin: self._verify_if_layer_quantized(name, module, config_kwargs) num_quantized_layers += 1 - assert num_quantized_layers > 0, ( - f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" - ) - assert num_quantized_layers == expected_quantized_layers, ( - f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" - ) + assert ( + num_quantized_layers > 0 + ), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" + assert ( + num_quantized_layers == expected_quantized_layers + ), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert): """ @@ -272,9 +266,9 @@ class QuantizationTesterMixin: if any(excluded in name for excluded in modules_to_not_convert): found_excluded = True # This module should NOT be quantized - assert not self._is_module_quantized(module), ( - f"Module {name} should not be quantized but was found to be quantized" - ) + assert not self._is_module_quantized( + module + ), f"Module {name} should not be quantized but was found to be quantized" assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}" @@ -296,9 +290,9 @@ class QuantizationTesterMixin: mem_with_exclusion = model_with_exclusion.get_memory_footprint() mem_fully_quantized = model_fully_quantized.get_memory_footprint() - assert mem_with_exclusion > mem_fully_quantized, ( - f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" - ) + assert ( + mem_with_exclusion > mem_fully_quantized + ), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" def _test_quantization_device_map(self, config_kwargs): """ @@ -316,12 +310,38 @@ class QuantizationTesterMixin: # Verify inference works with torch.no_grad(): inputs = self.get_dummy_inputs() - output = model(**inputs) - if isinstance(output, tuple): - output = output[0] + output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" assert not torch.isnan(output).any(), "Model output contains NaN" + def _test_dequantize(self, config_kwargs): + """ + Test that dequantize() converts quantized model back to standard linear layers. + + Args: + config_kwargs: Quantization config parameters + """ + model = self._create_quantized_model(config_kwargs) + + # Verify model has dequantize method + if not hasattr(model, "dequantize"): + pytest.skip("Model does not have dequantize method") + + # Dequantize the model + model.dequantize() + + # Verify no modules are quantized after dequantization + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()" + + # Verify inference still works after dequantization + with torch.no_grad(): + inputs = self.get_dummy_inputs() + output = model(**inputs, return_dict=False)[0] + assert output is not None, "Model output is None after dequantization" + assert not torch.isnan(output).any(), "Model output contains NaN after dequantization" + @is_bitsandbytes @require_accelerator @@ -379,9 +399,9 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin): def _verify_if_layer_quantized(self, name, module, config_kwargs): expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params - assert module.weight.__class__ == expected_weight_class, ( - f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" - ) + assert ( + module.weight.__class__ == expected_weight_class + ), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys())) def test_bnb_quantization_num_parameters(self, config_name): @@ -449,13 +469,13 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin): for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules): - assert module.weight.dtype == torch.float32, ( - f"Module {name} should be FP32 but is {module.weight.dtype}" - ) + assert ( + module.weight.dtype == torch.float32 + ), f"Module {name} should be FP32 but is {module.weight.dtype}" else: - assert module.weight.dtype == torch.uint8, ( - f"Module {name} should be uint8 but is {module.weight.dtype}" - ) + assert ( + module.weight.dtype == torch.uint8 + ), f"Module {name} should be uint8 but is {module.weight.dtype}" with torch.no_grad(): inputs = self.get_dummy_inputs() @@ -476,6 +496,10 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin): """Test that device_map='auto' works correctly with quantization.""" self._test_quantization_device_map(self.BNB_CONFIGS["4bit_nf4"]) + def test_bnb_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize(self.BNB_CONFIGS["4bit_nf4"]) + @is_quanto @require_quanto @@ -563,6 +587,10 @@ class QuantoTesterMixin(QuantizationTesterMixin): """Test that device_map='auto' works correctly with quantization.""" self._test_quantization_device_map(self.QUANTO_WEIGHT_TYPES["int8"]) + def test_quanto_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize(self.QUANTO_WEIGHT_TYPES["int8"]) + @is_torchao @require_accelerator @@ -649,6 +677,10 @@ class TorchAoTesterMixin(QuantizationTesterMixin): """Test that device_map='auto' works correctly with quantization.""" self._test_quantization_device_map(self.TORCHAO_QUANT_TYPES["int8wo"]) + def test_torchao_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize(self.TORCHAO_QUANT_TYPES["int8wo"]) + @is_gguf @require_accelerate @@ -716,24 +748,9 @@ class GGUFTesterMixin(QuantizationTesterMixin): def test_gguf_quantization_lora_inference(self): self._test_quantization_lora_inference({"compute_dtype": torch.bfloat16}) - def test_gguf_dequantize_model(self): - from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter - - model = self._create_quantized_model() - model.dequantize() - - def _check_for_gguf_linear(model): - has_children = list(model.children()) - if not has_children: - return - - for name, module in model.named_children(): - if isinstance(module, torch.nn.Linear): - assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear" - assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter" - - for name, module in model.named_children(): - _check_for_gguf_linear(module) + def test_gguf_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize({"compute_dtype": torch.bfloat16}) def test_gguf_quantized_layers(self): self._test_quantized_layers({"compute_dtype": torch.bfloat16}) @@ -826,3 +843,7 @@ class ModelOptTesterMixin(QuantizationTesterMixin): def test_modelopt_device_map(self): """Test that device_map='auto' works correctly with quantization.""" self._test_quantization_device_map(self.MODELOPT_CONFIGS["fp8"]) + + def test_modelopt_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize(self.MODELOPT_CONFIGS["fp8"]) diff --git a/tests/models/testing_utils/training.py b/tests/models/testing_utils/training.py index 7e4193d59e..f6612dd3be 100644 --- a/tests/models/testing_utils/training.py +++ b/tests/models/testing_utils/training.py @@ -50,10 +50,7 @@ class TrainingTesterMixin: model = self.model_class(**init_dict) model.to(torch_device) model.train() - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + output = model(**inputs_dict, return_dict=False)[0] noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) @@ -68,10 +65,7 @@ class TrainingTesterMixin: model.train() ema_model = EMAModel(model.parameters()) - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + output = model(**inputs_dict, return_dict=False)[0] noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) @@ -137,9 +131,7 @@ class TrainingTesterMixin: assert not model.is_gradient_checkpointing and model.training - out = model(**inputs_dict) - if isinstance(out, dict): - out = out.sample if hasattr(out, "sample") else out.to_tuple()[0] + out = model(**inputs_dict, return_dict=False)[0] # run the backwards pass on the model model.zero_grad() @@ -158,9 +150,7 @@ class TrainingTesterMixin: assert model_2.is_gradient_checkpointing and model_2.training - out_2 = model_2(**inputs_dict_copy) - if isinstance(out_2, dict): - out_2 = out_2.sample if hasattr(out_2, "sample") else out_2.to_tuple()[0] + out_2 = model_2(**inputs_dict_copy, return_dict=False)[0] # run the backwards pass on the model model_2.zero_grad() @@ -198,10 +188,7 @@ class TrainingTesterMixin: # Test with float16 if torch.device(torch_device).type != "cpu": with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + output = model(**inputs_dict, return_dict=False)[0] noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) @@ -212,10 +199,7 @@ class TrainingTesterMixin: if torch.device(torch_device).type != "cpu": model.zero_grad() with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + output = model(**inputs_dict, return_dict=False)[0] noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) diff --git a/utils/generate_model_tests.py b/utils/generate_model_tests.py index ffd600dfdf..f3860f4b9a 100644 --- a/utils/generate_model_tests.py +++ b/utils/generate_model_tests.py @@ -43,10 +43,15 @@ ATTRIBUTE_TO_TESTER = { ALWAYS_INCLUDE_TESTERS = [ "ModelTesterMixin", "MemoryTesterMixin", - "AttentionTesterMixin", "TorchCompileTesterMixin", ] +# Attention-related class names that indicate the model uses attention +ATTENTION_INDICATORS = { + "AttentionMixin", + "AttentionModuleMixin", +} + OPTIONAL_TESTERS = [ ("BitsAndBytesTesterMixin", "bnb"), ("QuantoTesterMixin", "quanto"), @@ -62,6 +67,17 @@ class ModelAnalyzer(ast.NodeVisitor): def __init__(self): self.model_classes = [] self.current_class = None + self.imports = set() + + def visit_Import(self, node: ast.Import): + for alias in node.names: + self.imports.add(alias.name.split(".")[-1]) + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom): + for alias in node.names: + self.imports.add(alias.name) + self.generic_visit(node) def visit_ClassDef(self, node: ast.ClassDef): base_names = [] @@ -164,7 +180,7 @@ class ModelAnalyzer(ast.NodeVisitor): return "" -def analyze_model_file(filepath: str) -> list[dict]: +def analyze_model_file(filepath: str) -> tuple[list[dict], set[str]]: with open(filepath) as f: source = f.read() @@ -172,10 +188,10 @@ def analyze_model_file(filepath: str) -> list[dict]: analyzer = ModelAnalyzer() analyzer.visit(tree) - return analyzer.model_classes + return analyzer.model_classes, analyzer.imports -def determine_testers(model_info: dict, include_optional: list[str]) -> list[str]: +def determine_testers(model_info: dict, include_optional: list[str], imports: set[str]) -> list[str]: testers = list(ALWAYS_INCLUDE_TESTERS) for base in model_info["bases"]: @@ -195,6 +211,10 @@ def determine_testers(model_info: dict, include_optional: list[str]) -> list[str if "ContextParallelTesterMixin" not in testers: testers.append("ContextParallelTesterMixin") + # Include AttentionTesterMixin if the model imports attention-related classes + if imports & ATTENTION_INDICATORS: + testers.append("AttentionTesterMixin") + for tester, flag in OPTIONAL_TESTERS: if flag in include_optional: if tester not in testers: @@ -335,9 +355,9 @@ def generate_test_class(model_name: str, config_class: str, tester: str) -> str: return "\n".join(lines) -def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str]) -> str: +def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str]) -> str: model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "") - testers = determine_testers(model_info, include_optional) + testers = determine_testers(model_info, include_optional, imports) tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"}) lines = [ @@ -446,7 +466,7 @@ def main(): print(f"Error: File not found: {args.model_filepath}", file=sys.stderr) sys.exit(1) - model_classes = analyze_model_file(args.model_filepath) + model_classes, imports = analyze_model_file(args.model_filepath) if not model_classes: print(f"Error: No model classes found in {args.model_filepath}", file=sys.stderr) @@ -468,7 +488,7 @@ def main(): if "all" in include_optional: include_optional = [flag for _, flag in OPTIONAL_TESTERS] - generated_code = generate_test_file(model_info, args.model_filepath, include_optional) + generated_code = generate_test_file(model_info, args.model_filepath, include_optional, imports) if args.dry_run: print(generated_code)