mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[NPU] npu attention enable ulysses (#12919)
* npu attention enable ulysses * clean the format * register _native_npu_attention to _supports_context_parallel Signed-off-by: yyt <yangyit139@gmail.com> * change npu_fusion_attention's input_layout to BSND to eliminate redundant transpose Signed-off-by: yyt <yangyit139@gmail.com> * Update format --------- Signed-off-by: yyt <yangyit139@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -1106,6 +1106,51 @@ def _sage_attention_backward_op(
|
||||
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
|
||||
|
||||
|
||||
def _npu_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(2), # num_heads
|
||||
input_layout="BSND",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0 - dropout_p,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# Not implemented yet.
|
||||
def _npu_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.")
|
||||
|
||||
|
||||
# ===== Context parallel =====
|
||||
|
||||
|
||||
@@ -2131,6 +2176,7 @@ def _native_math_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._NATIVE_NPU,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _native_npu_attention(
|
||||
query: torch.Tensor,
|
||||
@@ -2146,22 +2192,36 @@ def _native_npu_attention(
|
||||
raise ValueError("`attn_mask` is not supported for NPU attention")
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(1), # num_heads
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0 - dropout_p,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
if _parallel_config is None:
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(2), # num_heads
|
||||
input_layout="BSND",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0 - dropout_p,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
dropout_p,
|
||||
None,
|
||||
scale,
|
||||
None,
|
||||
return_lse,
|
||||
forward_op=_npu_attention_forward_op,
|
||||
backward_op=_npu_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user