diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 7c647b5c0a..d8ba27a2fc 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1494,6 +1494,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): from ..hooks.context_parallel import apply_context_parallel from .attention import AttentionModuleMixin + from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry from .attention_processor import Attention, MochiAttention if isinstance(config, ContextParallelConfig): @@ -1509,8 +1510,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): # Step 1: Validate attention backend supports context parallelism if enabled if config.context_parallel_config is not None: - from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry - for module in self.modules(): if not isinstance(module, attention_classes): continue