From f4fa3beee7f49b80ce7a58f9c8002f43299175c9 Mon Sep 17 00:00:00 2001 From: Seokhyeon Jeong Date: Wed, 14 May 2025 14:56:12 +0900 Subject: [PATCH] [tests] Add torch.compile test for UNet2DConditionModel (#11537) Co-authored-by: Sayak Paul --- tests/models/unets/test_models_unet_2d_condition.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 94a5d641a7..24d944bbf9 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -53,7 +53,12 @@ from diffusers.utils.testing_utils import ( torch_device, ) -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ( + LoraHotSwappingForModelTesterMixin, + ModelTesterMixin, + TorchCompileTesterMixin, + UNetTesterMixin, +) if is_peft_available(): @@ -351,7 +356,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): class UNet2DConditionModelTests( - ModelTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase + ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase ): model_class = UNet2DConditionModel main_input_name = "sample"