1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
sayakpaul
2025-09-25 13:05:39 +05:30
parent 310fdaf556
commit c386f220ea
2 changed files with 64 additions and 19 deletions

View File

@@ -80,12 +80,15 @@ if DIFFUSERS_ENABLE_HUB_KERNELS:
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub
from ..utils.kernels_utils import _get_fa3_from_hub, get_fa_from_hub
flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
fa3_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = fa3_interface_hub.flash_attn_func
fa_interface_hub = get_fa_from_hub()
flash_attn_func_hub = fa_interface_hub.flash_attn_func
else:
flash_attn_3_func_hub = None
flash_attn_func_hub = None
if _CAN_USE_SAGE_ATTN:
from sageattention import (
@@ -170,6 +173,8 @@ class AttentionBackendName(str, Enum):
# `flash-attn`
FLASH = "flash"
FLASH_VARLEN = "flash_varlen"
FLASH_HUB = "flash_hub"
# FLASH_VARLEN_HUB = "flash_varlen_hub" # not supported yet.
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
@@ -400,15 +405,15 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
)
# TODO: add support Hub variant of FA3 varlen later
elif backend in [AttentionBackendName._FLASH_3_HUB]:
# TODO: add support Hub variant of FA and FA3 varlen later
elif backend in [AttentionBackendName.FLASH_HUB, AttentionBackendName._FLASH_3_HUB]:
if not DIFFUSERS_ENABLE_HUB_KERNELS:
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
f"Flash Attention Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
)
if not is_kernels_available():
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
f"Flash Attention Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
elif backend in [
@@ -1225,6 +1230,36 @@ def _flash_attention(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
lse = None
out = flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_VARLEN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

View File

@@ -2,22 +2,32 @@ from ..utils import get_logger
from .import_utils import is_kernels_available
if is_kernels_available():
from kernels import get_kernel
logger = get_logger(__name__)
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
_DEFAULT_HUB_IDS = {
"fa3": ("kernels-community/flash-attn3", {"revision": "fake-ops-return-probs"}),
"fa": ("kernels-community/flash-attn", {}),
}
def _get_fa3_from_hub():
def _get_from_hub(key: str):
if not is_kernels_available():
return None
else:
from kernels import get_kernel
try:
# TODO: temporary revision for now. Remove when merged upstream into `main`.
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
return flash_attn_3_hub
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
raise
hub_id, kwargs = _DEFAULT_HUB_IDS[key]
try:
return get_kernel(hub_id, **kwargs)
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{hub_id}' from the Hub: {e}")
raise
def get_fa3_from_hub():
return _get_from_hub("fa3")
def get_fa_from_hub():
return _get_from_hub("fa")