mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
refacotr
This commit is contained in:
@@ -314,10 +314,8 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
|
||||
):
|
||||
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
|
||||
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
|
||||
cu_seqlens_k = torch.cumsum(seqlens_q, dim=0, dtype=torch.int32)
|
||||
cu_seqlens_q = torch.cumsum(seqlens_k, dim=0, dtype=torch.int32)
|
||||
cu_seqlens_q = torch.nn.functional.pad(cu_seqlens_q, (1, 0))
|
||||
cu_seqlens_k = torch.nn.functional.pad(cu_seqlens_k, (1, 0))
|
||||
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0))
|
||||
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), (1, 0))
|
||||
max_seqlen_q = seqlens_q.max().item()
|
||||
max_seqlen_k = seqlens_k.max().item()
|
||||
return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
|
||||
|
||||
Reference in New Issue
Block a user