diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 5475858dc0..7c647b5c0a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1487,8 +1487,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning." ) - if not torch.distributed.is_initialized(): - raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.") + if not torch.distributed.is_available() and not torch.distributed.is_initialized(): + raise RuntimeError( + "torch.distributed must be available and initialized before calling `enable_parallelism`." + ) from ..hooks.context_parallel import apply_context_parallel from .attention import AttentionModuleMixin