diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 6357bdad0f..8d149cb7a2 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -899,6 +899,7 @@ def _templated_context_parallel_attention( @_AttentionBackendRegistry.register( AttentionBackendName.FLASH, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, ) def _flash_attention( query: torch.Tensor,