diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 433815d7ed..96920c8631 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -994,21 +994,6 @@ def _flash_attention_backward_op( return grad_query, grad_key, grad_value -def _maybe_format_lse_for_context_parallel( - lse: Optional[torch.Tensor], - *, - seq_len: int, - num_heads: int, -) -> Optional[torch.Tensor]: - if lse is None or lse.ndim != 3: - return lse - - if lse.shape[1] == num_heads and lse.shape[2] == seq_len: - lse = lse.permute(0, 2, 1) - - return lse.contiguous() - - def _flash_attention_hub_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, @@ -1063,7 +1048,7 @@ def _flash_attention_hub_forward_op( alibi_slopes, return_lse, ) - lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2]) + lse = lse.permute(0, 2, 1).contiguous() if _save_ctx: ctx.save_for_backward(query, key, value, out, lse, rng_state) @@ -1173,7 +1158,7 @@ def _flash_attention_3_hub_forward_op( lse = None if return_lse: out, lse = out - lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2]) + lse = lse.permute(0, 2, 1).contiguous() if _save_ctx: ctx.save_for_backward(query, key, value) @@ -1315,7 +1300,7 @@ def _sage_attention_hub_forward_op( lse = None if return_lse: out, lse, *_ = out - lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2]) + lse = lse.permute(0, 2, 1).contiguous() return (out, lse) if return_lse else out