From 7a8f85b0473eb82db9114537c847a71fd8ab6f5e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Dec 2025 14:59:01 +0530 Subject: [PATCH] up --- src/diffusers/models/attention_dispatch.py | 35 +++++++++++++++------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ffda254976..433815d7ed 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -277,8 +277,8 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = { repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None, - wrapped_forward_attr="_wrapped_flash_attn_forward", - wrapped_backward_attr="_wrapped_flash_attn_backward", + wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward", + wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward", ), AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None @@ -602,27 +602,39 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): # ===== Helpers for downloading kernels ===== +def _resolve_kernel_attr(module, attr_path: str): + target = module + for attr in attr_path.split("."): + if not hasattr(target, attr): + raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.") + target = getattr(target, attr) + return target + + def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: if backend not in _HUB_KERNELS_REGISTRY: return config = _HUB_KERNELS_REGISTRY[backend] - if config.kernel_fn is not None: + needs_kernel = config.kernel_fn is None + needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None + needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None + + if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward): return try: from kernels import get_kernel kernel_module = get_kernel(config.repo_id, revision=config.revision) - kernel_func = getattr(kernel_module, config.function_attr) - # Cache the downloaded kernel function in the config object - config.kernel_fn = kernel_func + if needs_kernel: + config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr) - if config.wrapped_forward_attr is not None and config.wrapped_forward_attr is not None: - wrapped_forward_fn = getattr(kernel_module, config.wrapped_forward_attr) - wrapped_backward_fn = getattr(kernel_module, config.wrapped_backward_attr) - config.wrapped_forward_fn = wrapped_forward_fn - config.wrapped_backward_fn = wrapped_backward_fn + if needs_wrapped_forward: + config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr) + + if needs_wrapped_backward: + config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr) except Exception as e: logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}") @@ -1307,6 +1319,7 @@ def _sage_attention_hub_forward_op( return (out, lse) if return_lse else out + # ===== Context parallel =====