1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

refactor; support flash attention 2 with cp

This commit is contained in:
Aryan
2025-07-16 21:31:53 +02:00
parent 79736265c5
commit f859fdf7ba

View File

@@ -571,8 +571,8 @@ class _cudnn_attention(torch.autograd.Function):
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
):
@@ -653,8 +653,8 @@ class _flash_attention_2(torch.autograd.Function):
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
):
@@ -753,8 +753,8 @@ class TemplatedRingAttention(torch.autograd.Function):
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
dropout_p: float,
scale: Optional[float],
is_causal: bool,
scale: Optional[float],
enable_gqa: bool,
return_lse: bool,
op: torch.autograd.Function,
@@ -778,7 +778,7 @@ class TemplatedRingAttention(torch.autograd.Function):
value = kv[key.numel() :].reshape_as(value)
next_rank = (next_rank + 1) % world_size
out, lse = op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, True)
out, lse = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, True)
if parallel_config.convert_to_fp32:
out = out.to(torch.float32)
@@ -813,8 +813,8 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
dropout_p: float,
scale: Optional[float],
is_causal: bool,
scale: Optional[float],
enable_gqa: bool,
return_lse: bool,
op: torch.autograd.Function,
@@ -833,7 +833,7 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value))
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
out = op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse)
out = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
if return_lse:
out, lse, *_ = out
@@ -883,14 +883,14 @@ def _templated_context_parallel_attention(
# TODO: add support for unified attention with ring/ulysses degree both being > 1
if parallel_config.ring_degree > 1:
return TemplatedRingAttention.apply(
query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse, op
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op
)
elif parallel_config.ulysses_degree > 1:
return TemplatedUlyssesAttention.apply(
query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse, op
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op
)
else:
return op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse)
return op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
# ===== Attention backends =====
@@ -905,20 +905,33 @@ def _flash_attention(
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
) -> torch.Tensor:
out = flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
return out
parallel_config = _AttentionBackendRegistry._parallel_config
lse = None
if parallel_config is None:
out = flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query, key, value, None, dropout_p, is_causal, scale, False, return_lse, op=_flash_attention_2
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(