diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index f48b4c4969..85661336d0 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -44,11 +44,16 @@ class ContextParallelConfig: Args: ring_degree (`int`, *optional*, defaults to `1`): - Number of devices to use for ring attention within a context parallel region. Must be a divisor of the - total number of devices in the context parallel mesh. + Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes + attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N + of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best + for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a + context parallel region. Must be a divisor of the total number of devices in the context parallel mesh. ulysses_degree (`int`, *optional*, defaults to `1`): - Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the - total number of devices in the context parallel mesh. + Number of devices to use for Ulysses Attention. Sequence split across devices. Each device computes local + QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all KV), + requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with good + interconnect bandwidth. convert_to_fp32 (`bool`, *optional*, defaults to `True`): Whether to convert output and LSE to float32 for ring attention numerical stability. rotate_method (`str`, *optional*, defaults to `"allgather"`): @@ -96,7 +101,6 @@ class ContextParallelConfig: @property def mesh_shape(self) -> Tuple[int, int]: - """Shape of the device mesh (ring_degree, ulysses_degree).""" return (self.ring_degree, self.ulysses_degree) @property diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 5d4c4ab187..e4a8f30e72 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1509,7 +1509,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): attention_classes = (Attention, MochiAttention, AttentionModuleMixin) - # Step 1: Validate attention backend supports context parallelism if enabled if config.context_parallel_config is not None: for module in self.modules(): if not isinstance(module, attention_classes):