1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

fix chroma qkv fusion test

This commit is contained in:
Aryan
2025-07-15 07:51:58 +02:00
parent 17b678fc6f
commit 0cda91d467

View File

@@ -7,12 +7,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import torch_device
from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
class ChromaPipelineFastTests(
@@ -126,12 +121,10 @@ class ChromaPipelineFastTests(
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), (
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
self.assertTrue(
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images