diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py index 58cbd7ada5..60f6de6fd3 100644 --- a/tests/single_file/single_file_testing_utils.py +++ b/tests/single_file/single_file_testing_utils.py @@ -119,6 +119,24 @@ class SingleFileModelTesterMixin: f"max difference {torch.max(torch.abs(param - param_single_file)).item()}" ) + def test_checkpoint_altered_keys_loading(self): + # Test loading with checkpoints that have altered keys + if not hasattr(self, "alternate_keys_ckpt_paths") or not self.alternate_keys_ckpt_paths: + return + + for ckpt_path in self.alternate_keys_ckpt_paths: + backend_empty_cache(torch_device) + + single_file_kwargs = {} + if hasattr(self, "torch_dtype") and self.torch_dtype: + single_file_kwargs["torch_dtype"] = self.torch_dtype + + model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs) + + del model + gc.collect() + backend_empty_cache(torch_device) + class SDSingleFileTesterMixin: single_file_kwargs = {} diff --git a/tests/single_file/test_lumina2_transformer.py b/tests/single_file/test_lumina2_transformer.py index 2984776df5..67de1107ba 100644 --- a/tests/single_file/test_lumina2_transformer.py +++ b/tests/single_file/test_lumina2_transformer.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import unittest from diffusers import ( @@ -21,9 +20,7 @@ from diffusers import ( ) from ..testing_utils import ( - backend_empty_cache, enable_full_determinism, - torch_device, ) from .single_file_testing_utils import SingleFileModelTesterMixin @@ -40,12 +37,3 @@ class Lumina2Transformer2DModelSingleFileTests(SingleFileModelTesterMixin, unitt repo_id = "Alpha-VLLM/Lumina-Image-2.0" subfolder = "transformer" - - def test_checkpoint_loading(self): - for ckpt_path in self.alternate_keys_ckpt_paths: - backend_empty_cache(torch_device) - model = self.model_class.from_single_file(ckpt_path) - - del model - gc.collect() - backend_empty_cache(torch_device) diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py index f36a79660d..457b9fa9cd 100644 --- a/tests/single_file/test_model_flux_transformer_single_file.py +++ b/tests/single_file/test_model_flux_transformer_single_file.py @@ -37,18 +37,8 @@ class FluxTransformer2DModelSingleFileTests(SingleFileModelTesterMixin, unittest alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] repo_id = "black-forest-labs/FLUX.1-dev" - subfolder = "transformer" - def test_checkpoint_loading(self): - for ckpt_path in self.alternate_keys_ckpt_paths: - backend_empty_cache(torch_device) - model = self.model_class.from_single_file(ckpt_path) - - del model - gc.collect() - backend_empty_cache(torch_device) - def test_device_map_cuda(self): backend_empty_cache(torch_device) model = self.model_class.from_single_file(self.ckpt_path, device_map="cuda") diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py index b0577e1d61..85296ec1da 100644 --- a/tests/single_file/test_sana_transformer.py +++ b/tests/single_file/test_sana_transformer.py @@ -1,4 +1,3 @@ -import gc import unittest from diffusers import ( @@ -6,9 +5,7 @@ from diffusers import ( ) from ..testing_utils import ( - backend_empty_cache, enable_full_determinism, - torch_device, ) from .single_file_testing_utils import SingleFileModelTesterMixin @@ -27,12 +24,3 @@ class SanaTransformer2DModelSingleFileTests(SingleFileModelTesterMixin, unittest repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers" subfolder = "transformer" - - def test_checkpoint_loading(self): - for ckpt_path in self.alternate_keys_ckpt_paths: - backend_empty_cache(torch_device) - model = self.model_class.from_single_file(ckpt_path) - - del model - gc.collect() - backend_empty_cache(torch_device)