From 26475082cb9db5da659378c978496999d338b960 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Thu, 16 Oct 2025 09:19:30 -0700 Subject: [PATCH] [docs] Attention checks (#12486) * checks * feedback --------- Co-authored-by: Sayak Paul --- .../en/optimization/attention_backends.md | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/docs/source/en/optimization/attention_backends.md b/docs/source/en/optimization/attention_backends.md index e603878a63..8be2c06030 100644 --- a/docs/source/en/optimization/attention_backends.md +++ b/docs/source/en/optimization/attention_backends.md @@ -81,6 +81,45 @@ with attention_backend("_flash_3_hub"): > [!TIP] > Most attention backends support `torch.compile` without graph breaks and can be used to further speed up inference. +## Checks + +The attention dispatcher includes debugging checks that catch common errors before they cause problems. + +1. Device checks verify that query, key, and value tensors live on the same device. +2. Data type checks confirm tensors have matching dtypes and use either bfloat16 or float16. +3. Shape checks validate tensor dimensions and prevent mixing attention masks with causal flags. + +Enable these checks by setting the `DIFFUSERS_ATTN_CHECKS` environment variable. Checks add overhead to every attention operation, so they're disabled by default. + +```bash +export DIFFUSERS_ATTN_CHECKS=yes +``` + +The checks are run now before every attention operation. + +```py +import torch + +query = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda") +key = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda") +value = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda") + +try: + with attention_backend("flash"): + output = dispatch_attention_fn(query, key, value) + print("✓ Flash Attention works with checks enabled") +except Exception as e: + print(f"✗ Flash Attention failed: {e}") +``` + +You can also configure the registry directly. + +```py +from diffusers.models.attention_dispatch import _AttentionBackendRegistry + +_AttentionBackendRegistry._checks_enabled = True +``` + ## Available backends Refer to the table below for a complete list of available attention backends and their variants.