1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-08-25 18:53:02 +02:00
parent 827fc1599a
commit a0177ebfec

View File

@@ -355,7 +355,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
)
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
if not _CAN_USE_FLASH_ATTN_3:
if not _CAN_USE_FLASH_ATTN_3 and (flash_attn_3_func is None and flash_attn_3_varlen_func is None):
raise RuntimeError(
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
)
@@ -645,24 +645,34 @@ def _flash_attention_3(
deterministic: bool = False,
return_attn_probs: bool = False,
) -> torch.Tensor:
out, lse, *_ = flash_attn_3_func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
attention_chunk=0,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
sig = inspect.signature(flash_attn_3_func)
accepted = set(sig.parameters)
params = {
"q": query,
"k": key,
"v": value,
"softmax_scale": scale,
"causal": is_causal,
"qv": None,
"q_descale": None,
"k_descale": None,
"v_descale": None,
"window_size": window_size,
"attention_chunk": 0,
"softcap": softcap,
"num_splits": 1,
"pack_gqa": None,
"deterministic": deterministic,
"sm_margin": 0,
}
kwargs = {}
for name, value in params.items():
if name not in accepted:
logger.debug(f"{name} is not accepted by the `flash_attn_3_func` method, so it will be discarded.")
else:
kwargs[name] = value
out, lse, *_ = flash_attn_3_func(**kwargs)
return (out, lse) if return_attn_probs else out