diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 7daf32ba1e..d0a8127507 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -540,7 +540,7 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc return torch.empty_like(query), query.new_empty(lse_shape) -@_custom_op("flash_attn_3_hub_func", mutates_args=(), device_types="cuda") +@_custom_op("vllm_flash_attn3::_flash_attn_forward", mutates_args=(), device_types="cuda") def _wrapped_flash_attn_3_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -549,7 +549,7 @@ def _wrapped_flash_attn_3_hub( return out, lse -@_register_fake("flash_attn_3_hub_func") +@_register_fake("vllm_flash_attn3::_flash_attn_forward") def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, seq_len, num_heads, head_dim = query.shape lse_shape = (batch_size, seq_len, num_heads) diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py index dc4d0fa90c..ba1c5efcbe 100644 --- a/src/diffusers/utils/kernels_utils.py +++ b/src/diffusers/utils/kernels_utils.py @@ -11,7 +11,7 @@ def _get_fa3_from_hub(): from kernels import get_kernel try: - vllm_flash_attn3 = get_kernel(_DEFAULT_HUB_ID_FA3) - return vllm_flash_attn3 + flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3) + return flash_attn_3_hub except Exception as e: raise e