mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[core] support sage attention + FA2 through kernels (#12439)
* up * support automatic dispatch. * disable compile support for now./ * up * flash too. * document. * up * up * up * up
This commit is contained in:
@@ -139,12 +139,14 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
|
||||
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
|
||||
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
|
||||
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from kernels |
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
|
||||
| `sage_hub` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) from kernels |
|
||||
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
|
||||
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
|
||||
| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |
|
||||
|
||||
@@ -18,7 +18,7 @@ import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -160,16 +160,13 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
# - CP with sage attention, flex, xformers, other missing backends
|
||||
# - Add support for normal and CP training with backends that don't support it yet
|
||||
|
||||
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
|
||||
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
|
||||
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
|
||||
|
||||
|
||||
class AttentionBackendName(str, Enum):
|
||||
# EAGER = "eager"
|
||||
|
||||
# `flash-attn`
|
||||
FLASH = "flash"
|
||||
FLASH_HUB = "flash_hub"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
@@ -191,6 +188,7 @@ class AttentionBackendName(str, Enum):
|
||||
|
||||
# `sageattention`
|
||||
SAGE = "sage"
|
||||
SAGE_HUB = "sage_hub"
|
||||
SAGE_VARLEN = "sage_varlen"
|
||||
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
|
||||
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
|
||||
@@ -264,7 +262,13 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
||||
)
|
||||
),
|
||||
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
|
||||
),
|
||||
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -420,8 +424,8 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
||||
)
|
||||
|
||||
# TODO: add support Hub variant of FA3 varlen later
|
||||
elif backend in [AttentionBackendName._FLASH_3_HUB]:
|
||||
# TODO: add support Hub variant of varlen later
|
||||
elif backend in [AttentionBackendName._FLASH_3_HUB, AttentionBackendName.FLASH_HUB, AttentionBackendName.SAGE_HUB]:
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
@@ -1350,6 +1354,38 @@ def _flash_attention(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_VARLEN,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
@@ -1431,6 +1467,7 @@ def _flash_attention_3(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_3_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -1444,6 +1481,9 @@ def _flash_attention_3_hub(
|
||||
return_attn_probs: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if _parallel_config:
|
||||
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
@@ -1938,6 +1978,38 @@ def _sage_attention(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
||||
if _parallel_config is None:
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_VARLEN,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -34,7 +34,10 @@ from diffusers.utils import is_torch_version # noqa: E402
|
||||
|
||||
# fmt: off
|
||||
FORWARD_CASES = [
|
||||
("flash_hub", None),
|
||||
(
|
||||
"flash_hub",
|
||||
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
|
||||
),
|
||||
(
|
||||
"_flash_3_hub",
|
||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
|
||||
@@ -54,7 +57,11 @@ FORWARD_CASES = [
|
||||
]
|
||||
|
||||
COMPILE_CASES = [
|
||||
("flash_hub", None, True),
|
||||
(
|
||||
"flash_hub",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||
True
|
||||
),
|
||||
(
|
||||
"_flash_3_hub",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||
|
||||
Reference in New Issue
Block a user