From 66ce9ccb03b930b5922f548d6a0b919e2dd7aa52 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 22 Jul 2025 01:12:07 +0200 Subject: [PATCH] refacotr --- src/diffusers/models/attention_dispatch.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index fd761c2bbc..1e540d6202 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -314,10 +314,8 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask( ): seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) - cu_seqlens_k = torch.cumsum(seqlens_q, dim=0, dtype=torch.int32) - cu_seqlens_q = torch.cumsum(seqlens_k, dim=0, dtype=torch.int32) - cu_seqlens_q = torch.nn.functional.pad(cu_seqlens_q, (1, 0)) - cu_seqlens_k = torch.nn.functional.pad(cu_seqlens_k, (1, 0)) + cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0)) + cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), (1, 0)) max_seqlen_q = seqlens_q.max().item() max_seqlen_k = seqlens_k.max().item() return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)