mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
refactor; support flash attention 2 with cp
This commit is contained in:
@@ -571,8 +571,8 @@ class _cudnn_attention(torch.autograd.Function):
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
):
|
||||
@@ -653,8 +653,8 @@ class _flash_attention_2(torch.autograd.Function):
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
):
|
||||
@@ -753,8 +753,8 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor],
|
||||
dropout_p: float,
|
||||
scale: Optional[float],
|
||||
is_causal: bool,
|
||||
scale: Optional[float],
|
||||
enable_gqa: bool,
|
||||
return_lse: bool,
|
||||
op: torch.autograd.Function,
|
||||
@@ -778,7 +778,7 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
value = kv[key.numel() :].reshape_as(value)
|
||||
next_rank = (next_rank + 1) % world_size
|
||||
|
||||
out, lse = op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, True)
|
||||
out, lse = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, True)
|
||||
|
||||
if parallel_config.convert_to_fp32:
|
||||
out = out.to(torch.float32)
|
||||
@@ -813,8 +813,8 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor],
|
||||
dropout_p: float,
|
||||
scale: Optional[float],
|
||||
is_causal: bool,
|
||||
scale: Optional[float],
|
||||
enable_gqa: bool,
|
||||
return_lse: bool,
|
||||
op: torch.autograd.Function,
|
||||
@@ -833,7 +833,7 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value))
|
||||
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
|
||||
|
||||
out = op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse)
|
||||
out = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
@@ -883,14 +883,14 @@ def _templated_context_parallel_attention(
|
||||
# TODO: add support for unified attention with ring/ulysses degree both being > 1
|
||||
if parallel_config.ring_degree > 1:
|
||||
return TemplatedRingAttention.apply(
|
||||
query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse, op
|
||||
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op
|
||||
)
|
||||
elif parallel_config.ulysses_degree > 1:
|
||||
return TemplatedUlyssesAttention.apply(
|
||||
query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse, op
|
||||
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op
|
||||
)
|
||||
else:
|
||||
return op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse)
|
||||
return op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
|
||||
|
||||
|
||||
# ===== Attention backends =====
|
||||
@@ -905,20 +905,33 @@ def _flash_attention(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
out = flash_attn_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
return out
|
||||
parallel_config = _AttentionBackendRegistry._parallel_config
|
||||
|
||||
lse = None
|
||||
if parallel_config is None:
|
||||
out = flash_attn_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query, key, value, None, dropout_p, is_causal, scale, False, return_lse, op=_flash_attention_2
|
||||
)
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
|
||||
Reference in New Issue
Block a user