From 428399b5906f03445c5522517c230f2921ac81c2 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 7 Oct 2025 15:54:04 +0530 Subject: [PATCH] update --- src/diffusers/models/_modeling_parallel.py | 60 ++++++++++++++-------- src/diffusers/models/attention_dispatch.py | 6 +-- src/diffusers/models/modeling_utils.py | 35 +++++-------- 3 files changed, 52 insertions(+), 49 deletions(-) diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 2a1d2cc6ce..f48b4c4969 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -79,29 +79,47 @@ class ContextParallelConfig: if self.ulysses_degree is None: self.ulysses_degree = 1 + if self.ring_degree == 1 and self.ulysses_degree == 1: + raise ValueError( + "Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference" + ) + if self.ring_degree < 1 or self.ulysses_degree < 1: + raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") + if self.ring_degree > 1 and self.ulysses_degree > 1: + raise ValueError( + "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." + ) + if self.rotate_method != "allgather": + raise NotImplementedError( + f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." + ) + + @property + def mesh_shape(self) -> Tuple[int, int]: + """Shape of the device mesh (ring_degree, ulysses_degree).""" + return (self.ring_degree, self.ulysses_degree) + + @property + def mesh_dim_names(self) -> Tuple[str, str]: + """Dimension names for the device mesh.""" + return ("ring", "ulysses") + def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh): self._rank = rank self._world_size = world_size self._device = device self._mesh = mesh - if self.ring_degree is None: - self.ring_degree = 1 - if self.ulysses_degree is None: - self.ulysses_degree = 1 - if self.rotate_method != "allgather": - raise NotImplementedError( - f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." + + if self.ulysses_degree * self.ring_degree > world_size: + raise ValueError( + f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})." ) - if self._flattened_mesh is None: - self._flattened_mesh = self._mesh._flatten() - if self._ring_mesh is None: - self._ring_mesh = self._mesh["ring"] - if self._ulysses_mesh is None: - self._ulysses_mesh = self._mesh["ulysses"] - if self._ring_local_rank is None: - self._ring_local_rank = self._ring_mesh.get_local_rank() - if self._ulysses_local_rank is None: - self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() + + self._flattened_mesh = self._mesh._flatten() + self._ring_mesh = self._mesh["ring"] + self._ulysses_mesh = self._mesh["ulysses"] + self._ring_local_rank = self._ring_mesh.get_local_rank() + self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() @dataclass @@ -119,7 +137,7 @@ class ParallelConfig: _rank: int = None _world_size: int = None _device: torch.device = None - _cp_mesh: torch.distributed.device_mesh.DeviceMesh = None + _mesh: torch.distributed.device_mesh.DeviceMesh = None def setup( self, @@ -127,14 +145,14 @@ class ParallelConfig: world_size: int, device: torch.device, *, - cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, + mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, ): self._rank = rank self._world_size = world_size self._device = device - self._cp_mesh = cp_mesh + self._mesh = mesh if self.context_parallel_config is not None: - self.context_parallel_config.setup(rank, world_size, device, cp_mesh) + self.context_parallel_config.setup(rank, world_size, device, mesh) @dataclass(frozen=True) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 025cd443f0..efbb3afc5d 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -244,11 +244,7 @@ class _AttentionBackendRegistry: supports_context_parallel = ( backend in cls._supports_context_parallel and cls._supports_context_parallel[backend] ) - is_degree_greater_than_1 = parallel_config is not None and ( - parallel_config.context_parallel_config.ring_degree > 1 - or parallel_config.context_parallel_config.ulysses_degree > 1 - ) - return supports_context_parallel and is_degree_greater_than_1 + return supports_context_parallel and parallel_config.context_parallel_config is not None @contextlib.contextmanager diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1af7ba9ac5..57c3c8866f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1483,46 +1483,35 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): config: Union[ParallelConfig, ContextParallelConfig], cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None, ): - from ..hooks.context_parallel import apply_context_parallel - from .attention import AttentionModuleMixin - from .attention_processor import Attention, MochiAttention - logger.warning( "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning." ) - if isinstance(config, ContextParallelConfig): - config = ParallelConfig(context_parallel_config=config) - if not torch.distributed.is_initialized(): raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.") + from ..hooks.context_parallel import apply_context_parallel + from .attention import AttentionModuleMixin + from .attention_processor import Attention, MochiAttention + + if isinstance(config, ContextParallelConfig): + config = ParallelConfig(context_parallel_config=config) + rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() device_type = torch._C._get_accelerator().type device_module = torch.get_device_module(device_type) device = torch.device(device_type, rank % device_module.device_count()) - cp_mesh = None + mesh = None if config.context_parallel_config is not None: cp_config = config.context_parallel_config - if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1: - raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") - if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1: - raise ValueError( - "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." - ) - if cp_config.ring_degree * cp_config.ulysses_degree > world_size: - raise ValueError( - f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})." - ) - cp_mesh = torch.distributed.device_mesh.init_device_mesh( + mesh = torch.distributed.device_mesh.init_device_mesh( device_type=device_type, - mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree), - mesh_dim_names=("ring", "ulysses"), + mesh_shape=cp_config.mesh_shape, + mesh_dim_names=cp_config.mesh_dim_names, ) - - config.setup(rank, world_size, device, cp_mesh=cp_mesh) + config.setup(rank, world_size, device, mesh=mesh) if cp_plan is None and self._cp_plan is None: raise ValueError(