From ff21b7fe8b0e3d2c0bbd0341c8273a1f4bb62a7c Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Jul 2025 07:46:32 +0200 Subject: [PATCH] improve test --- tests/pipelines/flux/test_pipeline_flux.py | 11 ++++------- tests/pipelines/test_pipelines_common.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 0df0e028ff..4541521c89 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -28,8 +28,7 @@ from ..test_pipelines_common import ( FluxIPAdapterTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, + check_qkv_fused_layers_exist, ) @@ -171,12 +170,10 @@ class FluxPipelineFastTests( # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 13c25ccaa4..387eb6a614 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -37,6 +37,7 @@ from diffusers.hooks.first_block_cache import FirstBlockCacheConfig from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin +from diffusers.models.attention import AttentionModuleMixin from diffusers.models.attention_processor import AttnProcessor from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel @@ -98,6 +99,20 @@ def check_qkv_fusion_processors_exist(model): return all(p.startswith("Fused") for p in proc_names) +def check_qkv_fused_layers_exist(model, layer_names): + is_fused_submodules = [] + for submodule in model.modules(): + if not isinstance(submodule, AttentionModuleMixin): + continue + is_fused_attribute_set = submodule.fused_projections + is_fused_layer = True + for layer in layer_names: + is_fused_layer = is_fused_layer and getattr(submodule, layer, None) is not None + is_fused = is_fused_attribute_set and is_fused_layer + is_fused_submodules.append(is_fused) + return all(is_fused_submodules) + + class SDFunctionTesterMixin: """ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.