diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py index 1be5e14362..a6c250b50c 100644 --- a/src/diffusers/hooks/faster_cache.py +++ b/src/diffusers/hooks/faster_cache.py @@ -18,6 +18,7 @@ from typing import Any, Callable, List, Optional, Tuple import torch +from ..models.attention import AttentionModuleMixin from ..models.attention_processor import Attention, MochiAttention from ..models.modeling_outputs import Transformer2DModelOutput from ..utils import logging @@ -567,7 +568,7 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No _apply_faster_cache_on_denoiser(module, config) for name, submodule in module.named_modules(): - if not isinstance(submodule, _ATTENTION_CLASSES): + if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): continue if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS): _apply_faster_cache_on_attention_class(name, submodule, config)