From 4e69d42287dfef93de01160230d74a036f09fbef Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 26 Aug 2025 16:54:48 +0200 Subject: [PATCH] up --- src/diffusers/models/attention_dispatch.py | 46 ++++++++++------------ src/diffusers/utils/kernels_utils.py | 9 ++++- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 10dbc4e0ee..4db70f4626 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -17,6 +17,7 @@ import functools import inspect import math from enum import Enum +from functools import lru_cache from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch @@ -39,8 +40,6 @@ from ..utils import ( from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS -logger = get_logger(__name__) # pylint: disable=invalid-name - _REQUIRED_FLASH_VERSION = "2.6.3" _REQUIRED_SAGE_VERSION = "2.1.1" _REQUIRED_FLEX_VERSION = "2.5.0" @@ -70,20 +69,6 @@ else: flash_attn_3_func = None flash_attn_3_varlen_func = None -if is_kernels_available(): - from ..utils.kernels_utils import _get_fa3_from_hub - - flash_attn_interface_hub = _get_fa3_from_hub() - if flash_attn_interface_hub is not None: - flash_attn_3_hub_func = flash_attn_interface_hub.flash_attn_func - flash_attn_3_varlen_hub_func = flash_attn_interface_hub.flash_attn_varlen_func - else: - flash_attn_3_hub_func = None - flash_attn_3_varlen_hub_func = None -else: - flash_attn_3_hub_func = None - flash_attn_3_varlen_hub_func = None - if _CAN_USE_SAGE_ATTN: from sageattention import ( @@ -148,6 +133,7 @@ else: _custom_op = custom_op_no_op _register_fake = register_fake_no_op +logger = get_logger(__name__) # pylint: disable=invalid-name # TODO(aryan): Add support for the following: # - Sage Attention++ @@ -169,7 +155,7 @@ class AttentionBackendName(str, Enum): _FLASH_3 = "_flash_3" _FLASH_VARLEN_3 = "_flash_varlen_3" _FLASH_3_HUB = "_flash_3_hub" - _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet. + # _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet. # PyTorch native FLEX = "flex" @@ -224,6 +210,22 @@ class _AttentionBackendRegistry: return list(cls._backends.keys()) +@lru_cache(maxsize=None) +def _load_fa3_hub(): + from ..utils.kernels_utils import _get_fa3_from_hub + + fa3_hub = _get_fa3_from_hub() # won't re-download if already present + if fa3_hub is None: + raise RuntimeError( + "Failed to load FlashAttention-3 kernels from the Hub. Please ensure the wheel is available for your platform." + ) + return fa3_hub + + +def flash_attn_3_hub_func(*args, **kwargs): + return _load_fa3_hub().flash_attn_func(*args, **kwargs) + + @contextlib.contextmanager def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): """ @@ -374,12 +376,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None raise RuntimeError( f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." ) - if flash_attn_3_hub_func is None: - raise RuntimeError( - "`flash_attn_3_hub_func` wasn't available. Please double if `kernels` was able to successfully pull the FA3 kernel from kernels-community/vllm-flash-attn3." - ) - elif backend in [AttentionBackendName._FLASH_VARLEN_3_HUB]: - raise NotImplementedError elif backend in [ AttentionBackendName.SAGE, @@ -544,7 +540,7 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc return torch.empty_like(query), query.new_empty(lse_shape) -@_custom_op("vllm_flash_attn3::_flash_attn_forward", mutates_args=(), device_types="cuda") +@_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda") def _wrapped_flash_attn_3_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -553,7 +549,7 @@ def _wrapped_flash_attn_3_hub( return out, lse -@_register_fake("vllm_flash_attn3::_flash_attn_forward") +@_register_fake("vllm_flash_attn3::flash_attn") def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, seq_len, num_heads, head_dim = query.shape lse_shape = (batch_size, seq_len, num_heads) diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py index 346fc40c60..dddc9ede21 100644 --- a/src/diffusers/utils/kernels_utils.py +++ b/src/diffusers/utils/kernels_utils.py @@ -1,6 +1,10 @@ +from ..utils import get_logger from .import_utils import is_kernels_available +logger = get_logger(__name__) + + _DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3" @@ -13,5 +17,6 @@ def _get_fa3_from_hub(): try: flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3) return flash_attn_3_hub - except Exception: - return None + except Exception as e: + logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}") + raise