mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
improve test
This commit is contained in:
@@ -28,8 +28,7 @@ from ..test_pipelines_common import (
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
check_qkv_fused_layers_exist,
|
||||
)
|
||||
|
||||
|
||||
@@ -171,12 +170,10 @@ class FluxPipelineFastTests(
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
@@ -37,6 +37,7 @@ from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
|
||||
from diffusers.models.attention import AttentionModuleMixin
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
|
||||
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
|
||||
@@ -98,6 +99,20 @@ def check_qkv_fusion_processors_exist(model):
|
||||
return all(p.startswith("Fused") for p in proc_names)
|
||||
|
||||
|
||||
def check_qkv_fused_layers_exist(model, layer_names):
|
||||
is_fused_submodules = []
|
||||
for submodule in model.modules():
|
||||
if not isinstance(submodule, AttentionModuleMixin):
|
||||
continue
|
||||
is_fused_attribute_set = submodule.fused_projections
|
||||
is_fused_layer = True
|
||||
for layer in layer_names:
|
||||
is_fused_layer = is_fused_layer and getattr(submodule, layer, None) is not None
|
||||
is_fused = is_fused_attribute_set and is_fused_layer
|
||||
is_fused_submodules.append(is_fused)
|
||||
return all(is_fused_submodules)
|
||||
|
||||
|
||||
class SDFunctionTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
|
||||
|
||||
Reference in New Issue
Block a user