mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -17,7 +17,6 @@ import functools
|
||||
import inspect
|
||||
import math
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -145,6 +144,9 @@ _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
|
||||
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
|
||||
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
|
||||
|
||||
flash_attn_3_hub_func = None
|
||||
__fa3_hub_loaded = False
|
||||
|
||||
|
||||
class AttentionBackendName(str, Enum):
|
||||
# EAGER = "eager"
|
||||
@@ -210,20 +212,20 @@ class _AttentionBackendRegistry:
|
||||
return list(cls._backends.keys())
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _load_fa3_hub():
|
||||
def _ensure_fa3_hub_loaded():
|
||||
global __fa3_hub_loaded
|
||||
if __fa3_hub_loaded:
|
||||
return
|
||||
from ..utils.kernels_utils import _get_fa3_from_hub
|
||||
|
||||
fa3_hub = _get_fa3_from_hub() # won't re-download if already present
|
||||
if fa3_hub is None:
|
||||
fa3_hub_module = _get_fa3_from_hub() # doesn't retrigger download if already available.
|
||||
if fa3_hub_module is None:
|
||||
raise RuntimeError(
|
||||
"Failed to load FlashAttention-3 kernels from the Hub. Please ensure the wheel is available for your platform."
|
||||
)
|
||||
return fa3_hub
|
||||
|
||||
|
||||
def flash_attn_3_hub_func(*args, **kwargs):
|
||||
return _load_fa3_hub().flash_attn_func(*args, **kwargs)
|
||||
global flash_attn_3_hub_func
|
||||
flash_attn_3_hub_func = fa3_hub_module.flash_attn_func
|
||||
__fa3_hub_loaded = True
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -540,20 +542,20 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
|
||||
return torch.empty_like(query), query.new_empty(lse_shape)
|
||||
|
||||
|
||||
@_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda")
|
||||
def _wrapped_flash_attn_3_hub(
|
||||
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
out, lse = flash_attn_3_hub_func(query, key, value)
|
||||
lse = lse.permute(0, 2, 1)
|
||||
return out, lse
|
||||
# @_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda")
|
||||
# def _wrapped_flash_attn_3_hub(
|
||||
# query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
||||
# ) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# out, lse = flash_attn_3_hub_func(query, key, value)
|
||||
# lse = lse.permute(0, 2, 1)
|
||||
# return out, lse
|
||||
|
||||
|
||||
@_register_fake("vllm_flash_attn3::flash_attn")
|
||||
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, seq_len, num_heads, head_dim = query.shape
|
||||
lse_shape = (batch_size, seq_len, num_heads)
|
||||
return torch.empty_like(query), query.new_empty(lse_shape)
|
||||
# @_register_fake("vllm_flash_attn3::flash_attn")
|
||||
# def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# batch_size, seq_len, num_heads, head_dim = query.shape
|
||||
# lse_shape = (batch_size, seq_len, num_heads)
|
||||
# return torch.empty_like(query), query.new_empty(lse_shape)
|
||||
|
||||
|
||||
# ===== Attention backends =====
|
||||
|
||||
@@ -595,7 +595,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
attention as backend.
|
||||
"""
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
|
||||
from .attention_dispatch import (
|
||||
AttentionBackendName,
|
||||
_check_attention_backend_requirements,
|
||||
_ensure_fa3_hub_loaded,
|
||||
)
|
||||
|
||||
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
@@ -608,6 +612,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
||||
backend = AttentionBackendName(backend)
|
||||
_check_attention_backend_requirements(backend)
|
||||
# TODO: clean this once it gets exhausted.
|
||||
if "_flash_3_hub" in backend:
|
||||
# We ensure it's preloaded to reduce overhead and also to avoid compilation errors.
|
||||
_ensure_fa3_hub_loaded()
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
for module in self.modules():
|
||||
|
||||
@@ -5,7 +5,7 @@ from .import_utils import is_kernels_available
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3"
|
||||
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
|
||||
|
||||
|
||||
def _get_fa3_from_hub():
|
||||
@@ -15,7 +15,7 @@ def _get_fa3_from_hub():
|
||||
from kernels import get_kernel
|
||||
|
||||
try:
|
||||
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
|
||||
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops")
|
||||
return flash_attn_3_hub
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
||||
|
||||
Reference in New Issue
Block a user