From f732ff114403de85658d0f699d9d4b3bd0a32510 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Dec 2025 15:30:33 +0530 Subject: [PATCH] up --- src/diffusers/models/attention_dispatch.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 433815d7ed..96920c8631 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -994,21 +994,6 @@ def _flash_attention_backward_op( return grad_query, grad_key, grad_value -def _maybe_format_lse_for_context_parallel( - lse: Optional[torch.Tensor], - *, - seq_len: int, - num_heads: int, -) -> Optional[torch.Tensor]: - if lse is None or lse.ndim != 3: - return lse - - if lse.shape[1] == num_heads and lse.shape[2] == seq_len: - lse = lse.permute(0, 2, 1) - - return lse.contiguous() - - def _flash_attention_hub_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, @@ -1063,7 +1048,7 @@ def _flash_attention_hub_forward_op( alibi_slopes, return_lse, ) - lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2]) + lse = lse.permute(0, 2, 1).contiguous() if _save_ctx: ctx.save_for_backward(query, key, value, out, lse, rng_state) @@ -1173,7 +1158,7 @@ def _flash_attention_3_hub_forward_op( lse = None if return_lse: out, lse = out - lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2]) + lse = lse.permute(0, 2, 1).contiguous() if _save_ctx: ctx.save_for_backward(query, key, value) @@ -1315,7 +1300,7 @@ def _sage_attention_hub_forward_op( lse = None if return_lse: out, lse, *_ = out - lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2]) + lse = lse.permute(0, 2, 1).contiguous() return (out, lse) if return_lse else out