1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-08-26 12:02:49 +02:00
parent bc40971210
commit 87d08798de
2 changed files with 4 additions and 4 deletions

View File

@@ -540,7 +540,7 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
return torch.empty_like(query), query.new_empty(lse_shape)
@_custom_op("flash_attn_3_hub_func", mutates_args=(), device_types="cuda")
@_custom_op("vllm_flash_attn3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_hub(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -549,7 +549,7 @@ def _wrapped_flash_attn_3_hub(
return out, lse
@_register_fake("flash_attn_3_hub_func")
@_register_fake("vllm_flash_attn3::_flash_attn_forward")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads)

View File

@@ -11,7 +11,7 @@ def _get_fa3_from_hub():
from kernels import get_kernel
try:
vllm_flash_attn3 = get_kernel(_DEFAULT_HUB_ID_FA3)
return vllm_flash_attn3
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
return flash_attn_3_hub
except Exception as e:
raise e