1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-08-26 16:54:48 +02:00
parent 2bb3796569
commit 4e69d42287
2 changed files with 28 additions and 27 deletions

View File

@@ -17,6 +17,7 @@ import functools
import inspect
import math
from enum import Enum
from functools import lru_cache
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch
@@ -39,8 +40,6 @@ from ..utils import (
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
logger = get_logger(__name__) # pylint: disable=invalid-name
_REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0"
@@ -70,20 +69,6 @@ else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
if is_kernels_available():
from ..utils.kernels_utils import _get_fa3_from_hub
flash_attn_interface_hub = _get_fa3_from_hub()
if flash_attn_interface_hub is not None:
flash_attn_3_hub_func = flash_attn_interface_hub.flash_attn_func
flash_attn_3_varlen_hub_func = flash_attn_interface_hub.flash_attn_varlen_func
else:
flash_attn_3_hub_func = None
flash_attn_3_varlen_hub_func = None
else:
flash_attn_3_hub_func = None
flash_attn_3_varlen_hub_func = None
if _CAN_USE_SAGE_ATTN:
from sageattention import (
@@ -148,6 +133,7 @@ else:
_custom_op = custom_op_no_op
_register_fake = register_fake_no_op
logger = get_logger(__name__) # pylint: disable=invalid-name
# TODO(aryan): Add support for the following:
# - Sage Attention++
@@ -169,7 +155,7 @@ class AttentionBackendName(str, Enum):
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
_FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
# PyTorch native
FLEX = "flex"
@@ -224,6 +210,22 @@ class _AttentionBackendRegistry:
return list(cls._backends.keys())
@lru_cache(maxsize=None)
def _load_fa3_hub():
from ..utils.kernels_utils import _get_fa3_from_hub
fa3_hub = _get_fa3_from_hub() # won't re-download if already present
if fa3_hub is None:
raise RuntimeError(
"Failed to load FlashAttention-3 kernels from the Hub. Please ensure the wheel is available for your platform."
)
return fa3_hub
def flash_attn_3_hub_func(*args, **kwargs):
return _load_fa3_hub().flash_attn_func(*args, **kwargs)
@contextlib.contextmanager
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
"""
@@ -374,12 +376,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
if flash_attn_3_hub_func is None:
raise RuntimeError(
"`flash_attn_3_hub_func` wasn't available. Please double if `kernels` was able to successfully pull the FA3 kernel from kernels-community/vllm-flash-attn3."
)
elif backend in [AttentionBackendName._FLASH_VARLEN_3_HUB]:
raise NotImplementedError
elif backend in [
AttentionBackendName.SAGE,
@@ -544,7 +540,7 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
return torch.empty_like(query), query.new_empty(lse_shape)
@_custom_op("vllm_flash_attn3::_flash_attn_forward", mutates_args=(), device_types="cuda")
@_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_hub(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -553,7 +549,7 @@ def _wrapped_flash_attn_3_hub(
return out, lse
@_register_fake("vllm_flash_attn3::_flash_attn_forward")
@_register_fake("vllm_flash_attn3::flash_attn")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads)

View File

@@ -1,6 +1,10 @@
from ..utils import get_logger
from .import_utils import is_kernels_available
logger = get_logger(__name__)
_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3"
@@ -13,5 +17,6 @@ def _get_fa3_from_hub():
try:
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
return flash_attn_3_hub
except Exception:
return None
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
raise