1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

improve test

This commit is contained in:
Aryan
2025-07-14 07:46:32 +02:00
parent ecabd2a46e
commit ff21b7fe8b
2 changed files with 19 additions and 7 deletions

View File

@@ -28,8 +28,7 @@ from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
check_qkv_fused_layers_exist,
)
@@ -171,12 +170,10 @@ class FluxPipelineFastTests(
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), (
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
self.assertTrue(
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images

View File

@@ -37,6 +37,7 @@ from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
@@ -98,6 +99,20 @@ def check_qkv_fusion_processors_exist(model):
return all(p.startswith("Fused") for p in proc_names)
def check_qkv_fused_layers_exist(model, layer_names):
is_fused_submodules = []
for submodule in model.modules():
if not isinstance(submodule, AttentionModuleMixin):
continue
is_fused_attribute_set = submodule.fused_projections
is_fused_layer = True
for layer in layer_names:
is_fused_layer = is_fused_layer and getattr(submodule, layer, None) is not None
is_fused = is_fused_attribute_set and is_fused_layer
is_fused_submodules.append(is_fused)
return all(is_fused_submodules)
class SDFunctionTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.