From 0cda91d467636458bf77beb69cfa6ab62ceab324 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 07:51:58 +0200 Subject: [PATCH] fix chroma qkv fusion test --- tests/pipelines/chroma/test_pipeline_chroma.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/tests/pipelines/chroma/test_pipeline_chroma.py b/tests/pipelines/chroma/test_pipeline_chroma.py index fc5749f96c..5121a2b52d 100644 --- a/tests/pipelines/chroma/test_pipeline_chroma.py +++ b/tests/pipelines/chroma/test_pipeline_chroma.py @@ -7,12 +7,7 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers.utils.testing_utils import torch_device -from ..test_pipelines_common import ( - FluxIPAdapterTesterMixin, - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist class ChromaPipelineFastTests( @@ -126,12 +121,10 @@ class ChromaPipelineFastTests( # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images