1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

fix more tests

This commit is contained in:
Aryan
2025-07-15 08:30:26 +02:00
parent bc64f12c98
commit a0b276da53
2 changed files with 6 additions and 12 deletions

View File

@@ -18,6 +18,7 @@ from typing import Any, Callable, Optional, Tuple, Union
import torch
from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from .hooks import HookRegistry, ModelHook
@@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
config.spatial_attention_block_skip_range = 2
for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES):
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.

View File

@@ -8,12 +8,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import floats_tensor, torch_device
from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
class ChromaImg2ImgPipelineFastTests(
@@ -129,12 +124,10 @@ class ChromaImg2ImgPipelineFastTests(
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), (
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
self.assertTrue(
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images