diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 6f4c3d544b..520bd8f871 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -317,9 +317,9 @@ class ModelUtilsTest(unittest.TestCase): repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True ) - assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), ( - "Model parameters don't match!" - ) + assert all( + torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters()) + ), "Model parameters don't match!" # Remove a shard file cached_shard_file = try_to_load_from_cache( @@ -335,9 +335,9 @@ class ModelUtilsTest(unittest.TestCase): # Verify error mentions the missing shard error_msg = str(context.exception) - assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, ( - f"Expected error about missing shard, got: {error_msg}" - ) + assert ( + cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg + ), f"Expected error about missing shard, got: {error_msg}" @unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners") @unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.") @@ -354,9 +354,9 @@ class ModelUtilsTest(unittest.TestCase): ) download_requests = [r.method for r in m.request_history] - assert download_requests.count("HEAD") == 3, ( - "3 HEAD requests one for config, one for model, and one for shard index file." - ) + assert ( + download_requests.count("HEAD") == 3 + ), "3 HEAD requests one for config, one for model, and one for shard index file." assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model" with requests_mock.mock(real_http=True) as m: @@ -368,9 +368,9 @@ class ModelUtilsTest(unittest.TestCase): ) cache_requests = [r.method for r in m.request_history] - assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, ( - "We should call only `model_info` to check for commit hash and knowing if shard index is present." - ) + assert ( + "HEAD" == cache_requests[0] and len(cache_requests) == 2 + ), "We should call only `model_info` to check for commit hash and knowing if shard index is present." def test_weight_overwrite(self): with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py new file mode 100644 index 0000000000..7c64e0a04b --- /dev/null +++ b/tests/models/testing_utils/__init__.py @@ -0,0 +1,2 @@ +from .common import ModelTesterMixin +from .single_file import SingleFileTesterMixin diff --git a/tests/models/testing_utils/attention.py b/tests/models/testing_utils/attention.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py new file mode 100644 index 0000000000..ec4245af12 --- /dev/null +++ b/tests/models/testing_utils/common.py @@ -0,0 +1,304 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from typing import Dict, List, Tuple + +import pytest +import torch + +from ...testing_utils import torch_device + + +class ModelTesterMixin: + """ + Base mixin class for model testing with common test methods. + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + - main_input_name: Name of the main input tensor (e.g., "sample", "hidden_states") + - base_precision: Default tolerance for floating point comparisons (default: 1e-3) + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + """ + + model_class = None + base_precision = 1e-3 + + def get_init_dict(self): + raise NotImplementedError("get_init_dict must be implemented by subclasses. ") + + def get_dummy_inputs(self): + raise NotImplementedError( + "get_dummy_inputs must be implemented by subclasses. " "It should return inputs_dict." + ) + + def check_device_map_is_respected(self, model, device_map): + """Helper method to check if device map is correctly applied to model parameters.""" + for param_name, param in model.named_parameters(): + # Find device in device_map + while len(param_name) > 0 and param_name not in device_map: + param_name = ".".join(param_name.split(".")[:-1]) + if param_name not in device_map: + raise ValueError("device map is incomplete, it does not contain any device for `param_name`.") + + param_device = device_map[param_name] + if param_device in ["cpu", "disk"]: + assert param.device == torch.device( + "meta" + ), f"Expected device 'meta' for {param_name}, got {param.device}" + else: + assert param.device == torch.device( + param_device + ), f"Expected device {param_device} for {param_name}, got {param.device}" + + def test_from_save_pretrained(self, expected_max_diff=5e-5): + """Test that model can be saved and loaded with save_pretrained/from_pretrained.""" + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname) + new_model.to(torch_device) + + with torch.no_grad(): + image = model(**self.get_dummy_inputs()) + + if isinstance(image, dict): + image = image.to_tuple()[0] + + new_image = new_model(**self.get_dummy_inputs()) + + if isinstance(new_image, dict): + new_image = new_image.to_tuple()[0] + + max_diff = (image - new_image).abs().max().item() + assert ( + max_diff <= expected_max_diff + ), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" + + def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): + """Test save_pretrained/from_pretrained with variant parameter.""" + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, variant="fp16") + new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16") + + # non-variant cannot be loaded + with pytest.raises(OSError) as exc_info: + self.model_class.from_pretrained(tmpdirname) + + # make sure that error message states what keys are missing + assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value) + + new_model.to(torch_device) + + with torch.no_grad(): + image = model(**self.get_dummy_inputs()) + if isinstance(image, dict): + image = image.to_tuple()[0] + + new_image = new_model(**self.get_dummy_inputs()) + + if isinstance(new_image, dict): + new_image = new_image.to_tuple()[0] + + max_diff = (image - new_image).abs().max().item() + assert ( + max_diff <= expected_max_diff + ), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" + + def test_from_save_pretrained_dtype(self): + """Test save_pretrained/from_pretrained preserves dtype correctly.""" + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + if torch_device == "mps" and dtype == torch.bfloat16: + continue + with tempfile.TemporaryDirectory() as tmpdirname: + model.to(dtype) + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype) + assert new_model.dtype == dtype + if ( + hasattr(self.model_class, "_keep_in_fp32_modules") + and self.model_class._keep_in_fp32_modules is None + ): + new_model = self.model_class.from_pretrained( + tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype + ) + assert new_model.dtype == dtype + + def test_determinism(self, expected_max_diff=1e-5): + """Test that model outputs are deterministic across multiple forward passes.""" + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + first = model(**self.get_dummy_inputs()) + if isinstance(first, dict): + first = first.to_tuple()[0] + + second = model(**self.get_dummy_inputs()) + if isinstance(second, dict): + second = second.to_tuple()[0] + + # Remove NaN values and compute max difference + first_flat = first.flatten() + second_flat = second.flatten() + + # Filter out NaN values + mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat)) + first_filtered = first_flat[mask] + second_filtered = second_flat[mask] + + max_diff = torch.abs(first_filtered - second_filtered).max().item() + assert ( + max_diff <= expected_max_diff + ), f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}" + + def test_output(self, expected_output_shape=None): + """Test that model produces output with expected shape.""" + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + inputs_dict = self.get_dummy_inputs() + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + assert output is not None, "Model output is None" + assert ( + output.shape == expected_output_shape + ), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}" + + def test_model_from_pretrained(self): + """Test that model loaded from pretrained matches original model.""" + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + # test if the model can be loaded from the config + # and has all the expected shape + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, safe_serialization=False) + new_model = self.model_class.from_pretrained(tmpdirname) + new_model.to(torch_device) + new_model.eval() + + # check if all parameters shape are the same + for param_name in model.state_dict().keys(): + param_1 = model.state_dict()[param_name] + param_2 = new_model.state_dict()[param_name] + assert ( + param_1.shape == param_2.shape + ), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}" + + with torch.no_grad(): + output_1 = model(**self.get_dummy_inputs()) + + if isinstance(output_1, dict): + output_1 = output_1.to_tuple()[0] + + output_2 = new_model(**self.get_dummy_inputs()) + + if isinstance(output_2, dict): + output_2 = output_2.to_tuple()[0] + + assert ( + output_1.shape == output_2.shape + ), f"Output shape mismatch. Original: {output_1.shape}, loaded: {output_2.shape}" + + def test_outputs_equivalence(self): + """Test that dict and tuple outputs are equivalent.""" + + def set_nan_tensor_to_zero(t): + # Temporary fallback until `aten::_index_put_impl_` is implemented in mps + # Track progress in https://github.com/pytorch/pytorch/issues/77764 + device = t.device + if device.type == "mps": + t = t.to("cpu") + t[t != t] = 0 + return t.to(device) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + assert torch.allclose( + set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 + ), ( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" + f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." + ) + + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs_dict = model(**self.get_dummy_inputs()) + outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False) + + recursive_check(outputs_tuple, outputs_dict) + + def test_model_config_to_json_string(self): + """Test model config can be serialized to JSON string.""" + model = self.model_class(**self.get_init_dict()) + + json_string = model.config.to_json_string() + assert isinstance(json_string, str), "Config to_json_string should return a string" + assert len(json_string) > 0, "JSON string should not be empty" + + def test_keep_in_fp32_modules(self): + r""" + A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16 + Also ensures if inference works. + """ + if not hasattr(self.model_class, "_keep_in_fp32_modules"): + pytest.skip("Model does not have _keep_in_fp32_modules") + + fp32_modules = self.model_class._keep_in_fp32_modules + + for torch_dtype in [torch.bfloat16, torch.float16]: + model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, torch_dtype=torch_dtype).to( + torch_device + ) + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules): + assert param.data == torch.float32 + else: + assert param.data == torch_dtype diff --git a/tests/models/testing_utils/compile.py b/tests/models/testing_utils/compile.py new file mode 100644 index 0000000000..2c083176c5 --- /dev/null +++ b/tests/models/testing_utils/compile.py @@ -0,0 +1,166 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +import tempfile + +import pytest +import torch + +from ...testing_utils import ( + backend_empty_cache, + is_torch_compile, + require_accelerator, + require_torch_version_greater, + torch_device, +) + + +@is_torch_compile +@require_accelerator +@require_torch_version_greater("2.7.1") +class TorchCompileTesterMixin: + """ + Mixin class for testing torch.compile functionality on models. + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + - different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic shape testing + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + """ + + different_shapes_for_compilation = None + + def setup_method(self): + """Setup before each test method.""" + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + """Cleanup after each test method.""" + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + def test_torch_compile_recompilation_and_graph_break(self): + """Test that model compiles without graph breaks and doesn't recompile unnecessarily.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict).to(torch_device) + model.eval() + model = torch.compile(model, fullgraph=True) + + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + torch.no_grad(), + ): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + + def test_torch_compile_repeated_blocks(self): + """Test compilation of repeated blocks if model supports it.""" + if self.model_class._repeated_blocks is None: + pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.") + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict).to(torch_device) + model.eval() + model.compile_repeated_blocks(fullgraph=True) + + recompile_limit = 1 + if self.model_class.__name__ == "UNet2DConditionModel": + recompile_limit = 2 + + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(recompile_limit=recompile_limit), + torch.no_grad(), + ): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + + def test_compile_with_group_offloading(self): + """Test that compilation works with group offloading enabled.""" + if not self.model_class._supports_group_offloading: + pytest.skip("Model does not support group offloading.") + + torch._dynamo.config.cache_size_limit = 10000 + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) + model.eval() + + group_offload_kwargs = { + "onload_device": torch_device, + "offload_device": "cpu", + "offload_type": "block_level", + "num_blocks_per_group": 1, + "use_stream": True, + "non_blocking": True, + } + model.enable_group_offload(**group_offload_kwargs) + model.compile() + + with torch.no_grad(): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + + def test_compile_on_different_shapes(self): + """Test dynamic compilation on different input shapes.""" + if self.different_shapes_for_compilation is None: + pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") + torch.fx.experimental._config.use_duck_shape = False + + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + model = torch.compile(model, fullgraph=True, dynamic=True) + + for height, width in self.different_shapes_for_compilation: + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): + inputs_dict = self.get_dummy_inputs(height=height, width=width) + _ = model(**inputs_dict) + + def test_compile_works_with_aot(self): + """Test that model works with ahead-of-time compilation and packaging.""" + from torch._inductor.package import load_package + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict).to(torch_device) + exported_model = torch.export.export(model, args=(), kwargs=inputs_dict) + + with tempfile.TemporaryDirectory() as tmpdir: + package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2") + _ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path) + assert os.path.exists(package_path), f"Package file not created at {package_path}" + loaded_binary = load_package(package_path, run_single_threaded=True) + + model.forward = loaded_binary + + with torch.no_grad(): + _ = model(**inputs_dict) + _ = model(**inputs_dict) diff --git a/tests/models/testing_utils/hub.py b/tests/models/testing_utils/hub.py new file mode 100644 index 0000000000..cbaded9fff --- /dev/null +++ b/tests/models/testing_utils/hub.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import uuid + +import pytest +import torch +from huggingface_hub import ModelCard, delete_repo +from huggingface_hub.utils import is_jinja_available + +from ...others.test_utils import TOKEN, USER, is_staging_test + + +@is_staging_test +class ModelPushToHubTesterMixin: + """ + Mixin class for testing push_to_hub functionality on models. + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + """ + + identifier = uuid.uuid4() + repo_id = f"test-model-{identifier}" + org_repo_id = f"valid_org/{repo_id}-org" + + def test_push_to_hub(self): + """Test pushing model to hub and loading it back.""" + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.push_to_hub(self.repo_id, token=TOKEN) + + new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}") + for p1, p2 in zip(model.parameters(), new_model.parameters()): + assert torch.equal(p1, p2), "Parameters don't match after push_to_hub and from_pretrained" + + # Reset repo + delete_repo(token=TOKEN, repo_id=self.repo_id) + + # Push to hub via save_pretrained + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN) + + new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}") + for p1, p2 in zip(model.parameters(), new_model.parameters()): + assert torch.equal( + p1, p2 + ), "Parameters don't match after save_pretrained with push_to_hub and from_pretrained" + + # Reset repo + delete_repo(self.repo_id, token=TOKEN) + + def test_push_to_hub_in_organization(self): + """Test pushing model to hub in organization namespace.""" + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.push_to_hub(self.org_repo_id, token=TOKEN) + + new_model = self.model_class.from_pretrained(self.org_repo_id) + for p1, p2 in zip(model.parameters(), new_model.parameters()): + assert torch.equal(p1, p2), "Parameters don't match after push_to_hub to org and from_pretrained" + + # Reset repo + delete_repo(token=TOKEN, repo_id=self.org_repo_id) + + # Push to hub via save_pretrained + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id) + + new_model = self.model_class.from_pretrained(self.org_repo_id) + for p1, p2 in zip(model.parameters(), new_model.parameters()): + assert torch.equal( + p1, p2 + ), "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained" + + # Reset repo + delete_repo(self.org_repo_id, token=TOKEN) + + def test_push_to_hub_library_name(self): + """Test that library_name in model card is set to 'diffusers'.""" + if not is_jinja_available(): + pytest.skip("Model card tests cannot be performed without Jinja installed.") + + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.push_to_hub(self.repo_id, token=TOKEN) + + model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data + assert ( + model_card.library_name == "diffusers" + ), f"Expected library_name 'diffusers', got {model_card.library_name}" + + # Reset repo + delete_repo(self.repo_id, token=TOKEN) diff --git a/tests/models/testing_utils/ip_adapter.py b/tests/models/testing_utils/ip_adapter.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/testing_utils/lora.py b/tests/models/testing_utils/lora.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/testing_utils/offloading.py b/tests/models/testing_utils/offloading.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/testing_utils/single_file.py b/tests/models/testing_utils/single_file.py new file mode 100644 index 0000000000..561dc3c567 --- /dev/null +++ b/tests/models/testing_utils/single_file.py @@ -0,0 +1,252 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile + +import torch +from huggingface_hub import hf_hub_download, snapshot_download + +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name + +from ...testing_utils import ( + backend_empty_cache, + nightly, + require_torch_accelerator, + torch_device, +) + + +def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir): + """Download a single file checkpoint from the Hub to a temporary directory.""" + path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir) + return path + + +def download_diffusers_config(pretrained_model_name_or_path, tmpdir): + """Download diffusers config files (excluding weights) from a repository.""" + path = snapshot_download( + pretrained_model_name_or_path, + ignore_patterns=[ + "**/*.ckpt", + "*.ckpt", + "**/*.bin", + "*.bin", + "**/*.pt", + "*.pt", + "**/*.safetensors", + "*.safetensors", + ], + allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"], + local_dir=tmpdir, + ) + return path + + +@nightly +@require_torch_accelerator +@is_single_file +class SingleFileTesterMixin: + """ + Mixin class for testing single file loading for models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - ckpt_path: Path or Hub path to the single file checkpoint + - subfolder: (Optional) Subfolder within the repo + - torch_dtype: (Optional) torch dtype to use for testing + """ + + pretrained_model_name_or_path = None + ckpt_path = None + + def setup_method(self): + """Setup before each test method.""" + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + """Cleanup after each test method.""" + gc.collect() + backend_empty_cache(torch_device) + + def test_single_file_model_config(self): + """Test that config matches between pretrained and single file loading.""" + pretrained_kwargs = {} + single_file_kwargs = {} + + pretrained_kwargs["device"] = torch_device + single_file_kwargs["device"] = torch_device + + if hasattr(self, "subfolder") and self.subfolder: + pretrained_kwargs["subfolder"] = self.subfolder + + if hasattr(self, "torch_dtype") and self.torch_dtype: + pretrained_kwargs["torch_dtype"] = self.torch_dtype + single_file_kwargs["torch_dtype"] = self.torch_dtype + + model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs) + model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert model.config[param_name] == param_value, ( + f"{param_name} differs between pretrained loading and single file loading: " + f"pretrained={model.config[param_name]}, single_file={param_value}" + ) + + def test_single_file_model_parameters(self): + """Test that parameters match between pretrained and single file loading.""" + pretrained_kwargs = {} + single_file_kwargs = {} + + pretrained_kwargs["device"] = torch_device + single_file_kwargs["device"] = torch_device + + if hasattr(self, "subfolder") and self.subfolder: + pretrained_kwargs["subfolder"] = self.subfolder + + if hasattr(self, "torch_dtype") and self.torch_dtype: + pretrained_kwargs["torch_dtype"] = self.torch_dtype + single_file_kwargs["torch_dtype"] = self.torch_dtype + + model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs) + model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs) + + state_dict = model.state_dict() + state_dict_single_file = model_single_file.state_dict() + + assert set(state_dict.keys()) == set(state_dict_single_file.keys()), ( + "Model parameters keys differ between pretrained and single file loading. " + f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. " + f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}" + ) + + for key in state_dict.keys(): + param = state_dict[key] + param_single_file = state_dict_single_file[key] + + assert param.shape == param_single_file.shape, ( + f"Parameter shape mismatch for {key}: " + f"pretrained {param.shape} vs single file {param_single_file.shape}" + ) + + assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), ( + f"Parameter values differ for {key}: " + f"max difference {torch.max(torch.abs(param - param_single_file)).item()}" + ) + + def test_single_file_loading_local_files_only(self): + """Test single file loading with local_files_only=True.""" + single_file_kwargs = {} + + if hasattr(self, "torch_dtype") and self.torch_dtype: + single_file_kwargs["torch_dtype"] = self.torch_dtype + + with tempfile.TemporaryDirectory() as tmpdir: + pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir) + + model_single_file = self.model_class.from_single_file( + local_ckpt_path, local_files_only=True, **single_file_kwargs + ) + + assert model_single_file is not None, "Failed to load model with local_files_only=True" + + def test_single_file_loading_with_diffusers_config(self): + """Test single file loading with diffusers config.""" + single_file_kwargs = {} + + if hasattr(self, "torch_dtype") and self.torch_dtype: + single_file_kwargs["torch_dtype"] = self.torch_dtype + + # Load with config parameter + model_single_file = self.model_class.from_single_file( + self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs + ) + + # Load pretrained for comparison + pretrained_kwargs = {} + if hasattr(self, "subfolder") and self.subfolder: + pretrained_kwargs["subfolder"] = self.subfolder + if hasattr(self, "torch_dtype") and self.torch_dtype: + pretrained_kwargs["torch_dtype"] = self.torch_dtype + + model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs) + + # Compare configs + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert ( + model.config[param_name] == param_value + ), f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}" + + def test_single_file_loading_with_diffusers_config_local_files_only(self): + """Test single file loading with diffusers config and local_files_only=True.""" + single_file_kwargs = {} + + if hasattr(self, "torch_dtype") and self.torch_dtype: + single_file_kwargs["torch_dtype"] = self.torch_dtype + + with tempfile.TemporaryDirectory() as tmpdir: + pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir) + local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, tmpdir) + + model_single_file = self.model_class.from_single_file( + local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs + ) + + assert model_single_file is not None, "Failed to load model with config and local_files_only=True" + + def test_single_file_loading_dtype(self): + """Test single file loading with different dtypes.""" + for dtype in [torch.float32, torch.float16]: + if torch_device == "mps" and dtype == torch.bfloat16: + continue + + model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype) + + assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}" + + # Cleanup + del model_single_file + gc.collect() + backend_empty_cache(torch_device) + + def test_checkpoint_variant_loading(self): + """Test loading checkpoints with alternate keys/variants if provided.""" + if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths: + return + + for ckpt_path in self.alternate_ckpt_paths: + backend_empty_cache(torch_device) + + single_file_kwargs = {} + if hasattr(self, "torch_dtype") and self.torch_dtype: + single_file_kwargs["torch_dtype"] = self.torch_dtype + + model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs) + + assert model is not None, f"Failed to load checkpoint from {ckpt_path}" + + del model + gc.collect() + backend_empty_cache(torch_device) diff --git a/tests/models/testing_utils/training.py b/tests/models/testing_utils/training.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/transformers/test_models_transformer_flux_.py b/tests/models/transformers/test_models_transformer_flux_.py new file mode 100644 index 0000000000..a67218548c --- /dev/null +++ b/tests/models/transformers/test_models_transformer_flux_.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from diffusers import FluxTransformer2DModel +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import LoraHotSwappingForModelTesterMixin +from ..testing_utils.common import ModelTesterMixin +from ..testing_utils.compile import TorchCompileTesterMixin +from ..testing_utils.single_file import SingleFileTesterMixin + + +enable_full_determinism() + + +class FluxTransformerTesterConfig: + model_class = FluxTransformer2DModel + + def get_init_dict(self): + """Return Flux model initialization arguments.""" + return { + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "pooled_projection_dim": 32, + "axes_dims_rope": [4, 4, 8], + } + + def get_dummy_inputs(self): + batch_size = 1 + height = width = 4 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 24 + embedding_dim = 8 + + return { + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), + "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), + "pooled_projections": randn_tensor((batch_size, embedding_dim)), + "img_ids": randn_tensor((height * width, num_image_channels)), + "txt_ids": randn_tensor((sequence_length, num_image_channels)), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + } + + @property + def input_shape(self): + return (16, 4) + + @property + def output_shape(self): + return (16, 4) + + +class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin): + def test_deprecated_inputs_img_txt_ids_3d(self): + """Test that deprecated 3D img_ids and txt_ids still work.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output_1 = model(**inputs_dict).to_tuple()[0] + + # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated) + text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0) + image_ids_3d = inputs_dict["img_ids"].unsqueeze(0) + + assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor" + assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor" + + inputs_dict["txt_ids"] = text_ids_3d + inputs_dict["img_ids"] = image_ids_3d + + with torch.no_grad(): + output_2 = model(**inputs_dict).to_tuple()[0] + + assert output_1.shape == output_2.shape + assert torch.allclose(output_1, output_2, atol=1e-5), ( + "output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) " + "are not equal as them as 2d inputs" + ) + + +class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin): + ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" + alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] + pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" + subfolder = "transformer" + pass + + +class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin): + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height=4, width=4): + """Override to support dynamic height/width for compilation tests.""" + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 24 + embedding_dim = 8 + + return { + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), + "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), + "pooled_projections": randn_tensor((batch_size, embedding_dim)), + "img_ids": randn_tensor((height * width, num_image_channels)), + "txt_ids": randn_tensor((sequence_length, num_image_channels)), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + } + + +class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height=4, width=4): + """Override to support dynamic height/width for LoRA hotswap tests.""" + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 48 + embedding_dim = 32 + + return { + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), + "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), + "pooled_projections": randn_tensor((batch_size, embedding_dim)), + "img_ids": randn_tensor((height * width, num_image_channels)), + "txt_ids": randn_tensor((sequence_length, num_image_channels)), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + }