From c1415207145b1417e0839fd39862b7eeb38bdaea Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 09:49:17 +0200 Subject: [PATCH] fight more tests --- .../models/transformers/transformer_flux.py | 2 +- .../test_controlnet_flux_img2img.py | 14 ++++---------- tests/pipelines/flux/test_pipeline_flux_control.py | 14 ++++---------- .../flux/test_pipeline_flux_control_inpaint.py | 14 ++++---------- 4 files changed, 13 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 7898bdb1f0..706438569f 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -242,7 +242,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module): key = apply_rotary_emb(key, image_rotary_emb) hidden_states = torch.nn.functional.scaled_dot_product_attention( - query, key, value, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py index 8d63619c40..ab4cf32734 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py @@ -16,11 +16,7 @@ from diffusers.utils.testing_utils import ( ) from diffusers.utils.torch_utils import randn_tensor -from ..test_pipelines_common import ( - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): @@ -170,12 +166,10 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi original_image_slice = image[0, -3:, -3:, -1] pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py index d8d0774e1e..42283da6fd 100644 --- a/tests/pipelines/flux/test_pipeline_flux_control.py +++ b/tests/pipelines/flux/test_pipeline_flux_control.py @@ -8,11 +8,7 @@ from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPToken from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel from diffusers.utils.testing_utils import torch_device -from ..test_pipelines_common import ( - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin): @@ -140,12 +136,10 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images diff --git a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py index a2f7c91710..0abd08e373 100644 --- a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py +++ b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py @@ -15,11 +15,7 @@ from diffusers.utils.testing_utils import ( torch_device, ) -from ..test_pipelines_common import ( - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): @@ -134,12 +130,10 @@ class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images