mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -44,11 +44,16 @@ class ContextParallelConfig:
|
||||
|
||||
Args:
|
||||
ring_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
|
||||
attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
|
||||
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
|
||||
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
|
||||
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
|
||||
ulysses_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
Number of devices to use for Ulysses Attention. Sequence split across devices. Each device computes local
|
||||
QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all KV),
|
||||
requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with good
|
||||
interconnect bandwidth.
|
||||
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert output and LSE to float32 for ring attention numerical stability.
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
@@ -96,7 +101,6 @@ class ContextParallelConfig:
|
||||
|
||||
@property
|
||||
def mesh_shape(self) -> Tuple[int, int]:
|
||||
"""Shape of the device mesh (ring_degree, ulysses_degree)."""
|
||||
return (self.ring_degree, self.ulysses_degree)
|
||||
|
||||
@property
|
||||
|
||||
@@ -1509,7 +1509,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
|
||||
# Step 1: Validate attention backend supports context parallelism if enabled
|
||||
if config.context_parallel_config is not None:
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
|
||||
Reference in New Issue
Block a user