1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
Files
diffusers/src/diffusers/models/attention_dispatch.py
2025-07-21 23:56:25 +02:00

1118 lines
37 KiB
Python

# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import inspect
import math
from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch
from ..utils import (
get_logger,
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
is_sageattention_available,
is_sageattention_version,
is_torch_npu_available,
is_torch_version,
is_torch_xla_available,
is_torch_xla_version,
is_xformers_available,
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
logger = get_logger(__name__) # pylint: disable=invalid-name
if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"):
from flash_attn import flash_attn_func, flash_attn_varlen_func
else:
logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.")
flash_attn_func = None
flash_attn_varlen_func = None
if is_flash_attn_3_available():
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"):
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
sageattn_qk_int8_pv_fp8_cuda_sm90,
sageattn_qk_int8_pv_fp16_cuda,
sageattn_qk_int8_pv_fp16_triton,
sageattn_varlen,
)
else:
logger.warning(
"`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`."
)
sageattn = None
sageattn_qk_int8_pv_fp16_cuda = None
sageattn_qk_int8_pv_fp16_triton = None
sageattn_qk_int8_pv_fp8_cuda = None
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
sageattn_varlen = None
if is_torch_version(">=", "2.5.0"):
# We cannot import the flex_attention function from the package directly because it is expected (from the
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
# compiled function.
import torch.nn.attention.flex_attention as flex_attention
if is_torch_npu_available():
from torch_npu import npu_fusion_attention
else:
npu_fusion_attention = None
if is_torch_xla_available() and is_torch_xla_version(">", "2.2"):
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
else:
xla_flash_attention = None
if is_xformers_available() and is_xformers_version(">=", "0.0.29"):
import xformers.ops as xops
else:
logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.")
xops = None
# TODO(aryan): Add support for the following:
# - Sage Attention++
# - block sparse, radial and other attention methods
# - CP with sage attention, flex, xformers, other missing backends
# - Add support for normal and CP training with backends that don't support it yet
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
class AttentionBackendName(str, Enum):
# EAGER = "eager"
# `flash-attn`
FLASH = "flash"
FLASH_VARLEN = "flash_varlen"
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
# PyTorch native
FLEX = "flex"
NATIVE = "native"
_NATIVE_CUDNN = "_native_cudnn"
_NATIVE_EFFICIENT = "_native_efficient"
_NATIVE_FLASH = "_native_flash"
_NATIVE_MATH = "_native_math"
_NATIVE_NPU = "_native_npu"
_NATIVE_XLA = "_native_xla"
# `sageattention`
SAGE = "sage"
SAGE_VARLEN = "sage_varlen"
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
_SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda"
_SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton"
# TODO: let's not add support for Sparge Attention now because it requires tuning per model
# We can look into supporting something "autotune"-ing in the future
# SPARGE = "sparge"
# `xformers`
XFORMERS = "xformers"
class _AttentionBackendRegistry:
_backends = {}
_constraints = {}
_supported_arg_names = {}
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
_checks_enabled = DIFFUSERS_ATTN_CHECKS
@classmethod
def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None):
logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
def decorator(func):
cls._backends[backend] = func
cls._constraints[backend] = constraints or []
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
return func
return decorator
@classmethod
def get_active_backend(cls):
return cls._active_backend, cls._backends[cls._active_backend]
@classmethod
def list_backends(cls):
return list(cls._backends.keys())
@contextlib.contextmanager
def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE):
"""
Context manager to set the active attention backend.
"""
if backend not in _AttentionBackendRegistry._backends:
raise ValueError(f"Backend {backend} is not registered.")
old_backend = _AttentionBackendRegistry._active_backend
_AttentionBackendRegistry._active_backend = backend
try:
yield
finally:
_AttentionBackendRegistry._active_backend = old_backend
def dispatch_attention_fn(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
attention_kwargs: Optional[Dict[str, Any]] = None,
*,
backend: Optional[AttentionBackendName] = None,
) -> torch.Tensor:
attention_kwargs = attention_kwargs or {}
if backend is None:
# If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment
# variable), or we use a custom backend based on whether user is using the `attention_backend` context manager
backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend()
else:
backend_name = AttentionBackendName(backend)
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
kwargs = {
"query": query,
"key": key,
"value": value,
"attn_mask": attn_mask,
"dropout_p": dropout_p,
"is_causal": is_causal,
"scale": scale,
"enable_gqa": enable_gqa,
**attention_kwargs,
}
if _AttentionBackendRegistry._checks_enabled:
removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
if removed_kwargs:
logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.")
for check in _AttentionBackendRegistry._constraints.get(backend_name):
check(**kwargs)
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
return backend_fn(**kwargs)
# ===== Checks =====
# A list of very simple functions to catch common errors quickly when debugging.
def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None:
if attn_mask is not None and is_causal:
raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.")
def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
if query.device != key.device or query.device != value.device:
raise ValueError("Query, key, and value must be on the same device.")
if query.dtype != key.dtype or query.dtype != value.dtype:
raise ValueError("Query, key, and value must have the same dtype.")
def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
_check_device(query, key, value)
if query.device.type != "cuda":
raise ValueError("Query, key, and value must be on a CUDA device.")
def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable:
def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
_check_device_cuda(query, key, value)
if torch.cuda.get_device_capability(query.device) < (major, minor):
raise ValueError(
f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}."
)
return check_device_cuda
def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
if query.dtype != key.dtype:
raise ValueError("Query and key must have the same dtype.")
if query.dtype != value.dtype:
raise ValueError("Query and value must have the same dtype.")
def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
_check_qkv_dtype_match(query, key, value)
if query.dtype not in (torch.bfloat16, torch.float16):
raise ValueError("Query, key, and value must be either bfloat16 or float16.")
def _check_shape(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> None:
if query.shape[-1] != key.shape[-1]:
raise ValueError("Query and key must have the same last dimension.")
if query.shape[-2] != value.shape[-2]:
raise ValueError("Query and value must have the same second to last dimension.")
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:
raise ValueError("Attention mask must match the key's second to last dimension.")
# ===== Helper functions =====
@functools.lru_cache(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
batch_size: int,
seq_len_q: int,
seq_len_kv: int,
device: Optional[torch.device] = None,
):
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
cu_seqlens_k = torch.cumsum(seqlens_q, dim=0, dtype=torch.int32)
cu_seqlens_q = torch.cumsum(seqlens_k, dim=0, dtype=torch.int32)
cu_seqlens_q = torch.nn.functional.pad(cu_seqlens_q, (1, 0))
cu_seqlens_k = torch.nn.functional.pad(cu_seqlens_k, (1, 0))
max_seqlen_q = seqlens_q.max().item()
max_seqlen_k = seqlens_k.max().item()
return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
def _prepare_for_flash_attn_or_sage_varlen_with_mask(
batch_size: int,
seq_len_q: int,
attn_mask: torch.Tensor,
device: Optional[torch.device] = None,
):
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), (1, 0))
max_seqlen_q = seqlens_q.max().item()
max_seqlen_k = seqlens_k.max().item()
return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
def _prepare_for_flash_attn_or_sage_varlen(
batch_size: int,
seq_len_q: int,
seq_len_kv: int,
attn_mask: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
) -> None:
if attn_mask is None:
return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)
def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
"""
Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in
FlashAttention/Sage varlen.
Supports 1D to 4D shapes and common broadcasting patterns.
"""
if attn_mask.dtype != torch.bool:
raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")
if attn_mask.ndim == 1:
# [seq_len_k] -> broadcast across batch
attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
elif attn_mask.ndim == 2:
# [batch_size, seq_len_k]. Maybe broadcast across batch
if attn_mask.size(0) not in [1, batch_size]:
raise ValueError(
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask."
)
attn_mask = attn_mask.expand(batch_size, seq_len_k)
elif attn_mask.ndim == 3:
# [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension
# We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen.
if attn_mask.size(0) not in [1, batch_size]:
raise ValueError(
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask."
)
attn_mask = attn_mask.any(dim=1)
attn_mask = attn_mask.expand(batch_size, seq_len_k)
elif attn_mask.ndim == 4:
# [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions
if attn_mask.size(0) not in [1, batch_size]:
raise ValueError(
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask."
)
attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K]
attn_mask = attn_mask.any(dim=(1, 2)) # [B, K]
else:
raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}")
if attn_mask.shape != (batch_size, seq_len_k):
raise ValueError(
f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})"
)
return attn_mask
def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return q_idx >= kv_idx
# ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility
# TODO: library.custom_op and register_fake probably need version guards?
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_original(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = flash_attn_3_func(query, key, value)
lse = lse.permute(0, 2, 1)
return out, lse
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads)
return torch.empty_like(query), query.new_empty(lse_shape)
# ===== Attention backends =====
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False,
return_attn_probs: bool = False,
) -> torch.Tensor:
out = flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_VARLEN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False,
return_attn_probs: bool = False,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
out = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
)
out = out.unflatten(0, (batch_size, -1))
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
deterministic: bool = False,
return_attn_probs: bool = False,
) -> torch.Tensor:
out, lse, *_ = flash_attn_3_func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
attention_chunk=0,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
return (out, lse) if return_attn_probs else out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_varlen_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
deterministic: bool = False,
return_attn_probs: bool = False,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
out, lse, *_ = flash_attn_3_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
seqused_q=None,
seqused_k=None,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
out = out.unflatten(0, (batch_size, -1))
return (out, lse) if return_attn_probs else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLEX,
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
)
def _native_flex_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
kernel_options: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
# TODO: should we LRU cache the block mask creation?
score_mod = None
block_mask = None
batch_size, seq_len_q, num_heads, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask):
block_mask = attn_mask
elif is_causal:
block_mask = flex_attention.create_block_mask(
_flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device
)
elif torch.is_tensor(attn_mask):
if attn_mask.ndim == 2:
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv)
if attn_mask.dtype == torch.bool:
# TODO: this probably does not work but verify!
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return attn_mask[batch_idx, head_idx, q_idx, kv_idx]
block_mask = flex_attention.create_block_mask(
mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device
)
else:
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx]
else:
raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = flex_attention.flex_attention(
query=query,
key=key,
value=value,
score_mod=score_mod,
block_mask=block_mask,
scale=scale,
enable_gqa=enable_gqa,
return_lse=return_lse,
kernel_options=kernel_options,
)
out = out.permute(0, 2, 1, 3)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName.NATIVE,
constraints=[_check_device, _check_shape],
)
def _native_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_CUDNN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _native_cudnn_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_EFFICIENT,
constraints=[_check_device, _check_shape],
)
def _native_efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_FLASH,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _native_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=None, # not supported
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_MATH,
constraints=[_check_device, _check_shape],
)
def _native_math_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_NPU,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _native_npu_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
) -> torch.Tensor:
return npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tokens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_XLA,
constraints=[_check_device, _check_shape],
)
def _native_xla_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
query = query / math.sqrt(query.shape[-1])
out = xla_flash_attention(
q=query,
k=key,
v=value,
causal=is_causal,
)
out = out.permute(0, 2, 1, 3)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _sage_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_VARLEN,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _sage_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
is_causal: bool = False,
scale: Optional[float] = None,
smooth_k: bool = True,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
out = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
is_causal=is_causal,
sm_scale=scale,
smooth_k=smooth_k,
)
out = out.unflatten(0, (batch_size, -1))
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
)
def _sage_qk_int8_pv_fp8_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
smooth_k: bool = True,
smooth_v: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp8_cuda(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
smooth_v=smooth_v,
return_lse=return_lse,
)
@_AttentionBackendRegistry.register(
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
)
def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
smooth_k: bool = True,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp8_cuda_sm90(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
return_lse=return_lse,
)
@_AttentionBackendRegistry.register(
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
)
def _sage_qk_int8_pv_fp16_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32",
smooth_k: bool = True,
smooth_v: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp16_cuda(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
smooth_v=smooth_v,
return_lse=return_lse,
)
@_AttentionBackendRegistry.register(
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
)
def _sage_qk_int8_pv_fp16_triton_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
smooth_k: bool = True,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp16_triton(
q=query,
k=key,
v=value,
tensor_layout="NHD",
quantization_backend=quantization_backend,
is_causal=is_causal,
sm_scale=scale,
smooth_k=smooth_k,
return_lse=return_lse,
)
@_AttentionBackendRegistry.register(
AttentionBackendName.XFORMERS,
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
)
def _xformers_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
batch_size, seq_len_q, num_heads_q, _ = query.shape
_, seq_len_kv, num_heads_kv, _ = key.shape
if is_causal:
attn_mask = xops.LowerTriangularMask()
elif attn_mask is not None:
if attn_mask.ndim == 2:
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
elif attn_mask.ndim != 4:
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
if enable_gqa:
if num_heads_q % num_heads_kv != 0:
raise ValueError("Number of heads in query must be divisible by number of heads in key/value.")
num_heads_per_group = num_heads_q // num_heads_kv
query = query.unflatten(2, (num_heads_kv, -1))
key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale)
if enable_gqa:
out = out.flatten(2, 3)
return out