diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 475824a855..67a6a29e90 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -2193,6 +2193,8 @@ class TorchCompileTesterMixin: recompile_limit = 1 if self.model_class.__name__ == "UNet2DConditionModel": recompile_limit = 2 + elif self.model_class.__name__ == "ZImageTransformer2DModel": + recompile_limit = 3 with ( torch._inductor.utils.fresh_inductor_cache(), @@ -2294,7 +2296,6 @@ class LoraHotSwappingForModelTesterMixin: backend_empty_cache(torch_device) def get_lora_config(self, lora_rank, lora_alpha, target_modules): - # from diffusers test_models_unet_2d_condition.py from peft import LoraConfig lora_config = LoraConfig(