From ebf891a254de807bd88aba8a09792c82df74f459 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 16 Jan 2026 21:29:42 +0530 Subject: [PATCH] [core] gracefully error out when attn-backend x cp combo isn't supported. (#12832) * gracefully error out when attn-backend x cp combo isn't supported. * Revert "gracefully error out when attn-backend x cp combo isn't supported." This reverts commit c8abb5d7c01ca6a7c0bf82c27c91a326155a5e43. * gracefully error out when attn-backend x cp combo isn't supported. * up * address PR feedback. * up * Update src/diffusers/models/modeling_utils.py Co-authored-by: Dhruv Nair * dot. --------- Co-authored-by: Dhruv Nair --- src/diffusers/models/attention_dispatch.py | 9 ++++++-- src/diffusers/models/modeling_utils.py | 25 ++++++++++++++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index f086c2d425..61c478b03c 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -235,6 +235,10 @@ class _AttentionBackendRegistry: def get_active_backend(cls): return cls._active_backend, cls._backends[cls._active_backend] + @classmethod + def set_active_backend(cls, backend: str): + cls._active_backend = backend + @classmethod def list_backends(cls): return list(cls._backends.keys()) @@ -294,12 +298,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke _maybe_download_kernel_for_backend(backend) old_backend = _AttentionBackendRegistry._active_backend - _AttentionBackendRegistry._active_backend = backend + _AttentionBackendRegistry.set_active_backend(backend) try: yield finally: - _AttentionBackendRegistry._active_backend = old_backend + _AttentionBackendRegistry.set_active_backend(old_backend) def dispatch_attention_fn( @@ -348,6 +352,7 @@ def dispatch_attention_fn( check(**kwargs) kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]} + return backend_fn(**kwargs) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0ccd4c480e..63e50af617 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -599,6 +599,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): from .attention import AttentionModuleMixin from .attention_dispatch import ( AttentionBackendName, + _AttentionBackendRegistry, _check_attention_backend_requirements, _maybe_download_kernel_for_backend, ) @@ -607,6 +608,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): from .attention_processor import Attention, MochiAttention logger.warning("Attention backends are an experimental feature and the API may be subject to change.") + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + + parallel_config_set = False + for module in self.modules(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if getattr(processor, "_parallel_config", None) is not None: + parallel_config_set = True + break backend = backend.lower() available_backends = {x.value for x in AttentionBackendName.__members__.values()} @@ -614,10 +625,17 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) backend = AttentionBackendName(backend) + if parallel_config_set and not _AttentionBackendRegistry._is_context_parallel_available(backend): + compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel) + raise ValueError( + f"Context parallelism is enabled but current attention backend '{backend.value}' " + f"does not support context parallelism. " + f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()`." + ) + _check_attention_backend_requirements(backend) _maybe_download_kernel_for_backend(backend) - attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules(): if not isinstance(module, attention_classes): continue @@ -626,6 +644,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): continue processor._attention_backend = backend + # Important to set the active backend so that it propagates gracefully throughout. + _AttentionBackendRegistry.set_active_backend(backend) + def reset_attention_backend(self) -> None: """ Resets the attention backend for the model. Following calls to `forward` will use the environment default, if @@ -1538,7 +1559,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' " f"is using backend '{attention_backend.value}' which does not support context parallelism. " f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before " - f"calling `enable_parallelism()`." + f"calling `model.enable_parallelism()`." ) # All modules use the same attention processor and backend. We don't need to