From d08e0bb545741c118f7f3eb5864164c733ea788e Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 15 Dec 2025 14:19:27 +0530 Subject: [PATCH] update --- tests/models/testing_utils/__init__.py | 8 +- tests/models/testing_utils/attention.py | 27 +- tests/models/testing_utils/common.py | 451 ++++++++++-------- tests/models/testing_utils/lora.py | 442 ++++++++++++++--- tests/models/testing_utils/memory.py | 57 ++- tests/models/testing_utils/single_file.py | 6 +- .../test_models_transformer_flux.py | 35 +- tests/testing_utils.py | 53 ++ 8 files changed, 793 insertions(+), 286 deletions(-) diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py index e72a3c928b..229179737a 100644 --- a/tests/models/testing_utils/__init__.py +++ b/tests/models/testing_utils/__init__.py @@ -1,8 +1,8 @@ from .attention import AttentionTesterMixin, ContextParallelTesterMixin -from .common import ModelTesterMixin +from .common import BaseModelTesterConfig, ModelTesterMixin from .compile import TorchCompileTesterMixin from .ip_adapter import IPAdapterTesterMixin -from .lora import LoraTesterMixin +from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin from .quantization import ( BitsAndBytesTesterMixin, @@ -17,14 +17,16 @@ from .training import TrainingTesterMixin __all__ = [ - "ContextParallelTesterMixin", "AttentionTesterMixin", + "BaseModelTesterConfig", "BitsAndBytesTesterMixin", + "ContextParallelTesterMixin", "CPUOffloadTesterMixin", "GGUFTesterMixin", "GroupOffloadTesterMixin", "IPAdapterTesterMixin", "LayerwiseCastingTesterMixin", + "LoraHotSwappingForModelTesterMixin", "LoraTesterMixin", "MemoryTesterMixin", "ModelOptTesterMixin", diff --git a/tests/models/testing_utils/attention.py b/tests/models/testing_utils/attention.py index f794a7a0aa..45443046fb 100644 --- a/tests/models/testing_utils/attention.py +++ b/tests/models/testing_utils/attention.py @@ -25,7 +25,13 @@ from diffusers.models.attention_processor import ( AttnProcessor, ) -from ...testing_utils import is_attention, is_context_parallel, require_torch_multi_accelerator, torch_device +from ...testing_utils import ( + assert_tensors_close, + is_attention, + is_context_parallel, + require_torch_multi_accelerator, + torch_device, +) @is_attention @@ -89,8 +95,12 @@ class AttentionTesterMixin: output_after_fusion = output_after_fusion.to_tuple()[0] # Verify outputs match - assert torch.allclose(output_before_fusion, output_after_fusion, atol=self.base_precision), ( - "Output should not change after fusing projections" + assert_tensors_close( + output_before_fusion, + output_after_fusion, + atol=self.base_precision, + rtol=0, + msg="Output should not change after fusing projections", ) # Unfuse projections @@ -110,8 +120,12 @@ class AttentionTesterMixin: output_after_unfusion = output_after_unfusion.to_tuple()[0] # Verify outputs still match - assert torch.allclose(output_before_fusion, output_after_unfusion, atol=self.base_precision), ( - "Output should match original after unfusing projections" + assert_tensors_close( + output_before_fusion, + output_after_unfusion, + atol=self.base_precision, + rtol=0, + msg="Output should match original after unfusing projections", ) def test_get_set_processor(self): @@ -238,9 +252,6 @@ class ContextParallelTesterMixin: if not torch.distributed.is_available(): pytest.skip("torch.distributed is not available.") - if not torch.cuda.is_available() or torch.cuda.device_count() < 2: - pytest.skip("Context parallel requires at least 2 CUDA devices.") - if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index 9f4ae271f9..11c10c4557 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -15,8 +15,8 @@ import json import os -import tempfile from collections import defaultdict +from typing import Any, Dict, Optional, Type import pytest import torch @@ -26,7 +26,7 @@ from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, d from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator -from ...testing_utils import CaptureLogger, torch_device +from ...testing_utils import assert_tensors_close, torch_device def named_persistent_module_tensors( @@ -130,40 +130,144 @@ def check_device_map_is_respected(model, device_map): ) +class BaseModelTesterConfig: + """ + Base class defining the configuration interface for model testing. + + This class defines the contract that all model test classes must implement. + It provides a consistent interface for accessing model configuration, initialization + parameters, and test inputs across all testing mixins. + + Required properties (must be implemented by subclasses): + - model_class: The model class to test + + Optional properties (can be overridden, have sensible defaults): + - pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None) + - pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {}) + - output_shape: Expected output shape for output validation tests (default: None) + - base_precision: Default tolerance for floating point comparisons (default: 1e-3) + - model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7]) + + Required methods (must be implemented by subclasses): + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Example usage: + class MyModelTestConfig(BaseModelTesterConfig): + @property + def model_class(self): + return MyModel + + @property + def pretrained_model_name_or_path(self): + return "org/my-model" + + @property + def output_shape(self): + return (1, 3, 32, 32) + + def get_init_dict(self): + return {"in_channels": 3, "out_channels": 3} + + def get_dummy_inputs(self): + return {"sample": torch.randn(1, 3, 32, 32, device=torch_device)} + + class TestMyModel(MyModelTestConfig, ModelTesterMixin, QuantizationTesterMixin): + pass + """ + + # ==================== Required Properties ==================== + + @property + def model_class(self) -> Type[nn.Module]: + """The model class to test. Must be implemented by subclasses.""" + raise NotImplementedError("Subclasses must implement the `model_class` property.") + + # ==================== Optional Properties ==================== + + @property + def pretrained_model_name_or_path(self) -> Optional[str]: + """Hub repository ID for the pretrained model (used for quantization and hub tests).""" + return None + + @property + def pretrained_model_kwargs(self) -> Dict[str, Any]: + """Additional kwargs to pass to from_pretrained (e.g., subfolder, variant).""" + return {} + + @property + def output_shape(self) -> Optional[tuple]: + """Expected output shape for output validation tests.""" + return None + + @property + def model_split_percents(self) -> list: + """Percentages for model parallelism tests.""" + return [0.5, 0.7] + + # ==================== Required Methods ==================== + + def get_init_dict(self) -> Dict[str, Any]: + """ + Returns dict of arguments to initialize the model. + + Returns: + Dict[str, Any]: Initialization arguments for the model constructor. + + Example: + return { + "in_channels": 3, + "out_channels": 3, + "sample_size": 32, + } + """ + raise NotImplementedError("Subclasses must implement `get_init_dict()`.") + + def get_dummy_inputs(self) -> Dict[str, Any]: + """ + Returns dict of inputs to pass to the model forward pass. + + Returns: + Dict[str, Any]: Input tensors/values for model.forward(). + + Example: + return { + "sample": torch.randn(1, 3, 32, 32, device=torch_device), + "timestep": torch.tensor([1], device=torch_device), + } + """ + raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.") + + class ModelTesterMixin: """ Base mixin class for model testing with common test methods. - Expected class attributes to be set by subclasses: + This mixin expects the test class to also inherit from BaseModelTesterConfig + (or implement its interface) which provides: - model_class: The model class to test - - main_input_name: Name of the main input tensor (e.g., "sample", "hidden_states") - - base_precision: Default tolerance for floating point comparisons (default: 1e-3) - - Expected methods to be implemented by subclasses: - get_init_dict(): Returns dict of arguments to initialize the model - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Example: + class MyModelTestConfig(BaseModelTesterConfig): + model_class = MyModel + def get_init_dict(self): ... + def get_dummy_inputs(self): ... + + class TestMyModel(MyModelTestConfig, ModelTesterMixin): + pass """ - model_class = None - base_precision = 1e-3 - model_split_percents = [0.5, 0.7] - - def get_init_dict(self): - raise NotImplementedError("get_init_dict must be implemented by subclasses. ") - - def get_dummy_inputs(self): - raise NotImplementedError("get_dummy_inputs must be implemented by subclasses. It should return inputs_dict.") - - def test_from_save_pretrained(self, expected_max_diff=5e-5): + def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=0): torch.manual_seed(0) model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - new_model = self.model_class.from_pretrained(tmpdirname) - new_model.to(torch_device) + model.save_pretrained(tmp_path) + new_model = self.model_class.from_pretrained(tmp_path) + new_model.to(torch_device) # check if all parameters shape are the same for param_name in model.state_dict().keys(): @@ -184,28 +288,24 @@ class ModelTesterMixin: if isinstance(new_image, dict): new_image = new_image.to_tuple()[0] - max_diff = (image - new_image).abs().max().item() - assert max_diff <= expected_max_diff, ( - f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" - ) + assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") - def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): + def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0): model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname, variant="fp16") - new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16") + model.save_pretrained(tmp_path, variant="fp16") + new_model = self.model_class.from_pretrained(tmp_path, variant="fp16") - # non-variant cannot be loaded - with pytest.raises(OSError) as exc_info: - self.model_class.from_pretrained(tmpdirname) + # non-variant cannot be loaded + with pytest.raises(OSError) as exc_info: + self.model_class.from_pretrained(tmp_path) - # make sure that error message states what keys are missing - assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value) + # make sure that error message states what keys are missing + assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value) - new_model.to(torch_device) + new_model.to(torch_device) with torch.no_grad(): image = model(**self.get_dummy_inputs()) @@ -217,35 +317,27 @@ class ModelTesterMixin: if isinstance(new_image, dict): new_image = new_image.to_tuple()[0] - max_diff = (image - new_image).abs().max().item() - assert max_diff <= expected_max_diff, ( - f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" - ) + assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") - def test_from_save_pretrained_dtype(self): + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) + def test_from_save_pretrained_dtype(self, tmp_path, dtype): model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() - for dtype in [torch.float32, torch.float16, torch.bfloat16]: - if torch_device == "mps" and dtype == torch.bfloat16: - continue - with tempfile.TemporaryDirectory() as tmpdirname: - model.to(dtype) - model.save_pretrained(tmpdirname) - new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype) - assert new_model.dtype == dtype - if ( - hasattr(self.model_class, "_keep_in_fp32_modules") - and self.model_class._keep_in_fp32_modules is None - ): - # When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None - new_model = self.model_class.from_pretrained( - tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype - ) - assert new_model.dtype == dtype + if torch_device == "mps" and dtype == torch.bfloat16: + pytest.skip(reason=f"{dtype} is not supported on {torch_device}") - def test_determinism(self, expected_max_diff=1e-5): + model.to(dtype) + model.save_pretrained(tmp_path) + new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=True, torch_dtype=dtype) + assert new_model.dtype == dtype + if hasattr(self.model_class, "_keep_in_fp32_modules") and self.model_class._keep_in_fp32_modules is None: + # When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None + new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype) + assert new_model.dtype == dtype + + def test_determinism(self, atol=1e-5, rtol=0): model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() @@ -259,18 +351,15 @@ class ModelTesterMixin: if isinstance(second, dict): second = second.to_tuple()[0] - # Remove NaN values and compute max difference + # Filter out NaN values before comparison first_flat = first.flatten() second_flat = second.flatten() - - # Filter out NaN values mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat)) first_filtered = first_flat[mask] second_filtered = second_flat[mask] - max_diff = torch.abs(first_filtered - second_filtered).max().item() - assert max_diff <= expected_max_diff, ( - f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}" + assert_tensors_close( + first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic" ) def test_output(self, expected_output_shape=None): @@ -310,13 +399,12 @@ class ModelTesterMixin: elif tuple_object is None: return else: - assert torch.allclose( - set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 - ), ( - "Tuple and dict output are not equal. Difference:" - f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" - f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." + assert_tensors_close( + set_nan_tensor_to_zero(tuple_object), + set_nan_tensor_to_zero(dict_object), + atol=1e-5, + rtol=0, + msg="Tuple and dict output are not equal", ) model = self.model_class(**self.get_init_dict()) @@ -329,7 +417,7 @@ class ModelTesterMixin: recursive_check(outputs_tuple, outputs_dict) - def test_getattr_is_correct(self): + def test_getattr_is_correct(self, caplog): init_dict = self.get_init_dict() model = self.model_class(**init_dict) @@ -337,28 +425,26 @@ class ModelTesterMixin: model.dummy_attribute = 5 model.register_to_config(test_attribute=5) - logger = logging.get_logger("diffusers.models.modeling_utils") - # 30 for warning - logger.setLevel(30) - with CaptureLogger(logger) as cap_logger: + logger_name = "diffusers.models.modeling_utils" + with caplog.at_level(logging.WARNING, logger=logger_name): + caplog.clear() assert hasattr(model, "dummy_attribute") assert getattr(model, "dummy_attribute") == 5 assert model.dummy_attribute == 5 # no warning should be thrown - assert cap_logger.out == "" + assert caplog.text == "" - logger = logging.get_logger("diffusers.models.modeling_utils") - # 30 for warning - logger.setLevel(30) - with CaptureLogger(logger) as cap_logger: + with caplog.at_level(logging.WARNING, logger=logger_name): + caplog.clear() assert hasattr(model, "save_pretrained") fn = model.save_pretrained fn_1 = getattr(model, "save_pretrained") assert fn == fn_1 + # no warning should be thrown - assert cap_logger.out == "" + assert caplog.text == "" # warning should be thrown for config attributes accessed directly with pytest.warns(FutureWarning): @@ -399,32 +485,34 @@ class ModelTesterMixin: torch_device not in ["cuda", "xpu"], reason="float16 and bfloat16 can only be use for inference with an accelerator", ) - def test_from_save_pretrained_float16_bfloat16(self): + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): model = self.model_class(**self.get_init_dict()) model.to(torch_device) fp32_modules = model._keep_in_fp32_modules - with tempfile.TemporaryDirectory() as tmp_dir: - for torch_dtype in [torch.bfloat16, torch.float16]: - model.to(torch_dtype).save_pretrained(tmp_dir) - model_loaded = self.model_class.from_pretrained(tmp_dir, torch_dtype=torch_dtype).to(torch_device) + model.to(dtype).save_pretrained(tmp_path) + model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device) - for name, param in model_loaded.named_parameters(): - if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules): - assert param.data.dtype == torch.float32 - else: - assert param.data.dtype == torch_dtype + for name, param in model_loaded.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules): + assert param.data.dtype == torch.float32 + else: + assert param.data.dtype == dtype - with torch.no_grad(): - output = model(**self.get_dummy_inputs()) - output_loaded = model_loaded(**self.get_dummy_inputs()) + with torch.no_grad(): + output = model(**self.get_dummy_inputs()) + if isinstance(output, dict): + output = output.to_tuple()[0] - assert torch.allclose(output, output_loaded, atol=1e-4), ( - f"Loaded model output differs for {torch_dtype}" - ) + output_loaded = model_loaded(**self.get_dummy_inputs()) + if isinstance(output_loaded, dict): + output_loaded = output_loaded.to_tuple()[0] + + assert_tensors_close(output, output_loaded, atol=1e-4, rtol=0, msg=f"Loaded model output differs for {dtype}") @require_accelerator - def test_sharded_checkpoints(self): + def test_sharded_checkpoints(self, tmp_path): torch.manual_seed(0) config = self.get_init_dict() inputs_dict = self.get_dummy_inputs() @@ -435,30 +523,30 @@ class ModelTesterMixin: model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small - with tempfile.TemporaryDirectory() as tmp_dir: - model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") - assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist" - # Check if the right number of shards exists - expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) - actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - assert actual_num_shards == expected_num_shards, ( - f"Expected {expected_num_shards} shards, got {actual_num_shards}" - ) + model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB") + assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist" - new_model = self.model_class.from_pretrained(tmp_dir).eval() - new_model = new_model.to(torch_device) + # Check if the right number of shards exists + expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)) + actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")]) + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) - torch.manual_seed(0) - inputs_dict_new = self.get_dummy_inputs() - new_output = new_model(**inputs_dict_new) + new_model = self.model_class.from_pretrained(tmp_path).eval() + new_model = new_model.to(torch_device) - assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( - "Output should match after sharded save/load" - ) + torch.manual_seed(0) + inputs_dict_new = self.get_dummy_inputs() + new_output = new_model(**inputs_dict_new) + + assert_tensors_close( + base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after sharded save/load" + ) @require_accelerator - def test_sharded_checkpoints_with_variant(self): + def test_sharded_checkpoints_with_variant(self, tmp_path): torch.manual_seed(0) config = self.get_init_dict() inputs_dict = self.get_dummy_inputs() @@ -470,35 +558,33 @@ class ModelTesterMixin: model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small variant = "fp16" - with tempfile.TemporaryDirectory() as tmp_dir: - model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant) - index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) - assert os.path.exists(os.path.join(tmp_dir, index_filename)), ( - f"Variant index file {index_filename} should exist" - ) + model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB", variant=variant) - # Check if the right number of shards exists - expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, index_filename)) - actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - assert actual_num_shards == expected_num_shards, ( - f"Expected {expected_num_shards} shards, got {actual_num_shards}" - ) + index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + assert os.path.exists(os.path.join(tmp_path, index_filename)), ( + f"Variant index file {index_filename} should exist" + ) - new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval() - new_model = new_model.to(torch_device) + # Check if the right number of shards exists + expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, index_filename)) + actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")]) + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) - torch.manual_seed(0) - inputs_dict_new = self.get_dummy_inputs() - new_output = new_model(**inputs_dict_new) + new_model = self.model_class.from_pretrained(tmp_path, variant=variant).eval() + new_model = new_model.to(torch_device) - assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( - "Output should match after variant sharded save/load" - ) + torch.manual_seed(0) + inputs_dict_new = self.get_dummy_inputs() + new_output = new_model(**inputs_dict_new) - def test_sharded_checkpoints_with_parallel_loading(self): - import time + assert_tensors_close( + base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load" + ) + def test_sharded_checkpoints_with_parallel_loading(self, tmp_path): from diffusers.utils import constants torch.manual_seed(0) @@ -517,47 +603,37 @@ class ModelTesterMixin: original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None) try: - with tempfile.TemporaryDirectory() as tmp_dir: - model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") - assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist" + model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB") + assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist" - # Check if the right number of shards exists - expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) - actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - assert actual_num_shards == expected_num_shards, ( - f"Expected {expected_num_shards} shards, got {actual_num_shards}" - ) + # Check if the right number of shards exists + expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)) + actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")]) + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) - # Load without parallel loading - constants.HF_ENABLE_PARALLEL_LOADING = False - start_time = time.time() - model_sequential = self.model_class.from_pretrained(tmp_dir).eval() - sequential_load_time = time.time() - start_time - model_sequential = model_sequential.to(torch_device) + # Load without parallel loading + constants.HF_ENABLE_PARALLEL_LOADING = False + model_sequential = self.model_class.from_pretrained(tmp_path).eval() + model_sequential = model_sequential.to(torch_device) - torch.manual_seed(0) + # Load with parallel loading + constants.HF_ENABLE_PARALLEL_LOADING = True + constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2 - # Load with parallel loading - constants.HF_ENABLE_PARALLEL_LOADING = True - constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2 + torch.manual_seed(0) + model_parallel = self.model_class.from_pretrained(tmp_path).eval() + model_parallel = model_parallel.to(torch_device) - start_time = time.time() - model_parallel = self.model_class.from_pretrained(tmp_dir).eval() - parallel_load_time = time.time() - start_time - model_parallel = model_parallel.to(torch_device) + torch.manual_seed(0) + inputs_dict_parallel = self.get_dummy_inputs() + output_parallel = model_parallel(**inputs_dict_parallel) - torch.manual_seed(0) - inputs_dict_parallel = self.get_dummy_inputs() - output_parallel = model_parallel(**inputs_dict_parallel) + assert_tensors_close( + base_output[0], output_parallel[0], atol=1e-5, rtol=0, msg="Output should match with parallel loading" + ) - assert torch.allclose(base_output[0], output_parallel[0], atol=1e-5), ( - "Output should match with parallel loading" - ) - - # Verify parallel loading is faster or at least not significantly slower - assert parallel_load_time < sequential_load_time, ( - f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s" - ) finally: # Restore original values constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading @@ -565,7 +641,7 @@ class ModelTesterMixin: constants.HF_PARALLEL_WORKERS = original_parallel_workers @require_torch_multi_accelerator - def test_model_parallelism(self): + def test_model_parallelism(self, tmp_path): if self.model_class._no_split_modules is None: pytest.skip("Test not supported for this model as `_no_split_modules` is not set.") @@ -581,20 +657,19 @@ class ModelTesterMixin: model_size = compute_module_sizes(model)[""] max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] - with tempfile.TemporaryDirectory() as tmp_dir: - model.cpu().save_pretrained(tmp_dir) + model.cpu().save_pretrained(tmp_path) - for max_size in max_gpu_sizes: - max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2} - new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) - # Making sure part of the model will be on GPU 0 and GPU 1 - assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs" + for max_size in max_gpu_sizes: + max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2} + new_model = self.model_class.from_pretrained(tmp_path, device_map="auto", max_memory=max_memory) + # Making sure part of the model will be on GPU 0 and GPU 1 + assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs" - check_device_map_is_respected(new_model, new_model.hf_device_map) + check_device_map_is_respected(new_model, new_model.hf_device_map) - torch.manual_seed(0) - new_output = new_model(**inputs_dict) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) - assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( - "Output should match with model parallelism" - ) + assert_tensors_close( + base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with model parallelism" + ) diff --git a/tests/models/testing_utils/lora.py b/tests/models/testing_utils/lora.py index 6777c164f2..b790e3ea26 100644 --- a/tests/models/testing_utils/lora.py +++ b/tests/models/testing_utils/lora.py @@ -13,17 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import json import os -import tempfile +import re import pytest import safetensors.torch import torch +import torch.nn as nn +from diffusers.utils.import_utils import is_peft_available from diffusers.utils.testing_utils import check_if_dicts_are_equal -from ...testing_utils import is_lora, require_peft_backend, torch_device +from ...testing_utils import ( + assert_tensors_close, + backend_empty_cache, + is_lora, + is_torch_compile, + require_peft_backend, + require_peft_version_greater, + require_torch_accelerator, + require_torch_version_greater, + torch_device, +) + + +if is_peft_available(): + from diffusers.loaders.peft import PeftAdapterMixin def check_if_lora_correctly_set(model) -> bool: @@ -67,7 +84,7 @@ class LoraTesterMixin: if not issubclass(self.model_class, PeftAdapterMixin): pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).") - def test_save_load_lora_adapter(self, rank=4, lora_alpha=4, use_dora=False): + def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False): from peft import LoraConfig from peft.utils import get_peft_model_state_dict @@ -95,26 +112,25 @@ class LoraTesterMixin: "Output should differ with LoRA enabled" ) - with tempfile.TemporaryDirectory() as tmpdir: - model.save_lora_adapter(tmpdir) - assert os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")), ( - "LoRA weights file not created" - ) + model.save_lora_adapter(tmp_path) + assert os.path.isfile(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")), ( + "LoRA weights file not created" + ) - state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + state_dict_loaded = safetensors.torch.load_file(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")) - model.unload_lora() - assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" - model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) - state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0") + model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True) + state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0") - for k in state_dict_loaded: - loaded_v = state_dict_loaded[k] - retrieved_v = state_dict_retrieved[k].to(loaded_v.device) - assert torch.allclose(loaded_v, retrieved_v), f"Mismatch in LoRA weight {k}" + for k in state_dict_loaded: + loaded_v = state_dict_loaded[k] + retrieved_v = state_dict_retrieved[k].to(loaded_v.device) + assert_tensors_close(loaded_v, retrieved_v, atol=1e-5, rtol=0, msg=f"Mismatch in LoRA weight {k}") - assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload" + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload" torch.manual_seed(0) outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] @@ -122,11 +138,15 @@ class LoraTesterMixin: assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), ( "Output should differ with LoRA enabled" ) - assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), ( - "Outputs should match before and after save/load" + assert_tensors_close( + outputs_with_lora, + outputs_with_lora_2, + atol=1e-4, + rtol=1e-4, + msg="Outputs should match before and after save/load", ) - def test_lora_wrong_adapter_name_raises_error(self): + def test_lora_wrong_adapter_name_raises_error(self, tmp_path): from peft import LoraConfig init_dict = self.get_init_dict() @@ -142,14 +162,13 @@ class LoraTesterMixin: model.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" - with tempfile.TemporaryDirectory() as tmpdir: - wrong_name = "foo" - with pytest.raises(ValueError) as exc_info: - model.save_lora_adapter(tmpdir, adapter_name=wrong_name) + wrong_name = "foo" + with pytest.raises(ValueError) as exc_info: + model.save_lora_adapter(tmp_path, adapter_name=wrong_name) - assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value) + assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value) - def test_lora_adapter_metadata_is_loaded_correctly(self, rank=4, lora_alpha=4, use_dora=False): + def test_lora_adapter_metadata_is_loaded_correctly(self, tmp_path, rank=4, lora_alpha=4, use_dora=False): from peft import LoraConfig init_dict = self.get_init_dict() @@ -166,19 +185,18 @@ class LoraTesterMixin: metadata = model.peft_config["default"].to_dict() assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" - with tempfile.TemporaryDirectory() as tmpdir: - model.save_lora_adapter(tmpdir) - model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") - assert os.path.isfile(model_file), "LoRA weights file not created" + model.save_lora_adapter(tmp_path) + model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors") + assert os.path.isfile(model_file), "LoRA weights file not created" - model.unload_lora() - assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" - model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) - parsed_metadata = model.peft_config["default_0"].to_dict() - check_if_dicts_are_equal(metadata, parsed_metadata) + model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True) + parsed_metadata = model.peft_config["default_0"].to_dict() + check_if_dicts_are_equal(metadata, parsed_metadata) - def test_lora_adapter_wrong_metadata_raises_error(self): + def test_lora_adapter_wrong_metadata_raises_error(self, tmp_path): from peft import LoraConfig from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY @@ -196,25 +214,337 @@ class LoraTesterMixin: model.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" - with tempfile.TemporaryDirectory() as tmpdir: - model.save_lora_adapter(tmpdir) - model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") - assert os.path.isfile(model_file), "LoRA weights file not created" + model.save_lora_adapter(tmp_path) + model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors") + assert os.path.isfile(model_file), "LoRA weights file not created" - # Perturb the metadata in the state dict - loaded_state_dict = safetensors.torch.load_file(model_file) - metadata = {"format": "pt"} - lora_adapter_metadata = denoiser_lora_config.to_dict() - lora_adapter_metadata.update({"foo": 1, "bar": 2}) - for key, value in lora_adapter_metadata.items(): - if isinstance(value, set): - lora_adapter_metadata[key] = list(value) - metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) - safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata) + # Perturb the metadata in the state dict + loaded_state_dict = safetensors.torch.load_file(model_file) + metadata = {"format": "pt"} + lora_adapter_metadata = denoiser_lora_config.to_dict() + lora_adapter_metadata.update({"foo": 1, "bar": 2}) + for key, value in lora_adapter_metadata.items(): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata) - model.unload_lora() - assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" - with pytest.raises(TypeError) as exc_info: - model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) - assert "`LoraConfig` class could not be instantiated" in str(exc_info.value) + with pytest.raises(TypeError) as exc_info: + model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True) + assert "`LoraConfig` class could not be instantiated" in str(exc_info.value) + + +@is_lora +@is_torch_compile +@require_peft_backend +@require_peft_version_greater("0.14.0") +@require_torch_version_greater("2.7.1") +@require_torch_accelerator +class LoraHotSwappingForModelTesterMixin: + """ + Mixin class for testing LoRA hot swapping functionality on models. + + Test that hotswapping does not result in recompilation on the model directly. + We're not extensively testing the hotswapping functionality since it is implemented in PEFT + and is extensively tested there. The goal of this test is specifically to ensure that + hotswapping with diffusers does not require recompilation. + + See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252 + for the analogous PEFT test. + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + - different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic compilation tests + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest marks: lora, torch_compile + Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests + """ + + different_shapes_for_compilation = None + + def setup_method(self): + if not issubclass(self.model_class, PeftAdapterMixin): + pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).") + + def teardown_method(self): + # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model, + # there will be recompilation errors, as torch caches the model when run in the same process. + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + def _get_lora_config(self, lora_rank, lora_alpha, target_modules): + from peft import LoraConfig + + lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + target_modules=target_modules, + init_lora_weights=False, + use_dora=False, + ) + return lora_config + + def _get_linear_module_name_other_than_attn(self, model): + linear_names = [ + name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name + ] + return linear_names[0] + + def _check_model_hotswap(self, tmp_path, do_compile, rank0, rank1, target_modules0, target_modules1=None): + """ + Check that hotswapping works on a model. + + Steps: + - create 2 LoRA adapters and save them + - load the first adapter + - hotswap the second adapter + - check that the outputs are correct + - optionally compile the model + - optionally check if recompilations happen on different shapes + + Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would + fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is + fine. + """ + different_shapes = self.different_shapes_for_compilation + # create 2 adapters with different ranks and alphas + torch.manual_seed(0) + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + alpha0, alpha1 = rank0, rank1 + max_rank = max([rank0, rank1]) + if target_modules1 is None: + target_modules1 = target_modules0[:] + lora_config0 = self._get_lora_config(rank0, alpha0, target_modules0) + lora_config1 = self._get_lora_config(rank1, alpha1, target_modules1) + + model.add_adapter(lora_config0, adapter_name="adapter0") + with torch.inference_mode(): + torch.manual_seed(0) + output0_before = model(**inputs_dict)["sample"] + + model.add_adapter(lora_config1, adapter_name="adapter1") + model.set_adapter("adapter1") + with torch.inference_mode(): + torch.manual_seed(0) + output1_before = model(**inputs_dict)["sample"] + + # sanity checks: + tol = 5e-3 + assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol) + assert not (output0_before == 0).all() + assert not (output1_before == 0).all() + + # save the adapter checkpoints + model.save_lora_adapter(os.path.join(tmp_path, "0"), safe_serialization=True, adapter_name="adapter0") + model.save_lora_adapter(os.path.join(tmp_path, "1"), safe_serialization=True, adapter_name="adapter1") + del model + + # load the first adapter + torch.manual_seed(0) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + if do_compile or (rank0 != rank1): + # no need to prepare if the model is not compiled or if the ranks are identical + model.enable_lora_hotswap(target_rank=max_rank) + + file_name0 = os.path.join(os.path.join(tmp_path, "0"), "pytorch_lora_weights.safetensors") + file_name1 = os.path.join(os.path.join(tmp_path, "1"), "pytorch_lora_weights.safetensors") + model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) + + if do_compile: + model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None) + + with torch.inference_mode(): + # additionally check if dynamic compilation works. + if different_shapes is not None: + for height, width in different_shapes: + new_inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**new_inputs_dict) + else: + output0_after = model(**inputs_dict)["sample"] + assert_tensors_close( + output0_before, output0_after, atol=tol, rtol=tol, msg="Output mismatch after loading adapter0" + ) + + # hotswap the 2nd adapter + model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None) + + # we need to call forward to potentially trigger recompilation + with torch.inference_mode(): + if different_shapes is not None: + for height, width in different_shapes: + new_inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**new_inputs_dict) + else: + output1_after = model(**inputs_dict)["sample"] + assert_tensors_close( + output1_before, + output1_after, + atol=tol, + rtol=tol, + msg="Output mismatch after hotswapping to adapter1", + ) + + # check error when not passing valid adapter name + name = "does-not-exist" + msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name" + with pytest.raises(ValueError, match=re.escape(msg)): + model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_model(self, tmp_path, rank0, rank1): + self._check_model_hotswap( + tmp_path, do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"] + ) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_compiled_model_linear(self, tmp_path, rank0, rank1): + # It's important to add this context to raise an error on recompilation + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache(): + self._check_model_hotswap( + tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules + ) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_compiled_model_conv2d(self, tmp_path, rank0, rank1): + if "unet" not in self.model_class.__name__.lower(): + pytest.skip("Test only applies to UNet.") + + # It's important to add this context to raise an error on recompilation + target_modules = ["conv", "conv1", "conv2"] + with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache(): + self._check_model_hotswap( + tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules + ) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_compiled_model_both_linear_and_conv2d(self, tmp_path, rank0, rank1): + if "unet" not in self.model_class.__name__.lower(): + pytest.skip("Test only applies to UNet.") + + # It's important to add this context to raise an error on recompilation + target_modules = ["to_q", "conv"] + with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache(): + self._check_model_hotswap( + tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules + ) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_compiled_model_both_linear_and_other(self, tmp_path, rank0, rank1): + # In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping + # with `torch.compile()` for models that have both linear and conv layers. In this test, we check + # if we can target a linear layer from the transformer blocks and another linear layer from non-attention + # block. + target_modules = ["to_q"] + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + + target_modules.append(self._get_linear_module_name_other_than_attn(model)) + del model + + # It's important to add this context to raise an error on recompilation + with torch._dynamo.config.patch(error_on_recompile=True): + self._check_model_hotswap( + tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules + ) + + def test_enable_lora_hotswap_called_after_adapter_added_raises(self): + # ensure that enable_lora_hotswap is called before loading the first adapter + lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) + + msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.") + with pytest.raises(RuntimeError, match=msg): + model.enable_lora_hotswap(target_rank=32) + + def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog): + # ensure that enable_lora_hotswap is called before loading the first adapter + import logging + + lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) + msg = ( + "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." + ) + with caplog.at_level(logging.WARNING): + model.enable_lora_hotswap(target_rank=32, check_compiled="warn") + assert any(msg in record.message for record in caplog.records) + + def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog): + # check possibility to ignore the error/warning + import logging + + lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) + with caplog.at_level(logging.WARNING): + model.enable_lora_hotswap(target_rank=32, check_compiled="ignore") + assert len(caplog.records) == 0 + + def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): + # check that wrong argument value raises an error + lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) + msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.") + with pytest.raises(ValueError, match=msg): + model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") + + def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog): + # check the error and log + import logging + + # at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers + target_modules0 = ["to_q"] + target_modules1 = ["to_q", "to_k"] + with pytest.raises(RuntimeError): # peft raises RuntimeError + with caplog.at_level(logging.ERROR): + self._check_model_hotswap( + tmp_path, + do_compile=True, + rank0=8, + rank1=8, + target_modules0=target_modules0, + target_modules1=target_modules1, + ) + assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + @require_torch_version_greater("2.7.1") + def test_hotswapping_compile_on_different_shapes(self, tmp_path, rank0, rank1): + different_shapes_for_compilation = self.different_shapes_for_compilation + if different_shapes_for_compilation is None: + pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") + # Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic + # variable to represent input sizes that are the same. For more details, + # check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790). + torch.fx.experimental._config.use_duck_shape = False + + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + with torch._dynamo.config.patch(error_on_recompile=True): + self._check_model_hotswap( + tmp_path, + do_compile=True, + rank0=rank0, + rank1=rank1, + target_modules0=target_modules, + ) diff --git a/tests/models/testing_utils/memory.py b/tests/models/testing_utils/memory.py index 6cdc72b004..ebd76656f0 100644 --- a/tests/models/testing_utils/memory.py +++ b/tests/models/testing_utils/memory.py @@ -27,6 +27,7 @@ from diffusers.utils.testing_utils import _check_safetensors_serialization from diffusers.utils.torch_utils import get_torch_cuda_device_capability from ...testing_utils import ( + assert_tensors_close, backend_empty_cache, backend_max_memory_allocated, backend_reset_peak_memory_stats, @@ -122,8 +123,8 @@ class CPUOffloadTesterMixin: torch.manual_seed(0) new_output = new_model(**inputs_dict) - assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( - "Output should match with CPU offloading" + assert_tensors_close( + base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with CPU offloading" ) @require_offload_support @@ -156,7 +157,9 @@ class CPUOffloadTesterMixin: torch.manual_seed(0) new_output = new_model(**inputs_dict) - assert torch.allclose(base_output[0], new_output[0], atol=1e-5), "Output should match with disk offloading" + assert_tensors_close( + base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with disk offloading" + ) @require_offload_support def test_disk_offload_with_safetensors(self): @@ -183,8 +186,12 @@ class CPUOffloadTesterMixin: torch.manual_seed(0) new_output = new_model(**inputs_dict) - assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( - "Output should match with disk offloading (safetensors)" + assert_tensors_close( + base_output[0], + new_output[0], + atol=1e-5, + rtol=0, + msg="Output should match with disk offloading (safetensors)", ) @@ -247,17 +254,33 @@ class GroupOffloadTesterMixin: ) output_with_group_offloading4 = run_forward(model) - assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5), ( - "Output should match with block-level offloading" + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading1, + atol=1e-5, + rtol=0, + msg="Output should match with block-level offloading", ) - assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5), ( - "Output should match with non-blocking block-level offloading" + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading2, + atol=1e-5, + rtol=0, + msg="Output should match with non-blocking block-level offloading", ) - assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5), ( - "Output should match with leaf-level offloading" + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading3, + atol=1e-5, + rtol=0, + msg="Output should match with leaf-level offloading", ) - assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5), ( - "Output should match with leaf-level offloading with stream" + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading4, + atol=1e-5, + rtol=0, + msg="Output should match with leaf-level offloading with stream", ) @require_group_offload_support @@ -345,8 +368,12 @@ class GroupOffloadTesterMixin: raise ValueError(f"Following files are missing: {', '.join(missing_files)}") output_with_group_offloading = _run_forward(model, inputs_dict) - assert torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol), ( - "Output should match with disk-based group offloading" + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading, + atol=atol, + rtol=0, + msg="Output should match with disk-based group offloading", ) diff --git a/tests/models/testing_utils/single_file.py b/tests/models/testing_utils/single_file.py index 67d770849f..992e6dd8d9 100644 --- a/tests/models/testing_utils/single_file.py +++ b/tests/models/testing_utils/single_file.py @@ -22,6 +22,7 @@ from huggingface_hub import hf_hub_download, snapshot_download from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from ...testing_utils import ( + assert_tensors_close, backend_empty_cache, is_single_file, nightly, @@ -146,9 +147,8 @@ class SingleFileTesterMixin: f"pretrained {param.shape} vs single file {param_single_file.shape}" ) - assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), ( - f"Parameter values differ for {key}: " - f"max difference {torch.max(torch.abs(param - param_single_file)).item()}" + assert_tensors_close( + param, param_single_file, atol=1e-5, rtol=1e-5, msg=f"Parameter values differ for {key}" ) def test_single_file_loading_local_files_only(self): diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 3019308831..e0b38eda7f 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -23,13 +23,14 @@ from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProc from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin from ..testing_utils import ( AttentionTesterMixin, + BaseModelTesterConfig, BitsAndBytesTesterMixin, ContextParallelTesterMixin, GGUFTesterMixin, IPAdapterTesterMixin, + LoraHotSwappingForModelTesterMixin, LoraTesterMixin, MemoryTesterMixin, ModelOptTesterMixin, @@ -94,10 +95,26 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]: return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict} -class FluxTransformerTesterConfig: - model_class = FluxTransformer2DModel - pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" - pretrained_model_kwargs = {"subfolder": "transformer"} +class FluxTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return FluxTransformer2DModel + + @property + def pretrained_model_name_or_path(self): + return "hf-internal-testing/tiny-flux-pipe" + + @property + def pretrained_model_kwargs(self): + return {"subfolder": "transformer"} + + @property + def output_shape(self) -> tuple[int, int]: + return (16, 4) + + @property + def input_shape(self) -> tuple[int, int]: + return (16, 4) @property def generator(self): @@ -136,14 +153,6 @@ class FluxTransformerTesterConfig: "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), } - @property - def input_shape(self) -> tuple[int, int]: - return (16, 4) - - @property - def output_shape(self) -> tuple[int, int]: - return (16, 4) - class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin): def test_deprecated_inputs_img_txt_ids_3d(self): diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 9860d64dc1..4c97bbc14c 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -131,6 +131,59 @@ def torch_all_close(a, b, *args, **kwargs): return True +def assert_tensors_close( + actual: "torch.Tensor", + expected: "torch.Tensor", + atol: float = 1e-5, + rtol: float = 1e-5, + msg: str = "", +) -> None: + """ + Assert that two tensors are close within tolerance. + + Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected| + Provides concise, actionable error messages without dumping full tensors. + + Args: + actual: The actual tensor from the computation. + expected: The expected tensor to compare against. + atol: Absolute tolerance. + rtol: Relative tolerance. + msg: Optional message prefix for the assertion error. + + Raises: + AssertionError: If tensors have different shapes or values exceed tolerance. + + Example: + >>> assert_tensors_close(output, expected_output, atol=1e-5, rtol=1e-5, msg="Forward pass") + """ + if not is_torch_available(): + raise ValueError("PyTorch needs to be installed to use this function.") + + if actual.shape != expected.shape: + raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}") + + if not torch.allclose(actual, expected, atol=atol, rtol=rtol): + abs_diff = (actual - expected).abs() + max_diff = abs_diff.max().item() + + flat_idx = abs_diff.argmax().item() + max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist()) + + threshold = atol + rtol * expected.abs() + mismatched = (abs_diff > threshold).sum().item() + total = actual.numel() + + raise AssertionError( + f"{msg}\n" + f"Tensors not close! Mismatched elements: {mismatched}/{total} ({100 * mismatched / total:.1f}%)\n" + f" Max diff: {max_diff:.6e} at index {max_idx}\n" + f" Actual: {actual.flatten()[flat_idx].item():.6e}\n" + f" Expected: {expected.flatten()[flat_idx].item():.6e}\n" + f" atol: {atol:.6e}, rtol: {rtol:.6e}" + ) + + def numpy_cosine_similarity_distance(a, b): similarity = np.dot(a, b) / (norm(a) * norm(b)) distance = 1.0 - similarity.mean()