diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index fd761c2bbc..1e540d6202 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -314,10 +314,8 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask( ): seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) - cu_seqlens_k = torch.cumsum(seqlens_q, dim=0, dtype=torch.int32) - cu_seqlens_q = torch.cumsum(seqlens_k, dim=0, dtype=torch.int32) - cu_seqlens_q = torch.nn.functional.pad(cu_seqlens_q, (1, 0)) - cu_seqlens_k = torch.nn.functional.pad(cu_seqlens_k, (1, 0)) + cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0)) + cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), (1, 0)) max_seqlen_q = seqlens_q.max().item() max_seqlen_k = seqlens_k.max().item() return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)