From 87d08798ded03cc6f526287ee2f211a8a9839e77 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 26 Aug 2025 12:02:49 +0200 Subject: [PATCH] up --- src/diffusers/models/attention_dispatch.py | 4 ++-- src/diffusers/utils/kernels_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 7daf32ba1e..d0a8127507 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -540,7 +540,7 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc return torch.empty_like(query), query.new_empty(lse_shape) -@_custom_op("flash_attn_3_hub_func", mutates_args=(), device_types="cuda") +@_custom_op("vllm_flash_attn3::_flash_attn_forward", mutates_args=(), device_types="cuda") def _wrapped_flash_attn_3_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -549,7 +549,7 @@ def _wrapped_flash_attn_3_hub( return out, lse -@_register_fake("flash_attn_3_hub_func") +@_register_fake("vllm_flash_attn3::_flash_attn_forward") def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, seq_len, num_heads, head_dim = query.shape lse_shape = (batch_size, seq_len, num_heads) diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py index dc4d0fa90c..ba1c5efcbe 100644 --- a/src/diffusers/utils/kernels_utils.py +++ b/src/diffusers/utils/kernels_utils.py @@ -11,7 +11,7 @@ def _get_fa3_from_hub(): from kernels import get_kernel try: - vllm_flash_attn3 = get_kernel(_DEFAULT_HUB_ID_FA3) - return vllm_flash_attn3 + flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3) + return flash_attn_3_hub except Exception as e: raise e