From fe451c367b7191e790e3cef6571d0cfdfd53d68e Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 11 Dec 2025 11:04:47 +0530 Subject: [PATCH] update --- tests/conftest.py | 1 + tests/models/test_modeling_common.py | 24 +- tests/models/testing_utils/__init__.py | 3 +- tests/models/testing_utils/attention.py | 98 ++++- tests/models/testing_utils/common.py | 205 ++++++--- tests/models/testing_utils/hub.py | 20 +- tests/models/testing_utils/ip_adapter.py | 95 ++-- tests/models/testing_utils/lora.py | 24 +- tests/models/testing_utils/memory.py | 84 ++-- tests/models/testing_utils/quantization.py | 55 +-- tests/models/testing_utils/single_file.py | 6 +- tests/models/testing_utils/training.py | 15 +- .../test_models_transformer_flux.py | 410 +++++++++++------- .../test_models_transformer_flux_.py | 330 -------------- tests/quantization/gguf/test_gguf.py | 6 +- tests/testing_utils.py | 9 + 16 files changed, 679 insertions(+), 706 deletions(-) delete mode 100644 tests/models/transformers/test_models_transformer_flux_.py diff --git a/tests/conftest.py b/tests/conftest.py index 3744de27f3..0f7b9ef984 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,6 +46,7 @@ def pytest_configure(config): config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality") config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality") config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality") + config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality") def pytest_addoption(parser): diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 520bd8f871..6f4c3d544b 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -317,9 +317,9 @@ class ModelUtilsTest(unittest.TestCase): repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True ) - assert all( - torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters()) - ), "Model parameters don't match!" + assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), ( + "Model parameters don't match!" + ) # Remove a shard file cached_shard_file = try_to_load_from_cache( @@ -335,9 +335,9 @@ class ModelUtilsTest(unittest.TestCase): # Verify error mentions the missing shard error_msg = str(context.exception) - assert ( - cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg - ), f"Expected error about missing shard, got: {error_msg}" + assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, ( + f"Expected error about missing shard, got: {error_msg}" + ) @unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners") @unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.") @@ -354,9 +354,9 @@ class ModelUtilsTest(unittest.TestCase): ) download_requests = [r.method for r in m.request_history] - assert ( - download_requests.count("HEAD") == 3 - ), "3 HEAD requests one for config, one for model, and one for shard index file." + assert download_requests.count("HEAD") == 3, ( + "3 HEAD requests one for config, one for model, and one for shard index file." + ) assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model" with requests_mock.mock(real_http=True) as m: @@ -368,9 +368,9 @@ class ModelUtilsTest(unittest.TestCase): ) cache_requests = [r.method for r in m.request_history] - assert ( - "HEAD" == cache_requests[0] and len(cache_requests) == 2 - ), "We should call only `model_info` to check for commit hash and knowing if shard index is present." + assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, ( + "We should call only `model_info` to check for commit hash and knowing if shard index is present." + ) def test_weight_overwrite(self): with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py index 7955982ca9..e72a3c928b 100644 --- a/tests/models/testing_utils/__init__.py +++ b/tests/models/testing_utils/__init__.py @@ -1,4 +1,4 @@ -from .attention import AttentionTesterMixin +from .attention import AttentionTesterMixin, ContextParallelTesterMixin from .common import ModelTesterMixin from .compile import TorchCompileTesterMixin from .ip_adapter import IPAdapterTesterMixin @@ -17,6 +17,7 @@ from .training import TrainingTesterMixin __all__ = [ + "ContextParallelTesterMixin", "AttentionTesterMixin", "BitsAndBytesTesterMixin", "CPUOffloadTesterMixin", diff --git a/tests/models/testing_utils/attention.py b/tests/models/testing_utils/attention.py index 22512c9458..f794a7a0aa 100644 --- a/tests/models/testing_utils/attention.py +++ b/tests/models/testing_utils/attention.py @@ -13,15 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest import torch +import torch.multiprocessing as mp +from diffusers.models._modeling_parallel import ContextParallelConfig from diffusers.models.attention import AttentionModuleMixin from diffusers.models.attention_processor import ( AttnProcessor, ) -from ...testing_utils import is_attention, torch_device +from ...testing_utils import is_attention, is_context_parallel, require_torch_multi_accelerator, torch_device @is_attention @@ -85,9 +89,9 @@ class AttentionTesterMixin: output_after_fusion = output_after_fusion.to_tuple()[0] # Verify outputs match - assert torch.allclose( - output_before_fusion, output_after_fusion, atol=self.base_precision - ), "Output should not change after fusing projections" + assert torch.allclose(output_before_fusion, output_after_fusion, atol=self.base_precision), ( + "Output should not change after fusing projections" + ) # Unfuse projections model.unfuse_qkv_projections() @@ -106,9 +110,9 @@ class AttentionTesterMixin: output_after_unfusion = output_after_unfusion.to_tuple()[0] # Verify outputs still match - assert torch.allclose( - output_before_fusion, output_after_unfusion, atol=self.base_precision - ), "Output should match original after unfusing projections" + assert torch.allclose(output_before_fusion, output_after_unfusion, atol=self.base_precision), ( + "Output should match original after unfusing projections" + ) def test_get_set_processor(self): init_dict = self.get_init_dict() @@ -177,3 +181,83 @@ class AttentionTesterMixin: model.set_attn_processor(wrong_processors) assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch" + + +def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue): + try: + # Setup distributed environment + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + torch.distributed.init_process_group( + backend="nccl", + init_method="env://", + world_size=world_size, + rank=rank, + ) + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank}") + + model = model_class(**init_dict) + model.to(device) + model.eval() + + inputs_on_device = {} + for key, value in inputs_dict.items(): + if isinstance(value, torch.Tensor): + inputs_on_device[key] = value.to(device) + else: + inputs_on_device[key] = value + + cp_config = ContextParallelConfig(**cp_dict) + model.enable_parallelism(config=cp_config) + + with torch.no_grad(): + output = model(**inputs_on_device) + if isinstance(output, dict): + output = output.to_tuple()[0] + + if rank == 0: + result_queue.put(("success", output.shape)) + + except Exception as e: + if rank == 0: + result_queue.put(("error", str(e))) + finally: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +@is_context_parallel +@require_torch_multi_accelerator +class ContextParallelTesterMixin: + base_precision = 1e-3 + + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) + def test_context_parallel_inference(self, cp_type): + if not torch.distributed.is_available(): + pytest.skip("torch.distributed is not available.") + + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + pytest.skip("Context parallel requires at least 2 CUDA devices.") + + if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: + pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + + world_size = 2 + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + cp_dict = {cp_type: world_size} + + ctx = mp.get_context("spawn") + result_queue = ctx.Queue() + + mp.spawn( + _context_parallel_worker, + args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue), + nprocs=world_size, + join=True, + ) + + status, result = result_queue.get(timeout=60) + assert status == "success", f"Context parallel inference failed: {result}" diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index 7ec8dbbd8b..9f4ae271f9 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -16,16 +16,45 @@ import json import os import tempfile +from collections import defaultdict import pytest import torch import torch.nn as nn from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size -from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant +from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator -from ...testing_utils import torch_device +from ...testing_utils import CaptureLogger, torch_device + + +def named_persistent_module_tensors( + module: nn.Module, + recurse: bool = False, +): + """ + A helper function that gathers all the tensors (parameters + persistent buffers) of a given module. + + Args: + module (`torch.nn.Module`): + The module we want the tensors on. + recurse (`bool`, *optional`, defaults to `False`): + Whether or not to go look in every submodule or just return the direct parameters and buffers. + """ + yield from module.named_parameters(recurse=recurse) + + for named_buffer in module.named_buffers(recurse=recurse): + name, _ = named_buffer + # Get parent by splitting on dots and traversing the model + parent = module + if "." in name: + parent_name = name.rsplit(".", 1)[0] + for part in parent_name.split("."): + parent = getattr(parent, part) + name = name.split(".")[-1] + if name not in parent._non_persistent_buffers_set: + yield named_buffer def compute_module_persistent_sizes( @@ -96,9 +125,9 @@ def check_device_map_is_respected(model, device_map): if param_device in ["cpu", "disk"]: assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}" else: - assert param.device == torch.device( - param_device - ), f"Expected device {param_device} for {param_name}, got {param.device}" + assert param.device == torch.device(param_device), ( + f"Expected device {param_device} for {param_name}, got {param.device}" + ) class ModelTesterMixin: @@ -123,9 +152,7 @@ class ModelTesterMixin: raise NotImplementedError("get_init_dict must be implemented by subclasses. ") def get_dummy_inputs(self): - raise NotImplementedError( - "get_dummy_inputs must be implemented by subclasses. " "It should return inputs_dict." - ) + raise NotImplementedError("get_dummy_inputs must be implemented by subclasses. It should return inputs_dict.") def test_from_save_pretrained(self, expected_max_diff=5e-5): torch.manual_seed(0) @@ -142,9 +169,9 @@ class ModelTesterMixin: for param_name in model.state_dict().keys(): param_1 = model.state_dict()[param_name] param_2 = new_model.state_dict()[param_name] - assert ( - param_1.shape == param_2.shape - ), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}" + assert param_1.shape == param_2.shape, ( + f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}" + ) with torch.no_grad(): image = model(**self.get_dummy_inputs()) @@ -158,9 +185,9 @@ class ModelTesterMixin: new_image = new_image.to_tuple()[0] max_diff = (image - new_image).abs().max().item() - assert ( - max_diff <= expected_max_diff - ), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" + assert max_diff <= expected_max_diff, ( + f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" + ) def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): model = self.model_class(**self.get_init_dict()) @@ -191,9 +218,9 @@ class ModelTesterMixin: new_image = new_image.to_tuple()[0] max_diff = (image - new_image).abs().max().item() - assert ( - max_diff <= expected_max_diff - ), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" + assert max_diff <= expected_max_diff, ( + f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" + ) def test_from_save_pretrained_dtype(self): model = self.model_class(**self.get_init_dict()) @@ -242,9 +269,9 @@ class ModelTesterMixin: second_filtered = second_flat[mask] max_diff = torch.abs(first_filtered - second_filtered).max().item() - assert ( - max_diff <= expected_max_diff - ), f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}" + assert max_diff <= expected_max_diff, ( + f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}" + ) def test_output(self, expected_output_shape=None): model = self.model_class(**self.get_init_dict()) @@ -259,9 +286,9 @@ class ModelTesterMixin: output = output.to_tuple()[0] assert output is not None, "Model output is None" - assert ( - output[0].shape == expected_output_shape or self.output_shape - ), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}" + assert output[0].shape == expected_output_shape or self.output_shape, ( + f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}" + ) def test_outputs_equivalence(self): def set_nan_tensor_to_zero(t): @@ -302,6 +329,71 @@ class ModelTesterMixin: recursive_check(outputs_tuple, outputs_dict) + def test_getattr_is_correct(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + + # save some things to test + model.dummy_attribute = 5 + model.register_to_config(test_attribute=5) + + logger = logging.get_logger("diffusers.models.modeling_utils") + # 30 for warning + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + assert hasattr(model, "dummy_attribute") + assert getattr(model, "dummy_attribute") == 5 + assert model.dummy_attribute == 5 + + # no warning should be thrown + assert cap_logger.out == "" + + logger = logging.get_logger("diffusers.models.modeling_utils") + # 30 for warning + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + assert hasattr(model, "save_pretrained") + fn = model.save_pretrained + fn_1 = getattr(model, "save_pretrained") + + assert fn == fn_1 + # no warning should be thrown + assert cap_logger.out == "" + + # warning should be thrown for config attributes accessed directly + with pytest.warns(FutureWarning): + assert model.test_attribute == 5 + + with pytest.warns(FutureWarning): + assert getattr(model, "test_attribute") == 5 + + with pytest.raises(AttributeError) as error: + model.does_not_exist + + assert str(error.value) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'" + + @require_accelerator + @pytest.mark.skipif( + torch_device not in ["cuda", "xpu"], + reason="float16 and bfloat16 can only be used with an accelerator", + ) + def test_keep_in_fp32_modules(self): + model = self.model_class(**self.get_init_dict()) + fp32_modules = model._keep_in_fp32_modules + + if fp32_modules is None or len(fp32_modules) == 0: + pytest.skip("Model does not have _keep_in_fp32_modules defined.") + + # Test with float16 + model.to(torch_device) + model.to(torch.float16) + + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules): + assert param.dtype == torch.float32, f"Parameter {name} should be float32 but got {param.dtype}" + else: + assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}" + @require_accelerator @pytest.mark.skipif( torch_device not in ["cuda", "xpu"], @@ -324,12 +416,12 @@ class ModelTesterMixin: assert param.data.dtype == torch_dtype with torch.no_grad(): - output = model(**get_dummy_inputs()) - output_loaded = model_loaded(**get_dummy_inputs()) + output = model(**self.get_dummy_inputs()) + output_loaded = model_loaded(**self.get_dummy_inputs()) - assert torch.allclose( - output, output_loaded, atol=1e-4 - ), f"Loaded model output differs for {torch_dtype}" + assert torch.allclose(output, output_loaded, atol=1e-4), ( + f"Loaded model output differs for {torch_dtype}" + ) @require_accelerator def test_sharded_checkpoints(self): @@ -350,9 +442,9 @@ class ModelTesterMixin: # Check if the right number of shards exists expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - assert ( - actual_num_shards == expected_num_shards - ), f"Expected {expected_num_shards} shards, got {actual_num_shards}" + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) new_model = self.model_class.from_pretrained(tmp_dir).eval() new_model = new_model.to(torch_device) @@ -361,9 +453,9 @@ class ModelTesterMixin: inputs_dict_new = self.get_dummy_inputs() new_output = new_model(**inputs_dict_new) - assert torch.allclose( - base_output[0], new_output[0], atol=1e-5 - ), "Output should match after sharded save/load" + assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( + "Output should match after sharded save/load" + ) @require_accelerator def test_sharded_checkpoints_with_variant(self): @@ -382,16 +474,16 @@ class ModelTesterMixin: model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant) index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) - assert os.path.exists( - os.path.join(tmp_dir, index_filename) - ), f"Variant index file {index_filename} should exist" + assert os.path.exists(os.path.join(tmp_dir, index_filename)), ( + f"Variant index file {index_filename} should exist" + ) # Check if the right number of shards exists expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, index_filename)) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - assert ( - actual_num_shards == expected_num_shards - ), f"Expected {expected_num_shards} shards, got {actual_num_shards}" + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval() new_model = new_model.to(torch_device) @@ -400,11 +492,10 @@ class ModelTesterMixin: inputs_dict_new = self.get_dummy_inputs() new_output = new_model(**inputs_dict_new) - assert torch.allclose( - base_output[0], new_output[0], atol=1e-5 - ), "Output should match after variant sharded save/load" + assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( + "Output should match after variant sharded save/load" + ) - @require_accelerator def test_sharded_checkpoints_with_parallel_loading(self): import time @@ -433,9 +524,9 @@ class ModelTesterMixin: # Check if the right number of shards exists expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - assert ( - actual_num_shards == expected_num_shards - ), f"Expected {expected_num_shards} shards, got {actual_num_shards}" + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) # Load without parallel loading constants.HF_ENABLE_PARALLEL_LOADING = False @@ -459,16 +550,14 @@ class ModelTesterMixin: inputs_dict_parallel = self.get_dummy_inputs() output_parallel = model_parallel(**inputs_dict_parallel) - assert torch.allclose( - base_output[0], output_parallel[0], atol=1e-5 - ), "Output should match with parallel loading" + assert torch.allclose(base_output[0], output_parallel[0], atol=1e-5), ( + "Output should match with parallel loading" + ) # Verify parallel loading is faster or at least not significantly slower - # For small test models, the difference might be negligible or even slightly slower due to overhead - # so we just check that parallel loading completed successfully and outputs match - assert ( - parallel_load_time < sequential_load_time - ), f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s" + assert parallel_load_time < sequential_load_time, ( + f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s" + ) finally: # Restore original values constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading @@ -506,6 +595,6 @@ class ModelTesterMixin: torch.manual_seed(0) new_output = new_model(**inputs_dict) - assert torch.allclose( - base_output[0], new_output[0], atol=1e-5 - ), "Output should match with model parallelism" + assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( + "Output should match with model parallelism" + ) diff --git a/tests/models/testing_utils/hub.py b/tests/models/testing_utils/hub.py index e20c3ab163..40d8777c33 100644 --- a/tests/models/testing_utils/hub.py +++ b/tests/models/testing_utils/hub.py @@ -18,7 +18,7 @@ import uuid import pytest import torch -from huggingface_hub.utils import is_jinja_available +from huggingface_hub.utils import ModelCard, delete_repo, is_jinja_available from ...others.test_utils import TOKEN, USER, is_staging_test @@ -58,9 +58,9 @@ class ModelPushToHubTesterMixin: new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}") for p1, p2 in zip(model.parameters(), new_model.parameters()): - assert torch.equal( - p1, p2 - ), "Parameters don't match after save_pretrained with push_to_hub and from_pretrained" + assert torch.equal(p1, p2), ( + "Parameters don't match after save_pretrained with push_to_hub and from_pretrained" + ) # Reset repo delete_repo(self.repo_id, token=TOKEN) @@ -84,9 +84,9 @@ class ModelPushToHubTesterMixin: new_model = self.model_class.from_pretrained(self.org_repo_id) for p1, p2 in zip(model.parameters(), new_model.parameters()): - assert torch.equal( - p1, p2 - ), "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained" + assert torch.equal(p1, p2), ( + "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained" + ) # Reset repo delete_repo(self.org_repo_id, token=TOKEN) @@ -101,9 +101,9 @@ class ModelPushToHubTesterMixin: model.push_to_hub(self.repo_id, token=TOKEN) model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data - assert ( - model_card.library_name == "diffusers" - ), f"Expected library_name 'diffusers', got {model_card.library_name}" + assert model_card.library_name == "diffusers", ( + f"Expected library_name 'diffusers', got {model_card.library_name}" + ) # Reset repo delete_repo(self.repo_id, token=TOKEN) diff --git a/tests/models/testing_utils/ip_adapter.py b/tests/models/testing_utils/ip_adapter.py index aff2cf1864..13e141869c 100644 --- a/tests/models/testing_utils/ip_adapter.py +++ b/tests/models/testing_utils/ip_adapter.py @@ -13,9 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import tempfile +import pytest import torch from diffusers.models.attention_processor import IPAdapterAttnProcessor @@ -61,7 +60,7 @@ def create_ip_adapter_state_dict(model): return {"ip_adapter": ip_state_dict} -def check_if_ip_adapter_correctly_set(model) -> bool: +def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool: """ Check if IP Adapter processors are correctly set in the model. @@ -72,7 +71,7 @@ def check_if_ip_adapter_correctly_set(model) -> bool: bool: True if IP Adapter is correctly set, False otherwise """ for module in model.attn_processors.values(): - if isinstance(module, IPAdapterAttnProcessor): + if isinstance(module, processor_cls): return True return False @@ -93,48 +92,49 @@ class IPAdapterTesterMixin: Use `pytest -m "not ip_adapter"` to skip these tests """ + ip_adapter_processor_cls = None + def create_ip_adapter_state_dict(self, model): raise NotImplementedError("child class must implement method to create IPAdapter State Dict") + def modify_inputs_for_ip_adapter(self, model, inputs_dict): + raise NotImplementedError("child class must implement method to create IPAdapter model inputs") + def test_load_ip_adapter(self): init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) - self.prepare_model(model) torch.manual_seed(0) output_no_adapter = model(**inputs_dict, return_dict=False)[0] - # Create dummy IP adapter state dict ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) - # Load IP adapter model._load_ip_adapter_weights([ip_adapter_state_dict]) - assert check_if_ip_adapter_correctly_set(model), "IP Adapter processors not set correctly" - - torch.manual_seed(0) - # Create dummy image embeds for IP adapter - cross_attention_dim = getattr(model.config, "cross_attention_dim", 32) - image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device) - inputs_dict_with_adapter = inputs_dict.copy() - inputs_dict_with_adapter["image_embeds"] = image_embeds + assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), ( + "IP Adapter processors not set correctly" + ) + inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy()) outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0] - assert not torch.allclose( - output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4 - ), "Output should differ with IP Adapter enabled" + assert not torch.allclose(output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4), ( + "Output should differ with IP Adapter enabled" + ) + @pytest.mark.skip( + reason="Setting IP Adapter scale is not defined at the model level. Enable this test after refactoring" + ) def test_ip_adapter_scale(self): init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) - # self.prepare_model(model) - # Create and load dummy IP adapter state dict ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) model._load_ip_adapter_weights([ip_adapter_state_dict]) + inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy()) + # Test scale = 0.0 (no effect) model.set_ip_adapter_scale(0.0) torch.manual_seed(0) @@ -146,14 +146,16 @@ class IPAdapterTesterMixin: output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0] # Outputs should differ with different scales - assert not torch.allclose( - output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4 - ), "Output should differ with different IP Adapter scales" + assert not torch.allclose(output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4), ( + "Output should differ with different IP Adapter scales" + ) + @pytest.mark.skip( + reason="Unloading IP Adapter is not defined at the model level. Enable this test after refactoring" + ) def test_unload_ip_adapter(self): init_dict = self.get_init_dict() model = self.model_class(**init_dict).to(torch_device) - self.prepare_model(model) # Save original processors original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()} @@ -161,49 +163,16 @@ class IPAdapterTesterMixin: # Create and load IP adapter ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) model._load_ip_adapter_weights([ip_adapter_state_dict]) - assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be set" + + assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), "IP Adapter should be set" # Unload IP adapter model.unload_ip_adapter() - assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded" + + assert not check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), ( + "IP Adapter should be unloaded" + ) # Verify processors are restored current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()} assert original_processors == current_processors, "Processors should be restored after unload" - - def test_ip_adapter_save_load(self): - init_dict = self.get_init_dict() - inputs_dict = self.get_dummy_inputs() - model = self.model_class(**init_dict).to(torch_device) - self.prepare_model(model) - - # Create and load IP adapter - ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) - model._load_ip_adapter_weights([ip_adapter_state_dict]) - - torch.manual_seed(0) - output_before_save = model(**inputs_dict, return_dict=False)[0] - - with tempfile.TemporaryDirectory() as tmpdir: - # Save the IP adapter weights - save_path = os.path.join(tmpdir, "ip_adapter.safetensors") - import safetensors.torch - - safetensors.torch.save_file(ip_adapter_state_dict["ip_adapter"], save_path) - - # Unload and reload - model.unload_ip_adapter() - assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded" - - # Reload from saved file - loaded_state_dict = {"ip_adapter": safetensors.torch.load_file(save_path)} - model._load_ip_adapter_weights([loaded_state_dict]) - assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be loaded" - - torch.manual_seed(0) - output_after_load = model(**inputs_dict_with_adapter, return_dict=False)[0] - - # Outputs should match before and after save/load - assert torch.allclose( - output_before_save, output_after_load, atol=1e-4, rtol=1e-4 - ), "Output should match before and after save/load" diff --git a/tests/models/testing_utils/lora.py b/tests/models/testing_utils/lora.py index dfc3bd2955..6777c164f2 100644 --- a/tests/models/testing_utils/lora.py +++ b/tests/models/testing_utils/lora.py @@ -91,15 +91,15 @@ class LoraTesterMixin: torch.manual_seed(0) outputs_with_lora = model(**inputs_dict, return_dict=False)[0] - assert not torch.allclose( - output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4 - ), "Output should differ with LoRA enabled" + assert not torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4), ( + "Output should differ with LoRA enabled" + ) with tempfile.TemporaryDirectory() as tmpdir: model.save_lora_adapter(tmpdir) - assert os.path.isfile( - os.path.join(tmpdir, "pytorch_lora_weights.safetensors") - ), "LoRA weights file not created" + assert os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")), ( + "LoRA weights file not created" + ) state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) @@ -119,12 +119,12 @@ class LoraTesterMixin: torch.manual_seed(0) outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] - assert not torch.allclose( - output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4 - ), "Output should differ with LoRA enabled" - assert torch.allclose( - outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4 - ), "Outputs should match before and after save/load" + assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), ( + "Output should differ with LoRA enabled" + ) + assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), ( + "Outputs should match before and after save/load" + ) def test_lora_wrong_adapter_name_raises_error(self): from peft import LoraConfig diff --git a/tests/models/testing_utils/memory.py b/tests/models/testing_utils/memory.py index d06a125dc6..6cdc72b004 100644 --- a/tests/models/testing_utils/memory.py +++ b/tests/models/testing_utils/memory.py @@ -122,9 +122,9 @@ class CPUOffloadTesterMixin: torch.manual_seed(0) new_output = new_model(**inputs_dict) - assert torch.allclose( - base_output[0], new_output[0], atol=1e-5 - ), "Output should match with CPU offloading" + assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( + "Output should match with CPU offloading" + ) @require_offload_support def test_disk_offload_without_safetensors(self): @@ -183,9 +183,9 @@ class CPUOffloadTesterMixin: torch.manual_seed(0) new_output = new_model(**inputs_dict) - assert torch.allclose( - base_output[0], new_output[0], atol=1e-5 - ), "Output should match with disk offloading (safetensors)" + assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( + "Output should match with disk offloading (safetensors)" + ) @is_group_offload @@ -247,18 +247,18 @@ class GroupOffloadTesterMixin: ) output_with_group_offloading4 = run_forward(model) - assert torch.allclose( - output_without_group_offloading, output_with_group_offloading1, atol=1e-5 - ), "Output should match with block-level offloading" - assert torch.allclose( - output_without_group_offloading, output_with_group_offloading2, atol=1e-5 - ), "Output should match with non-blocking block-level offloading" - assert torch.allclose( - output_without_group_offloading, output_with_group_offloading3, atol=1e-5 - ), "Output should match with leaf-level offloading" - assert torch.allclose( - output_without_group_offloading, output_with_group_offloading4, atol=1e-5 - ), "Output should match with leaf-level offloading with stream" + assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5), ( + "Output should match with block-level offloading" + ) + assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5), ( + "Output should match with non-blocking block-level offloading" + ) + assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5), ( + "Output should match with leaf-level offloading" + ) + assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5), ( + "Output should match with leaf-level offloading with stream" + ) @require_group_offload_support @torch.no_grad() @@ -345,9 +345,9 @@ class GroupOffloadTesterMixin: raise ValueError(f"Following files are missing: {', '.join(missing_files)}") output_with_group_offloading = _run_forward(model, inputs_dict) - assert torch.allclose( - output_without_group_offloading, output_with_group_offloading, atol=atol - ), "Output should match with disk-based group offloading" + assert torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol), ( + "Output should match with disk-based group offloading" + ) class LayerwiseCastingTesterMixin: @@ -396,16 +396,16 @@ class LayerwiseCastingTesterMixin: ) compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None - assert ( - fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint - ), "Memory footprint should decrease with lower precision storage" + assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint, ( + "Memory footprint should decrease with lower precision storage" + ) # NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it. if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY: - assert ( - fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory - ), "Peak memory should be lower with bf16 compute on newer GPUs" + assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory, ( + "Peak memory should be lower with bf16 compute on newer GPUs" + ) # On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few # bytes. This only happens for some models, so we allow a small tolerance. @@ -415,6 +415,36 @@ class LayerwiseCastingTesterMixin: or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE ), "Peak memory should be lower or within tolerance with fp8 storage" + def test_layerwise_casting_training(self): + def test_fn(storage_dtype, compute_dtype): + if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16: + pytest.skip("Skipping test because CPU doesn't go well with bfloat16.") + + model = self.model_class(**self.get_init_dict()) + model = model.to(torch_device, dtype=compute_dtype) + model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + model.train() + + inputs_dict = self.get_inputs_dict() + inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) + with torch.amp.autocast(device_type=torch.device(torch_device).type): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + input_tensor = inputs_dict[self.main_input_name] + noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device) + noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype) + loss = torch.nn.functional.mse_loss(output, noise) + + loss.backward() + + test_fn(torch.float16, torch.float32) + test_fn(torch.float8_e4m3fn, torch.float32) + test_fn(torch.float8_e5m2, torch.float32) + test_fn(torch.float8_e4m3fn, torch.bfloat16) + @is_memory @require_accelerator diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 15d6a32069..866d572f9d 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -26,6 +26,7 @@ from diffusers.utils.import_utils import ( is_nvidia_modelopt_available, is_optimum_quanto_available, is_torchao_available, + is_torchao_version, ) from ...testing_utils import ( @@ -128,9 +129,9 @@ class QuantizationTesterMixin: model_quantized = self._create_quantized_model(config_kwargs) num_params_quantized = model_quantized.num_parameters() - assert ( - num_params == num_params_quantized - ), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" + assert num_params == num_params_quantized, ( + f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" + ) def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2): model = self._load_unquantized_model() @@ -140,9 +141,9 @@ class QuantizationTesterMixin: mem_quantized = model_quantized.get_memory_footprint() ratio = mem / mem_quantized - assert ( - ratio >= expected_memory_reduction - ), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" + assert ratio >= expected_memory_reduction, ( + f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" + ) def _test_quantization_inference(self, config_kwargs): model_quantized = self._create_quantized_model(config_kwargs) @@ -243,12 +244,12 @@ class QuantizationTesterMixin: self._verify_if_layer_quantized(name, module, config_kwargs) num_quantized_layers += 1 - assert ( - num_quantized_layers > 0 - ), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" - assert ( - num_quantized_layers == expected_quantized_layers - ), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" + assert num_quantized_layers > 0, ( + f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" + ) + assert num_quantized_layers == expected_quantized_layers, ( + f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" + ) def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert): """ @@ -272,9 +273,9 @@ class QuantizationTesterMixin: if any(excluded in name for excluded in modules_to_not_convert): found_excluded = True # This module should NOT be quantized - assert not self._is_module_quantized( - module - ), f"Module {name} should not be quantized but was found to be quantized" + assert not self._is_module_quantized(module), ( + f"Module {name} should not be quantized but was found to be quantized" + ) assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}" @@ -296,9 +297,9 @@ class QuantizationTesterMixin: mem_with_exclusion = model_with_exclusion.get_memory_footprint() mem_fully_quantized = model_fully_quantized.get_memory_footprint() - assert ( - mem_with_exclusion > mem_fully_quantized - ), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" + assert mem_with_exclusion > mem_fully_quantized, ( + f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" + ) def _test_quantization_device_map(self, config_kwargs): """ @@ -380,9 +381,9 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin): def _verify_if_layer_quantized(self, name, module, config_kwargs): expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params - assert ( - module.weight.__class__ == expected_weight_class - ), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" + assert module.weight.__class__ == expected_weight_class, ( + f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" + ) @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys())) def test_bnb_quantization_num_parameters(self, config_name): @@ -450,13 +451,13 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin): for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules): - assert ( - module.weight.dtype == torch.float32 - ), f"Module {name} should be FP32 but is {module.weight.dtype}" + assert module.weight.dtype == torch.float32, ( + f"Module {name} should be FP32 but is {module.weight.dtype}" + ) else: - assert ( - module.weight.dtype == torch.uint8 - ), f"Module {name} should be uint8 but is {module.weight.dtype}" + assert module.weight.dtype == torch.uint8, ( + f"Module {name} should be uint8 but is {module.weight.dtype}" + ) with torch.no_grad(): inputs = self.get_dummy_inputs() diff --git a/tests/models/testing_utils/single_file.py b/tests/models/testing_utils/single_file.py index cb0ae7a4e7..67d770849f 100644 --- a/tests/models/testing_utils/single_file.py +++ b/tests/models/testing_utils/single_file.py @@ -192,9 +192,9 @@ class SingleFileTesterMixin: for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}" + assert model.config[param_name] == param_value, ( + f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}" + ) def test_single_file_loading_with_diffusers_config_local_files_only(self): single_file_kwargs = {} diff --git a/tests/models/testing_utils/training.py b/tests/models/testing_utils/training.py index f301b5a6d0..7e4193d59e 100644 --- a/tests/models/testing_utils/training.py +++ b/tests/models/testing_utils/training.py @@ -116,8 +116,7 @@ class TrainingTesterMixin: modules_with_gc_enabled[submodule.__class__.__name__] = True assert set(modules_with_gc_enabled.keys()) == expected_set, ( - f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} " - f"do not match expected set {expected_set}" + f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} do not match expected set {expected_set}" ) assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled" @@ -169,9 +168,9 @@ class TrainingTesterMixin: loss_2.backward() # compare the output and parameters gradients - assert ( - loss - loss_2 - ).abs() < loss_tolerance, f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}" + assert (loss - loss_2).abs() < loss_tolerance, ( + f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}" + ) named_params = dict(model.named_parameters()) named_params_2 = dict(model_2.named_parameters()) @@ -184,9 +183,9 @@ class TrainingTesterMixin: if param.grad is None: continue - assert torch_all_close( - param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol - ), f"Gradient mismatch for {name}" + assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol), ( + f"Gradient mismatch for {name}" + ) def test_mixed_precision_training(self): init_dict = self.get_init_dict() diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 3ab02f797b..43e02db448 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -13,119 +13,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +from typing import Any import torch from diffusers import FluxTransformer2DModel -from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 from diffusers.models.embeddings import ImageProjection +from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor +from diffusers.utils.torch_utils import randn_tensor -from ...testing_utils import enable_full_determinism, is_peft_available, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import LoraHotSwappingForModelTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + BitsAndBytesTesterMixin, + ContextParallelTesterMixin, + GGUFTesterMixin, + IPAdapterTesterMixin, + LoraTesterMixin, + MemoryTesterMixin, + ModelOptTesterMixin, + ModelTesterMixin, + QuantoTesterMixin, + SingleFileTesterMixin, + TorchAoTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() -def create_flux_ip_adapter_state_dict(model): - # "ip_adapter" (cross-attention weights) - ip_cross_attn_state_dict = {} - key_id = 0 - - for name in model.attn_processors.keys(): - if name.startswith("single_transformer_blocks"): - continue - - joint_attention_dim = model.config["joint_attention_dim"] - hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] - sd = FluxIPAdapterJointAttnProcessor2_0( - hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 - ).state_dict() - ip_cross_attn_state_dict.update( - { - f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], - f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], - f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], - f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], - } - ) - - key_id += 1 - - # "image_proj" (ImageProjection layer weights) - - image_projection = ImageProjection( - cross_attention_dim=model.config["joint_attention_dim"], - image_embed_dim=( - model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 - ), - num_image_text_embeds=4, - ) - - ip_image_projection_state_dict = {} - sd = image_projection.state_dict() - ip_image_projection_state_dict.update( - { - "proj.weight": sd["image_embeds.weight"], - "proj.bias": sd["image_embeds.bias"], - "norm.weight": sd["norm.weight"], - "norm.bias": sd["norm.bias"], - } - ) - - del sd - ip_state_dict = {} - ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) - return ip_state_dict - - -class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): +class FluxTransformerTesterConfig: model_class = FluxTransformer2DModel - main_input_name = "hidden_states" - # We override the items here because the transformer under consideration is small. - model_split_percents = [0.7, 0.6, 0.6] - - # Skip setting testing with default: AttnProcessor - uses_custom_attn_processor = True + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" + pretrained_model_kwargs = {"subfolder": "transformer"} @property - def dummy_input(self): - return self.prepare_dummy_input() - - @property - def input_shape(self): - return (16, 4) - - @property - def output_shape(self): - return (16, 4) - - def prepare_dummy_input(self, height=4, width=4): - batch_size = 1 - num_latent_channels = 4 - num_image_channels = 3 - sequence_length = 48 - embedding_dim = 32 - - hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) - pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device) - text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) - image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) - timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + def generator(self): + return torch.Generator("cpu").manual_seed(0) + def get_init_dict(self) -> dict[str, int | list[int]]: + """Return Flux model initialization arguments.""" return { - "hidden_states": hidden_states, - "encoder_hidden_states": encoder_hidden_states, - "img_ids": image_ids, - "txt_ids": text_ids, - "pooled_projections": pooled_prompt_embeds, - "timestep": timestep, - } - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { "patch_size": 1, "in_channels": 4, "num_layers": 1, @@ -137,11 +68,40 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): "axes_dims_rope": [4, 4, 8], } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + batch_size = 1 + height = width = 4 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 48 + embedding_dim = 32 + return { + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), generator=self.generator), + "encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator + ), + "pooled_projections": randn_tensor((batch_size, embedding_dim), generator=self.generator), + "img_ids": randn_tensor((height * width, num_image_channels), generator=self.generator), + "txt_ids": randn_tensor((sequence_length, num_image_channels), generator=self.generator), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + } + + @property + def input_shape(self) -> tuple[int, int]: + return (16, 4) + + @property + def output_shape(self) -> tuple[int, int]: + return (16, 4) + + +class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin): def test_deprecated_inputs_img_txt_ids_3d(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + """Test that deprecated 3D img_ids and txt_ids still work.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) model.to(torch_device) model.eval() @@ -162,63 +122,223 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): with torch.no_grad(): output_2 = model(**inputs_dict).to_tuple()[0] - self.assertEqual(output_1.shape, output_2.shape) - self.assertTrue( - torch.allclose(output_1, output_2, atol=1e-5), - msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", + assert output_1.shape == output_2.shape + assert torch.allclose(output_1, output_2, atol=1e-5), ( + "output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) " + "are not equal as them as 2d inputs" ) - def test_gradient_checkpointing_is_applied(self): - expected_set = {"FluxTransformer2DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - # The test exists for cases like - # https://github.com/huggingface/diffusers/issues/11874 - @unittest.skipIf(not is_peft_available(), "Only with PEFT") - def test_lora_exclude_modules(self): - from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict +class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin): + """Memory optimization tests for Flux Transformer.""" - lora_rank = 4 - target_module = "single_transformer_blocks.0.proj_out" - adapter_name = "foo" - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict).to(torch_device) + pass - state_dict = model.state_dict() - target_mod_shape = state_dict[f"{target_module}.weight"].shape - lora_state_dict = { - f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22, - f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33, + +class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin): + """Training tests for Flux Transformer.""" + + pass + + +class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin): + """Attention processor tests for Flux Transformer.""" + + pass + + +class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin): + """Context Parallel inference tests for Flux Transformer""" + + pass + + +class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin): + """IP Adapter tests for Flux Transformer.""" + + ip_adapter_processor_cls = FluxIPAdapterAttnProcessor + + def modify_inputs_for_ip_adapter(self, model, inputs_dict): + torch.manual_seed(0) + # Create dummy image embeds for IP adapter + cross_attention_dim = getattr(model.config, "joint_attention_dim", 32) + image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device) + + inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}}) + + return inputs_dict + + def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]: + from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor + + ip_cross_attn_state_dict = {} + key_id = 0 + + for name in model.attn_processors.keys(): + if name.startswith("single_transformer_blocks"): + continue + + joint_attention_dim = model.config["joint_attention_dim"] + hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] + sd = FluxIPAdapterAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], + f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], + f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], + } + ) + + key_id += 1 + + image_projection = ImageProjection( + cross_attention_dim=model.config["joint_attention_dim"], + image_embed_dim=( + model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 + ), + num_image_text_embeds=4, + ) + + ip_image_projection_state_dict = {} + sd = image_projection.state_dict() + ip_image_projection_state_dict.update( + { + "proj.weight": sd["image_embeds.weight"], + "proj.bias": sd["image_embeds.bias"], + "norm.weight": sd["norm.weight"], + "norm.bias": sd["norm.bias"], + } + ) + + del sd + ip_state_dict = {} + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + return ip_state_dict + + +class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin): + """LoRA adapter tests for Flux Transformer.""" + + pass + + +class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): + """LoRA hot-swapping tests for Flux Transformer.""" + + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + """Override to support dynamic height/width for LoRA hotswap tests.""" + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 24 + embedding_dim = 8 + + return { + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), + "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), + "pooled_projections": randn_tensor((batch_size, embedding_dim)), + "img_ids": randn_tensor((height * width, num_image_channels)), + "txt_ids": randn_tensor((sequence_length, num_image_channels)), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), } - # Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter). - config = LoraConfig( - r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"] - ) - inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict) - set_peft_model_state_dict(model, lora_state_dict, adapter_name) - retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name) - assert len(retrieved_lora_state_dict) == len(lora_state_dict) - assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all() - assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all() -class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = FluxTransformer2DModel +class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin): different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] - def prepare_init_args_and_inputs_for_common(self): - return FluxTransformerTests().prepare_init_args_and_inputs_for_common() + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + """Override to support dynamic height/width for compilation tests.""" + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 24 + embedding_dim = 8 - def prepare_dummy_input(self, height, width): - return FluxTransformerTests().prepare_dummy_input(height=height, width=width) + return { + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), + "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), + "pooled_projections": randn_tensor((batch_size, embedding_dim)), + "img_ids": randn_tensor((height * width, num_image_channels)), + "txt_ids": randn_tensor((sequence_length, num_image_channels)), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + } -class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): - model_class = FluxTransformer2DModel - different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] +class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin): + ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" + alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] + pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" + subfolder = "transformer" + pass - def prepare_init_args_and_inputs_for_common(self): - return FluxTransformerTests().prepare_init_args_and_inputs_for_common() - def prepare_dummy_input(self, height, width): - return FluxTransformerTests().prepare_dummy_input(height=height, width=width) +class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin): + gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf" + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } diff --git a/tests/models/transformers/test_models_transformer_flux_.py b/tests/models/transformers/test_models_transformer_flux_.py deleted file mode 100644 index e79974459f..0000000000 --- a/tests/models/transformers/test_models_transformer_flux_.py +++ /dev/null @@ -1,330 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -import torch - -from diffusers import FluxTransformer2DModel -from diffusers.models.embeddings import ImageProjection -from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor -from diffusers.utils.torch_utils import randn_tensor - -from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin -from ..testing_utils import ( - AttentionTesterMixin, - BitsAndBytesTesterMixin, - GGUFTesterMixin, - IPAdapterTesterMixin, - LoraTesterMixin, - MemoryTesterMixin, - ModelOptTesterMixin, - ModelTesterMixin, - QuantoTesterMixin, - SingleFileTesterMixin, - TorchAoTesterMixin, - TorchCompileTesterMixin, - TrainingTesterMixin, -) - - -enable_full_determinism() - - -class FluxTransformerTesterConfig: - model_class = FluxTransformer2DModel - pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" - pretrained_model_kwargs = {"subfolder": "transformer"} - - @property - def generator(self): - return torch.Generator("cpu").manual_seed(0) - - def get_init_dict(self) -> dict[str, int | list[int]]: - """Return Flux model initialization arguments.""" - return { - "patch_size": 1, - "in_channels": 4, - "num_layers": 1, - "num_single_layers": 1, - "attention_head_dim": 16, - "num_attention_heads": 2, - "joint_attention_dim": 32, - "pooled_projection_dim": 32, - "axes_dims_rope": [4, 4, 8], - } - - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - batch_size = 1 - height = width = 4 - num_latent_channels = 4 - num_image_channels = 3 - sequence_length = 48 - embedding_dim = 32 - - return { - "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), generator=self.generator), - "encoder_hidden_states": randn_tensor( - (batch_size, sequence_length, embedding_dim), generator=self.generator - ), - "pooled_projections": randn_tensor((batch_size, embedding_dim), generator=self.generator), - "img_ids": randn_tensor((height * width, num_image_channels), generator=self.generator), - "txt_ids": randn_tensor((sequence_length, num_image_channels), generator=self.generator), - "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), - } - - @property - def input_shape(self) -> tuple[int, int]: - return (16, 4) - - @property - def output_shape(self) -> tuple[int, int]: - return (16, 4) - - -class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin): - def test_deprecated_inputs_img_txt_ids_3d(self): - """Test that deprecated 3D img_ids and txt_ids still work.""" - init_dict = self.get_init_dict() - inputs_dict = self.get_dummy_inputs() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output_1 = model(**inputs_dict).to_tuple()[0] - - # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated) - text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0) - image_ids_3d = inputs_dict["img_ids"].unsqueeze(0) - - assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor" - assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor" - - inputs_dict["txt_ids"] = text_ids_3d - inputs_dict["img_ids"] = image_ids_3d - - with torch.no_grad(): - output_2 = model(**inputs_dict).to_tuple()[0] - - assert output_1.shape == output_2.shape - assert torch.allclose(output_1, output_2, atol=1e-5), ( - "output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) " - "are not equal as them as 2d inputs" - ) - - -class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin): - """Memory optimization tests for Flux Transformer.""" - - pass - - -class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin): - """Training tests for Flux Transformer.""" - - pass - - -class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin): - """Attention processor tests for Flux Transformer.""" - - pass - - -class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin): - """IP Adapter tests for Flux Transformer.""" - - def prepare_model(self, model): - joint_attention_dim = model.config["joint_attention_dim"] - hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] - model.set_attn_processor(FluxIPAdapterAttnProcessor(hidden_size, joint_attention_dim, scale=1.0)) - - def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]: - from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor - - ip_cross_attn_state_dict = {} - key_id = 0 - - for name in model.attn_processors.keys(): - if name.startswith("single_transformer_blocks"): - continue - - joint_attention_dim = model.config["joint_attention_dim"] - hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] - sd = FluxIPAdapterAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 - ).state_dict() - ip_cross_attn_state_dict.update( - { - f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], - f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], - f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], - f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], - } - ) - - key_id += 1 - - image_projection = ImageProjection( - cross_attention_dim=model.config["joint_attention_dim"], - image_embed_dim=( - model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 - ), - num_image_text_embeds=4, - ) - - ip_image_projection_state_dict = {} - sd = image_projection.state_dict() - ip_image_projection_state_dict.update( - { - "proj.weight": sd["image_embeds.weight"], - "proj.bias": sd["image_embeds.bias"], - "norm.weight": sd["norm.weight"], - "norm.bias": sd["norm.bias"], - } - ) - - del sd - ip_state_dict = {} - ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) - return ip_state_dict - - -class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin): - """LoRA adapter tests for Flux Transformer.""" - - pass - - -class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): - """LoRA hot-swapping tests for Flux Transformer.""" - - different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] - - def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: - """Override to support dynamic height/width for LoRA hotswap tests.""" - batch_size = 1 - num_latent_channels = 4 - num_image_channels = 3 - sequence_length = 24 - embedding_dim = 8 - - return { - "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), - "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), - "pooled_projections": randn_tensor((batch_size, embedding_dim)), - "img_ids": randn_tensor((height * width, num_image_channels)), - "txt_ids": randn_tensor((sequence_length, num_image_channels)), - "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), - } - - -class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin): - different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] - - def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: - """Override to support dynamic height/width for compilation tests.""" - batch_size = 1 - num_latent_channels = 4 - num_image_channels = 3 - sequence_length = 24 - embedding_dim = 8 - - return { - "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), - "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), - "pooled_projections": randn_tensor((batch_size, embedding_dim)), - "img_ids": randn_tensor((height * width, num_image_channels)), - "txt_ids": randn_tensor((sequence_length, num_image_channels)), - "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), - } - - -class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin): - ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" - alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] - pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" - subfolder = "transformer" - pass - - -class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin): - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - return { - "hidden_states": randn_tensor((1, 4096, 64)), - "encoder_hidden_states": randn_tensor((1, 512, 4096)), - "pooled_projections": randn_tensor((1, 768)), - "timestep": torch.tensor([1.0]).to(torch_device), - "img_ids": randn_tensor((4096, 3)), - "txt_ids": randn_tensor((512, 3)), - "guidance": torch.tensor([3.5]).to(torch_device), - } - - -class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin): - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - return { - "hidden_states": randn_tensor((1, 4096, 64)), - "encoder_hidden_states": randn_tensor((1, 512, 4096)), - "pooled_projections": randn_tensor((1, 768)), - "timestep": torch.tensor([1.0]).to(torch_device), - "img_ids": randn_tensor((4096, 3)), - "txt_ids": randn_tensor((512, 3)), - "guidance": torch.tensor([3.5]).to(torch_device), - } - - -class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin): - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - return { - "hidden_states": randn_tensor((1, 4096, 64)), - "encoder_hidden_states": randn_tensor((1, 512, 4096)), - "pooled_projections": randn_tensor((1, 768)), - "timestep": torch.tensor([1.0]).to(torch_device), - "img_ids": randn_tensor((4096, 3)), - "txt_ids": randn_tensor((512, 3)), - "guidance": torch.tensor([3.5]).to(torch_device), - } - - -class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin): - gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf" - - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - return { - "hidden_states": randn_tensor((1, 4096, 64)), - "encoder_hidden_states": randn_tensor((1, 512, 4096)), - "pooled_projections": randn_tensor((1, 768)), - "timestep": torch.tensor([1.0]).to(torch_device), - "img_ids": randn_tensor((4096, 3)), - "txt_ids": randn_tensor((512, 3)), - "guidance": torch.tensor([3.5]).to(torch_device), - } - - -class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin): - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - return { - "hidden_states": randn_tensor((1, 4096, 64)), - "encoder_hidden_states": randn_tensor((1, 512, 4096)), - "pooled_projections": randn_tensor((1, 768)), - "timestep": torch.tensor([1.0]).to(torch_device), - "img_ids": randn_tensor((4096, 3)), - "txt_ids": randn_tensor((512, 3)), - "guidance": torch.tensor([3.5]).to(torch_device), - } diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index c3bb71e794..0f4fd408a7 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -98,9 +98,9 @@ class GGUFCudaKernelsTests(unittest.TestCase): output_native = linear.forward_native(x) output_cuda = linear.forward_cuda(x) - assert torch.allclose( - output_native, output_cuda, 1e-2 - ), f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}" + assert torch.allclose(output_native, output_cuda, 1e-2), ( + f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}" + ) @nightly diff --git a/tests/testing_utils.py b/tests/testing_utils.py index ae69a21cf8..9860d64dc1 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -406,6 +406,15 @@ def is_modelopt(test_case): return pytest.mark.modelopt(test_case) +def is_context_parallel(test_case): + """ + Decorator marking a test as a context parallel inference test. These tests can be filtered using: + pytest -m "not context_parallel" to skip + pytest -m context_parallel to run only these tests + """ + return pytest.mark.context_parallel(test_case) + + def require_torch(test_case): """ Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.