From 41a6e86faf6fd1002e69a2cd813c286fe3ca591c Mon Sep 17 00:00:00 2001 From: dxqb <183307934+dxqb@users.noreply.github.com> Date: Tue, 6 Jan 2026 18:22:12 +0100 Subject: [PATCH] Check for attention mask in backends that don't support it (#12892) * check attention mask * Apply style fixes * bugfix --------- Co-authored-by: github-actions[bot] Co-authored-by: Sayak Paul --- src/diffusers/models/attention_dispatch.py | 47 ++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 310c44457c..15516ed2ed 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1420,6 +1420,7 @@ def _flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, @@ -1427,6 +1428,9 @@ def _flash_attention( _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: lse = None + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 2.") + if _parallel_config is None: out = flash_attn_func( q=query, @@ -1469,6 +1473,7 @@ def _flash_attention_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, @@ -1476,6 +1481,9 @@ def _flash_attention_hub( _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: lse = None + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 2.") + func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn out = func( q=query, @@ -1612,11 +1620,15 @@ def _flash_attention_3( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, is_causal: bool = False, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 3.") + out, lse = _wrapped_flash_attn_3( q=query, k=key, @@ -1636,6 +1648,7 @@ def _flash_attention_3_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, is_causal: bool = False, window_size: Tuple[int, int] = (-1, -1), @@ -1646,6 +1659,8 @@ def _flash_attention_3_hub( ) -> torch.Tensor: if _parallel_config: raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.") + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 3.") func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn out = func( @@ -1785,12 +1800,16 @@ def _aiter_flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for aiter attention") + if not return_lse and torch.is_grad_enabled(): # aiter requires return_lse=True by assertion when gradients are enabled. out, lse, *_ = aiter_flash_attn_func( @@ -2028,6 +2047,7 @@ def _native_flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, @@ -2035,6 +2055,9 @@ def _native_flash_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for aiter attention") + lse = None if _parallel_config is None and not return_lse: query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) @@ -2113,11 +2136,14 @@ def _native_npu_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for NPU attention") if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) @@ -2148,10 +2174,13 @@ def _native_xla_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for XLA attention") if return_lse: raise ValueError("XLA attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) @@ -2175,11 +2204,14 @@ def _sage_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") lse = None if _parallel_config is None: out = sageattn( @@ -2223,11 +2255,14 @@ def _sage_attention_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") lse = None func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn if _parallel_config is None: @@ -2309,11 +2344,14 @@ def _sage_qk_int8_pv_fp8_cuda_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp8_cuda( q=query, k=key, @@ -2333,11 +2371,14 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp8_cuda_sm90( q=query, k=key, @@ -2357,11 +2398,14 @@ def _sage_qk_int8_pv_fp16_cuda_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp16_cuda( q=query, k=key, @@ -2381,11 +2425,14 @@ def _sage_qk_int8_pv_fp16_triton_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp16_triton( q=query, k=key,