1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

support CP in native flash attention (#12829)

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Wang, Yi
2025-12-12 15:48:39 +08:00
committed by sayakpaul
parent a7c7a270f6
commit 218b17040f

View File

@@ -868,6 +868,97 @@ def _cudnn_attention_backward_op(
return grad_query, grad_key, grad_value
# https://github.com/pytorch/pytorch/blob/e33fa0ece36a93dbc8ff19b0251b8d99f8ae8668/aten/src/ATen/native/native_functions.yaml#L15135
# forward declaration:
# aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
def _native_flash_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for native flash attention.")
tensors_to_save = ()
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
tensors_to_save += (query, key, value)
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_flash_attention(
query=query,
key=key,
value=value,
dropout_p=dropout_p,
is_causal=is_causal,
return_debug_mask=False,
scale=scale,
)
)
tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
if _save_ctx:
ctx.save_for_backward(*tensors_to_save)
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
ctx.max_q = max_q
ctx.max_k = max_k
out = out.transpose(1, 2).contiguous()
if lse is not None:
lse = lse.transpose(1, 2).contiguous()
return (out, lse) if return_lse else out
# https://github.com/pytorch/pytorch/blob/e33fa0ece36a93dbc8ff19b0251b8d99f8ae8668/aten/src/ATen/native/native_functions.yaml#L15153
# backward declaration:
# aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
def _native_flash_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
grad_out = grad_out.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_flash_attention_backward(
grad_out,
query,
key,
value,
out,
logsumexp=lse,
philox_seed=philox_seed,
philox_offset=philox_offset,
cum_seq_q=cum_seq_q,
cum_seq_k=cum_seq_k,
max_q=ctx.max_q,
max_k=ctx.max_k,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
)
grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value))
return grad_query, grad_key, grad_value
# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
def _flash_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
@@ -1931,6 +2022,7 @@ def _native_efficient_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_FLASH,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
)
def _native_flash_attention(
query: torch.Tensor,
@@ -1943,22 +2035,40 @@ def _native_flash_attention(
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if return_lse:
raise ValueError("Native flash attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=None, # not supported
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
lse = None
if _parallel_config is None and not return_lse:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=None, # not supported
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op=_native_flash_attention_forward_op,
backward_op=_native_flash_attention_backward_op,
_parallel_config=_parallel_config,
)
out = out.permute(0, 2, 1, 3)
return out
if return_lse:
out, lse = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(