From f859fdf7ba742e66c1e4f809cfebda59cd94aa64 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Jul 2025 21:31:53 +0200 Subject: [PATCH] refactor; support flash attention 2 with cp --- src/diffusers/models/attention_dispatch.py | 53 ++++++++++++++-------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 07a9c09b5e..6357bdad0f 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -571,8 +571,8 @@ class _cudnn_attention(torch.autograd.Function): value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, - scale: Optional[float] = None, is_causal: bool = False, + scale: Optional[float] = None, enable_gqa: bool = False, return_lse: bool = False, ): @@ -653,8 +653,8 @@ class _flash_attention_2(torch.autograd.Function): value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, - scale: Optional[float] = None, is_causal: bool = False, + scale: Optional[float] = None, enable_gqa: bool = False, return_lse: bool = False, ): @@ -753,8 +753,8 @@ class TemplatedRingAttention(torch.autograd.Function): value: torch.Tensor, attn_mask: Optional[torch.Tensor], dropout_p: float, - scale: Optional[float], is_causal: bool, + scale: Optional[float], enable_gqa: bool, return_lse: bool, op: torch.autograd.Function, @@ -778,7 +778,7 @@ class TemplatedRingAttention(torch.autograd.Function): value = kv[key.numel() :].reshape_as(value) next_rank = (next_rank + 1) % world_size - out, lse = op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, True) + out, lse = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, True) if parallel_config.convert_to_fp32: out = out.to(torch.float32) @@ -813,8 +813,8 @@ class TemplatedUlyssesAttention(torch.autograd.Function): value: torch.Tensor, attn_mask: Optional[torch.Tensor], dropout_p: float, - scale: Optional[float], is_causal: bool, + scale: Optional[float], enable_gqa: bool, return_lse: bool, op: torch.autograd.Function, @@ -833,7 +833,7 @@ class TemplatedUlyssesAttention(torch.autograd.Function): query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value)) query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) - out = op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse) + out = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse) if return_lse: out, lse, *_ = out @@ -883,14 +883,14 @@ def _templated_context_parallel_attention( # TODO: add support for unified attention with ring/ulysses degree both being > 1 if parallel_config.ring_degree > 1: return TemplatedRingAttention.apply( - query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse, op + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op ) elif parallel_config.ulysses_degree > 1: return TemplatedUlyssesAttention.apply( - query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse, op + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op ) else: - return op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse) + return op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse) # ===== Attention backends ===== @@ -905,20 +905,33 @@ def _flash_attention( key: torch.Tensor, value: torch.Tensor, dropout_p: float = 0.0, - scale: Optional[float] = None, is_causal: bool = False, + scale: Optional[float] = None, return_lse: bool = False, ) -> torch.Tensor: - out = flash_attn_func( - q=query, - k=key, - v=value, - dropout_p=dropout_p, - softmax_scale=scale, - causal=is_causal, - return_attn_probs=return_lse, - ) - return out + parallel_config = _AttentionBackendRegistry._parallel_config + + lse = None + if parallel_config is None: + out = flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_attn_probs=return_lse, + ) + if return_lse: + out, lse, *_ = out + else: + out = _templated_context_parallel_attention( + query, key, value, None, dropout_p, is_causal, scale, False, return_lse, op=_flash_attention_2 + ) + if return_lse: + out, lse = out + + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register(