mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -540,7 +540,7 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
|
||||
return torch.empty_like(query), query.new_empty(lse_shape)
|
||||
|
||||
|
||||
@_custom_op("flash_attn_3_hub_func", mutates_args=(), device_types="cuda")
|
||||
@_custom_op("vllm_flash_attn3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
||||
def _wrapped_flash_attn_3_hub(
|
||||
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -549,7 +549,7 @@ def _wrapped_flash_attn_3_hub(
|
||||
return out, lse
|
||||
|
||||
|
||||
@_register_fake("flash_attn_3_hub_func")
|
||||
@_register_fake("vllm_flash_attn3::_flash_attn_forward")
|
||||
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, seq_len, num_heads, head_dim = query.shape
|
||||
lse_shape = (batch_size, seq_len, num_heads)
|
||||
|
||||
@@ -11,7 +11,7 @@ def _get_fa3_from_hub():
|
||||
from kernels import get_kernel
|
||||
|
||||
try:
|
||||
vllm_flash_attn3 = get_kernel(_DEFAULT_HUB_ID_FA3)
|
||||
return vllm_flash_attn3
|
||||
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
|
||||
return flash_attn_3_hub
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
Reference in New Issue
Block a user