diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 0a2ad68123..be5d8b403b 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -80,12 +80,15 @@ if DIFFUSERS_ENABLE_HUB_KERNELS: raise ImportError( "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." ) - from ..utils.kernels_utils import _get_fa3_from_hub + from ..utils.kernels_utils import _get_fa3_from_hub, get_fa_from_hub - flash_attn_interface_hub = _get_fa3_from_hub() - flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func + fa3_interface_hub = _get_fa3_from_hub() + flash_attn_3_func_hub = fa3_interface_hub.flash_attn_func + fa_interface_hub = get_fa_from_hub() + flash_attn_func_hub = fa_interface_hub.flash_attn_func else: flash_attn_3_func_hub = None + flash_attn_func_hub = None if _CAN_USE_SAGE_ATTN: from sageattention import ( @@ -170,6 +173,8 @@ class AttentionBackendName(str, Enum): # `flash-attn` FLASH = "flash" FLASH_VARLEN = "flash_varlen" + FLASH_HUB = "flash_hub" + # FLASH_VARLEN_HUB = "flash_varlen_hub" # not supported yet. _FLASH_3 = "_flash_3" _FLASH_VARLEN_3 = "_flash_varlen_3" _FLASH_3_HUB = "_flash_3_hub" @@ -400,15 +405,15 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." ) - # TODO: add support Hub variant of FA3 varlen later - elif backend in [AttentionBackendName._FLASH_3_HUB]: + # TODO: add support Hub variant of FA and FA3 varlen later + elif backend in [AttentionBackendName.FLASH_HUB, AttentionBackendName._FLASH_3_HUB]: if not DIFFUSERS_ENABLE_HUB_KERNELS: raise RuntimeError( - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." + f"Flash Attention Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." ) if not is_kernels_available(): raise RuntimeError( - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." + f"Flash Attention Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." ) elif backend in [ @@ -1225,6 +1230,36 @@ def _flash_attention( return (out, lse) if return_lse else out +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH_HUB, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention_hub( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + lse = None + out = flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_attn_probs=return_lse, + ) + if return_lse: + out, lse, *_ = out + + return (out, lse) if return_lse else out + + @_AttentionBackendRegistry.register( AttentionBackendName.FLASH_VARLEN, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py index 26d6e3972f..0a1511c67f 100644 --- a/src/diffusers/utils/kernels_utils.py +++ b/src/diffusers/utils/kernels_utils.py @@ -2,22 +2,32 @@ from ..utils import get_logger from .import_utils import is_kernels_available +if is_kernels_available(): + from kernels import get_kernel + logger = get_logger(__name__) - -_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3" +_DEFAULT_HUB_IDS = { + "fa3": ("kernels-community/flash-attn3", {"revision": "fake-ops-return-probs"}), + "fa": ("kernels-community/flash-attn", {}), +} -def _get_fa3_from_hub(): +def _get_from_hub(key: str): if not is_kernels_available(): return None - else: - from kernels import get_kernel - try: - # TODO: temporary revision for now. Remove when merged upstream into `main`. - flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs") - return flash_attn_3_hub - except Exception as e: - logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}") - raise + hub_id, kwargs = _DEFAULT_HUB_IDS[key] + try: + return get_kernel(hub_id, **kwargs) + except Exception as e: + logger.error(f"An error occurred while fetching kernel '{hub_id}' from the Hub: {e}") + raise + + +def get_fa3_from_hub(): + return _get_from_hub("fa3") + + +def get_fa_from_hub(): + return _get_from_hub("fa")