From 093cd3f040ee4f44908df8e1b441954f3f25c214 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 11 Nov 2025 19:16:13 -1000 Subject: [PATCH] fix dispatch_attention_fn check (#12636) * fix * fix --- src/diffusers/models/attention_dispatch.py | 16 +++++++++++----- src/diffusers/utils/constants.py | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 6ecf97701f..92a4a6a599 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -383,12 +383,18 @@ def _check_shape( attn_mask: Optional[torch.Tensor] = None, **kwargs, ) -> None: + # Expected shapes: + # query: (batch_size, seq_len_q, num_heads, head_dim) + # key: (batch_size, seq_len_kv, num_heads, head_dim) + # value: (batch_size, seq_len_kv, num_heads, head_dim) + # attn_mask: (seq_len_q, seq_len_kv) or (batch_size, seq_len_q, seq_len_kv) + # or (batch_size, num_heads, seq_len_q, seq_len_kv) if query.shape[-1] != key.shape[-1]: - raise ValueError("Query and key must have the same last dimension.") - if query.shape[-2] != value.shape[-2]: - raise ValueError("Query and value must have the same second to last dimension.") - if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: - raise ValueError("Attention mask must match the key's second to last dimension.") + raise ValueError("Query and key must have the same head dimension.") + if key.shape[-3] != value.shape[-3]: + raise ValueError("Key and value must have the same sequence length.") + if attn_mask is not None and attn_mask.shape[-1] != key.shape[-3]: + raise ValueError("Attention mask must match the key's sequence length.") # ===== Helper functions ===== diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 42a53e1810..a18f28606b 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -42,7 +42,7 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules" DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DIFFUSERS_REQUEST_TIMEOUT = 60 DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") -DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES +DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0").upper() in ENV_VARS_TRUE_VALUES DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES