1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
DN6
2025-10-30 22:30:47 +05:30
parent 56114f46cc
commit 450564563e
2 changed files with 9 additions and 6 deletions

View File

@@ -44,11 +44,16 @@ class ContextParallelConfig:
Args:
ring_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
ulysses_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
Number of devices to use for Ulysses Attention. Sequence split across devices. Each device computes local
QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all KV),
requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with good
interconnect bandwidth.
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
Whether to convert output and LSE to float32 for ring attention numerical stability.
rotate_method (`str`, *optional*, defaults to `"allgather"`):
@@ -96,7 +101,6 @@ class ContextParallelConfig:
@property
def mesh_shape(self) -> Tuple[int, int]:
"""Shape of the device mesh (ring_degree, ulysses_degree)."""
return (self.ring_degree, self.ulysses_degree)
@property

View File

@@ -1509,7 +1509,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
# Step 1: Validate attention backend supports context parallelism if enabled
if config.context_parallel_config is not None:
for module in self.modules():
if not isinstance(module, attention_classes):