mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
up
This commit is contained in:
@@ -80,12 +80,15 @@ if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
raise ImportError(
|
||||
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
|
||||
)
|
||||
from ..utils.kernels_utils import _get_fa3_from_hub
|
||||
from ..utils.kernels_utils import _get_fa3_from_hub, get_fa_from_hub
|
||||
|
||||
flash_attn_interface_hub = _get_fa3_from_hub()
|
||||
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
|
||||
fa3_interface_hub = _get_fa3_from_hub()
|
||||
flash_attn_3_func_hub = fa3_interface_hub.flash_attn_func
|
||||
fa_interface_hub = get_fa_from_hub()
|
||||
flash_attn_func_hub = fa_interface_hub.flash_attn_func
|
||||
else:
|
||||
flash_attn_3_func_hub = None
|
||||
flash_attn_func_hub = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
from sageattention import (
|
||||
@@ -170,6 +173,8 @@ class AttentionBackendName(str, Enum):
|
||||
# `flash-attn`
|
||||
FLASH = "flash"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
FLASH_HUB = "flash_hub"
|
||||
# FLASH_VARLEN_HUB = "flash_varlen_hub" # not supported yet.
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
_FLASH_3_HUB = "_flash_3_hub"
|
||||
@@ -400,15 +405,15 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
||||
)
|
||||
|
||||
# TODO: add support Hub variant of FA3 varlen later
|
||||
elif backend in [AttentionBackendName._FLASH_3_HUB]:
|
||||
# TODO: add support Hub variant of FA and FA3 varlen later
|
||||
elif backend in [AttentionBackendName.FLASH_HUB, AttentionBackendName._FLASH_3_HUB]:
|
||||
if not DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
raise RuntimeError(
|
||||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
|
||||
f"Flash Attention Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
|
||||
)
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
f"Flash Attention Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
)
|
||||
|
||||
elif backend in [
|
||||
@@ -1225,6 +1230,36 @@ def _flash_attention(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
)
|
||||
def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
out = flash_attn_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_VARLEN,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -2,22 +2,32 @@ from ..utils import get_logger
|
||||
from .import_utils import is_kernels_available
|
||||
|
||||
|
||||
if is_kernels_available():
|
||||
from kernels import get_kernel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
|
||||
_DEFAULT_HUB_IDS = {
|
||||
"fa3": ("kernels-community/flash-attn3", {"revision": "fake-ops-return-probs"}),
|
||||
"fa": ("kernels-community/flash-attn", {}),
|
||||
}
|
||||
|
||||
|
||||
def _get_fa3_from_hub():
|
||||
def _get_from_hub(key: str):
|
||||
if not is_kernels_available():
|
||||
return None
|
||||
else:
|
||||
from kernels import get_kernel
|
||||
|
||||
try:
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
|
||||
return flash_attn_3_hub
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
||||
raise
|
||||
hub_id, kwargs = _DEFAULT_HUB_IDS[key]
|
||||
try:
|
||||
return get_kernel(hub_id, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{hub_id}' from the Hub: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def get_fa3_from_hub():
|
||||
return _get_from_hub("fa3")
|
||||
|
||||
|
||||
def get_fa_from_hub():
|
||||
return _get_from_hub("fa")
|
||||
|
||||
Reference in New Issue
Block a user