1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Aryan
2025-07-22 01:12:07 +02:00
parent bb443f99dc
commit 66ce9ccb03

View File

@@ -314,10 +314,8 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
):
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
cu_seqlens_k = torch.cumsum(seqlens_q, dim=0, dtype=torch.int32)
cu_seqlens_q = torch.cumsum(seqlens_k, dim=0, dtype=torch.int32)
cu_seqlens_q = torch.nn.functional.pad(cu_seqlens_q, (1, 0))
cu_seqlens_k = torch.nn.functional.pad(cu_seqlens_k, (1, 0))
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), (1, 0))
max_seqlen_q = seqlens_q.max().item()
max_seqlen_k = seqlens_k.max().item()
return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)