diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e56f53150c..09eabea516 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -355,7 +355,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None ) elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]: - if not _CAN_USE_FLASH_ATTN_3: + if not _CAN_USE_FLASH_ATTN_3 and (flash_attn_3_func is None and flash_attn_3_varlen_func is None): raise RuntimeError( f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." ) @@ -645,24 +645,34 @@ def _flash_attention_3( deterministic: bool = False, return_attn_probs: bool = False, ) -> torch.Tensor: - out, lse, *_ = flash_attn_3_func( - q=query, - k=key, - v=value, - softmax_scale=scale, - causal=is_causal, - qv=None, - q_descale=None, - k_descale=None, - v_descale=None, - window_size=window_size, - attention_chunk=0, - softcap=softcap, - num_splits=1, - pack_gqa=None, - deterministic=deterministic, - sm_margin=0, - ) + sig = inspect.signature(flash_attn_3_func) + accepted = set(sig.parameters) + params = { + "q": query, + "k": key, + "v": value, + "softmax_scale": scale, + "causal": is_causal, + "qv": None, + "q_descale": None, + "k_descale": None, + "v_descale": None, + "window_size": window_size, + "attention_chunk": 0, + "softcap": softcap, + "num_splits": 1, + "pack_gqa": None, + "deterministic": deterministic, + "sm_margin": 0, + } + kwargs = {} + for name, value in params.items(): + if name not in accepted: + logger.debug(f"{name} is not accepted by the `flash_attn_3_func` method, so it will be discarded.") + else: + kwargs[name] = value + + out, lse, *_ = flash_attn_3_func(**kwargs) return (out, lse) if return_attn_probs else out