From effe4b9784d9093a2fb70a4063888fc4e4655ce9 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 27 Jun 2024 01:54:27 +0530 Subject: [PATCH] Update xformers SD3 test (#8712) update --- tests/models/test_modeling_common.py | 4 ++++ tests/pipelines/controlnet_sd3/test_controlnet_sd3.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index f0e01c40a3..ac356d4c52 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -382,6 +382,10 @@ class ModelTesterMixin: # If not has `set_attn_processor`, skip test return + if not hasattr(model, "set_default_attn_processor"): + # If not has `set_attn_processor`, skip test + return + model.set_default_attn_processor() assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values()) with torch.no_grad(): diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index 824c1de1b9..af5b38fafa 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -190,6 +190,10 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 ), f"Expected: {expected_slice}, got: {image_slice.flatten()}" + @unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention") + def test_xformers_attention_forwardGenerator_pass(self): + pass + @slow @require_torch_gpu