mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -355,7 +355,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
)
|
||||
|
||||
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
|
||||
if not _CAN_USE_FLASH_ATTN_3:
|
||||
if not _CAN_USE_FLASH_ATTN_3 and (flash_attn_3_func is None and flash_attn_3_varlen_func is None):
|
||||
raise RuntimeError(
|
||||
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
||||
)
|
||||
@@ -645,24 +645,34 @@ def _flash_attention_3(
|
||||
deterministic: bool = False,
|
||||
return_attn_probs: bool = False,
|
||||
) -> torch.Tensor:
|
||||
out, lse, *_ = flash_attn_3_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
attention_chunk=0,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
)
|
||||
sig = inspect.signature(flash_attn_3_func)
|
||||
accepted = set(sig.parameters)
|
||||
params = {
|
||||
"q": query,
|
||||
"k": key,
|
||||
"v": value,
|
||||
"softmax_scale": scale,
|
||||
"causal": is_causal,
|
||||
"qv": None,
|
||||
"q_descale": None,
|
||||
"k_descale": None,
|
||||
"v_descale": None,
|
||||
"window_size": window_size,
|
||||
"attention_chunk": 0,
|
||||
"softcap": softcap,
|
||||
"num_splits": 1,
|
||||
"pack_gqa": None,
|
||||
"deterministic": deterministic,
|
||||
"sm_margin": 0,
|
||||
}
|
||||
kwargs = {}
|
||||
for name, value in params.items():
|
||||
if name not in accepted:
|
||||
logger.debug(f"{name} is not accepted by the `flash_attn_3_func` method, so it will be discarded.")
|
||||
else:
|
||||
kwargs[name] = value
|
||||
|
||||
out, lse, *_ = flash_attn_3_func(**kwargs)
|
||||
return (out, lse) if return_attn_probs else out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user