mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -17,6 +17,7 @@ import functools
|
||||
import inspect
|
||||
import math
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -39,8 +40,6 @@ from ..utils import (
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_REQUIRED_FLASH_VERSION = "2.6.3"
|
||||
_REQUIRED_SAGE_VERSION = "2.1.1"
|
||||
_REQUIRED_FLEX_VERSION = "2.5.0"
|
||||
@@ -70,20 +69,6 @@ else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
|
||||
if is_kernels_available():
|
||||
from ..utils.kernels_utils import _get_fa3_from_hub
|
||||
|
||||
flash_attn_interface_hub = _get_fa3_from_hub()
|
||||
if flash_attn_interface_hub is not None:
|
||||
flash_attn_3_hub_func = flash_attn_interface_hub.flash_attn_func
|
||||
flash_attn_3_varlen_hub_func = flash_attn_interface_hub.flash_attn_varlen_func
|
||||
else:
|
||||
flash_attn_3_hub_func = None
|
||||
flash_attn_3_varlen_hub_func = None
|
||||
else:
|
||||
flash_attn_3_hub_func = None
|
||||
flash_attn_3_varlen_hub_func = None
|
||||
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
from sageattention import (
|
||||
@@ -148,6 +133,7 @@ else:
|
||||
_custom_op = custom_op_no_op
|
||||
_register_fake = register_fake_no_op
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# TODO(aryan): Add support for the following:
|
||||
# - Sage Attention++
|
||||
@@ -169,7 +155,7 @@ class AttentionBackendName(str, Enum):
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
_FLASH_3_HUB = "_flash_3_hub"
|
||||
_FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
|
||||
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
|
||||
|
||||
# PyTorch native
|
||||
FLEX = "flex"
|
||||
@@ -224,6 +210,22 @@ class _AttentionBackendRegistry:
|
||||
return list(cls._backends.keys())
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _load_fa3_hub():
|
||||
from ..utils.kernels_utils import _get_fa3_from_hub
|
||||
|
||||
fa3_hub = _get_fa3_from_hub() # won't re-download if already present
|
||||
if fa3_hub is None:
|
||||
raise RuntimeError(
|
||||
"Failed to load FlashAttention-3 kernels from the Hub. Please ensure the wheel is available for your platform."
|
||||
)
|
||||
return fa3_hub
|
||||
|
||||
|
||||
def flash_attn_3_hub_func(*args, **kwargs):
|
||||
return _load_fa3_hub().flash_attn_func(*args, **kwargs)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
|
||||
"""
|
||||
@@ -374,12 +376,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
raise RuntimeError(
|
||||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
)
|
||||
if flash_attn_3_hub_func is None:
|
||||
raise RuntimeError(
|
||||
"`flash_attn_3_hub_func` wasn't available. Please double if `kernels` was able to successfully pull the FA3 kernel from kernels-community/vllm-flash-attn3."
|
||||
)
|
||||
elif backend in [AttentionBackendName._FLASH_VARLEN_3_HUB]:
|
||||
raise NotImplementedError
|
||||
|
||||
elif backend in [
|
||||
AttentionBackendName.SAGE,
|
||||
@@ -544,7 +540,7 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
|
||||
return torch.empty_like(query), query.new_empty(lse_shape)
|
||||
|
||||
|
||||
@_custom_op("vllm_flash_attn3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
||||
@_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda")
|
||||
def _wrapped_flash_attn_3_hub(
|
||||
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -553,7 +549,7 @@ def _wrapped_flash_attn_3_hub(
|
||||
return out, lse
|
||||
|
||||
|
||||
@_register_fake("vllm_flash_attn3::_flash_attn_forward")
|
||||
@_register_fake("vllm_flash_attn3::flash_attn")
|
||||
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, seq_len, num_heads, head_dim = query.shape
|
||||
lse_shape = (batch_size, seq_len, num_heads)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from ..utils import get_logger
|
||||
from .import_utils import is_kernels_available
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3"
|
||||
|
||||
|
||||
@@ -13,5 +17,6 @@ def _get_fa3_from_hub():
|
||||
try:
|
||||
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
|
||||
return flash_attn_3_hub
|
||||
except Exception:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user