1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-08-27 13:36:57 +02:00
parent 548f56e428
commit 6e9f81fa03
3 changed files with 35 additions and 25 deletions

View File

@@ -17,7 +17,6 @@ import functools
import inspect
import math
from enum import Enum
from functools import lru_cache
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch
@@ -145,6 +144,9 @@ _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
flash_attn_3_hub_func = None
__fa3_hub_loaded = False
class AttentionBackendName(str, Enum):
# EAGER = "eager"
@@ -210,20 +212,20 @@ class _AttentionBackendRegistry:
return list(cls._backends.keys())
@lru_cache(maxsize=None)
def _load_fa3_hub():
def _ensure_fa3_hub_loaded():
global __fa3_hub_loaded
if __fa3_hub_loaded:
return
from ..utils.kernels_utils import _get_fa3_from_hub
fa3_hub = _get_fa3_from_hub() # won't re-download if already present
if fa3_hub is None:
fa3_hub_module = _get_fa3_from_hub() # doesn't retrigger download if already available.
if fa3_hub_module is None:
raise RuntimeError(
"Failed to load FlashAttention-3 kernels from the Hub. Please ensure the wheel is available for your platform."
)
return fa3_hub
def flash_attn_3_hub_func(*args, **kwargs):
return _load_fa3_hub().flash_attn_func(*args, **kwargs)
global flash_attn_3_hub_func
flash_attn_3_hub_func = fa3_hub_module.flash_attn_func
__fa3_hub_loaded = True
@contextlib.contextmanager
@@ -540,20 +542,20 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
return torch.empty_like(query), query.new_empty(lse_shape)
@_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_hub(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = flash_attn_3_hub_func(query, key, value)
lse = lse.permute(0, 2, 1)
return out, lse
# @_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda")
# def _wrapped_flash_attn_3_hub(
# query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# out, lse = flash_attn_3_hub_func(query, key, value)
# lse = lse.permute(0, 2, 1)
# return out, lse
@_register_fake("vllm_flash_attn3::flash_attn")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads)
return torch.empty_like(query), query.new_empty(lse_shape)
# @_register_fake("vllm_flash_attn3::flash_attn")
# def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# batch_size, seq_len, num_heads, head_dim = query.shape
# lse_shape = (batch_size, seq_len, num_heads)
# return torch.empty_like(query), query.new_empty(lse_shape)
# ===== Attention backends =====

View File

@@ -595,7 +595,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
attention as backend.
"""
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
from .attention_dispatch import (
AttentionBackendName,
_check_attention_backend_requirements,
_ensure_fa3_hub_loaded,
)
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
@@ -608,6 +612,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend)
_check_attention_backend_requirements(backend)
# TODO: clean this once it gets exhausted.
if "_flash_3_hub" in backend:
# We ensure it's preloaded to reduce overhead and also to avoid compilation errors.
_ensure_fa3_hub_loaded()
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():

View File

@@ -5,7 +5,7 @@ from .import_utils import is_kernels_available
logger = get_logger(__name__)
_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3"
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
def _get_fa3_from_hub():
@@ -15,7 +15,7 @@ def _get_fa3_from_hub():
from kernels import get_kernel
try:
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops")
return flash_attn_3_hub
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")