mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -994,21 +994,6 @@ def _flash_attention_backward_op(
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _maybe_format_lse_for_context_parallel(
|
||||
lse: Optional[torch.Tensor],
|
||||
*,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if lse is None or lse.ndim != 3:
|
||||
return lse
|
||||
|
||||
if lse.shape[1] == num_heads and lse.shape[2] == seq_len:
|
||||
lse = lse.permute(0, 2, 1)
|
||||
|
||||
return lse.contiguous()
|
||||
|
||||
|
||||
def _flash_attention_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
@@ -1063,7 +1048,7 @@ def _flash_attention_hub_forward_op(
|
||||
alibi_slopes,
|
||||
return_lse,
|
||||
)
|
||||
lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2])
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value, out, lse, rng_state)
|
||||
@@ -1173,7 +1158,7 @@ def _flash_attention_3_hub_forward_op(
|
||||
lse = None
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2])
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value)
|
||||
@@ -1315,7 +1300,7 @@ def _sage_attention_hub_forward_op(
|
||||
lse = None
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2])
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
Reference in New Issue
Block a user