1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-12-09 15:30:33 +05:30
parent 7a8f85b047
commit f732ff1144

View File

@@ -994,21 +994,6 @@ def _flash_attention_backward_op(
return grad_query, grad_key, grad_value
def _maybe_format_lse_for_context_parallel(
lse: Optional[torch.Tensor],
*,
seq_len: int,
num_heads: int,
) -> Optional[torch.Tensor]:
if lse is None or lse.ndim != 3:
return lse
if lse.shape[1] == num_heads and lse.shape[2] == seq_len:
lse = lse.permute(0, 2, 1)
return lse.contiguous()
def _flash_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1063,7 +1048,7 @@ def _flash_attention_hub_forward_op(
alibi_slopes,
return_lse,
)
lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2])
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value, out, lse, rng_state)
@@ -1173,7 +1158,7 @@ def _flash_attention_3_hub_forward_op(
lse = None
if return_lse:
out, lse = out
lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2])
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value)
@@ -1315,7 +1300,7 @@ def _sage_attention_hub_forward_op(
lse = None
if return_lse:
out, lse, *_ = out
lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2])
lse = lse.permute(0, 2, 1).contiguous()
return (out, lse) if return_lse else out