diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index bbdd1c3f68..1c87871941 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -18,6 +18,7 @@ from typing import Any, Callable, Optional, Tuple, Union import torch +from ..models.attention import AttentionModuleMixin from ..models.attention_processor import Attention, MochiAttention from ..utils import logging from .hooks import HookRegistry, ModelHook @@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt config.spatial_attention_block_skip_range = 2 for name, submodule in module.named_modules(): - if not isinstance(submodule, _ATTENTION_CLASSES): + if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB # cannot be applied to this layer. For custom layers, users can extend this functionality and implement # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`. diff --git a/tests/pipelines/chroma/test_pipeline_chroma_img2img.py b/tests/pipelines/chroma/test_pipeline_chroma_img2img.py index 02b20527b2..d518e1b7b8 100644 --- a/tests/pipelines/chroma/test_pipeline_chroma_img2img.py +++ b/tests/pipelines/chroma/test_pipeline_chroma_img2img.py @@ -8,12 +8,7 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers.utils.testing_utils import floats_tensor, torch_device -from ..test_pipelines_common import ( - FluxIPAdapterTesterMixin, - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist class ChromaImg2ImgPipelineFastTests( @@ -129,12 +124,10 @@ class ChromaImg2ImgPipelineFastTests( # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images