1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge branch 'main' into remove-explicit-typing

This commit is contained in:
Sayak Paul
2026-01-13 09:23:41 +05:30
committed by GitHub
3 changed files with 212 additions and 8 deletions

View File

@@ -333,3 +333,31 @@ pipeline = DiffusionPipeline.from_pretrained(
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
).to(device)
```
### Unified Attention
[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout.
This hybrid approach leverages the strengths of both methods:
- **Ulysses Attention** efficiently parallelizes across attention heads
- **Ring Attention** handles very long sequences with minimal memory overhead
- Together, they enable 2D parallelization across both heads and sequence dimensions
[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping).
Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`].
```py
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2))
```
> [!TIP]
> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices).
We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532) on a node of 4 H100 GPUs. The results are summarized as follows:
| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) |
|--------------------|------------------|-------------|------------------|
| ulysses | 6670.789 | 7.50 | 33.85 |
| ring | 13076.492 | 3.82 | 56.02 |
| unified_balanced | 11068.705 | 4.52 | 33.85 |
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to number of attention-heads, a limitation that is solved by unified attention.

View File

@@ -90,10 +90,6 @@ class ContextParallelConfig:
)
if self.ring_degree < 1 or self.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if self.ring_degree > 1 and self.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."

View File

@@ -1179,6 +1179,103 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
return x
def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
"""
Perform dimension sharding / reassembly across processes using _all_to_all_single.
This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or
head dimension flexibly by accepting scatter_idx and gather_idx.
Args:
x (torch.Tensor):
Input tensor. Expected shapes:
- When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim)
- When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim)
scatter_idx (int) :
Dimension along which the tensor is partitioned before all-to-all.
gather_idx (int):
Dimension along which the output is reassembled after all-to-all.
group :
Distributed process group for the Ulysses group.
Returns:
torch.Tensor: Tensor with globally exchanged dimensions.
- For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim)
- For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim)
"""
group_world_size = torch.distributed.get_world_size(group)
if scatter_idx == 2 and gather_idx == 1:
# Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
# dimension and scatters head dimension
batch_size, seq_len_local, num_heads, head_dim = x.shape
seq_len = seq_len_local * group_world_size
num_heads_local = num_heads // group_world_size
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
x_temp = (
x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim)
.transpose(0, 2)
.contiguous()
)
if group_world_size > 1:
out = _all_to_all_single(x_temp, group=group)
else:
out = x_temp
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous()
out = out.reshape(batch_size, seq_len, num_heads_local, head_dim)
return out
elif scatter_idx == 1 and gather_idx == 2:
# Used after ulysses sequence parallel in unified SP. gathers the head dimension
# scatters back the sequence dimension.
batch_size, seq_len, num_heads_local, head_dim = x.shape
num_heads = num_heads_local * group_world_size
seq_len_local = seq_len // group_world_size
# B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
x_temp = (
x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim)
.permute(1, 3, 2, 0, 4)
.reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
)
if group_world_size > 1:
output = _all_to_all_single(x_temp, group)
else:
output = x_temp
output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous()
output = output.reshape(batch_size, seq_len_local, num_heads, head_dim)
return output
else:
raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.")
class SeqAllToAllDim(torch.autograd.Function):
"""
all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange
for more info.
"""
@staticmethod
def forward(ctx, group, input, scatter_id=2, gather_id=1):
ctx.group = group
ctx.scatter_id = scatter_id
ctx.gather_id = gather_id
return _all_to_all_dim_exchange(input, scatter_id, gather_id, group)
@staticmethod
def backward(ctx, grad_outputs):
grad_input = SeqAllToAllDim.apply(
ctx.group,
grad_outputs,
ctx.gather_id, # reversed
ctx.scatter_id, # reversed
)
return (None, grad_input, None, None)
class TemplatedRingAttention(torch.autograd.Function):
@staticmethod
def forward(
@@ -1239,7 +1336,10 @@ class TemplatedRingAttention(torch.autograd.Function):
out = out.to(torch.float32)
lse = lse.to(torch.float32)
lse = lse.unsqueeze(-1)
# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
@@ -1300,7 +1400,7 @@ class TemplatedRingAttention(torch.autograd.Function):
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
class TemplatedUlyssesAttention(torch.autograd.Function):
@@ -1395,7 +1495,69 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
)
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
def _templated_unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
dropout_p: float,
is_causal: bool,
scale: Optional[float],
enable_gqa: bool,
return_lse: bool,
forward_op,
backward_op,
_parallel_config: Optional["ParallelConfig"] = None,
scatter_idx: int = 2,
gather_idx: int = 1,
):
"""
Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719
"""
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
ulysses_group = ulysses_mesh.get_group()
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
out = TemplatedRingAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
if return_lse:
context_layer, lse, *_ = out
else:
context_layer = out
# context_layer is of shape (B, S, H_LOCAL, D)
output = SeqAllToAllDim.apply(
ulysses_group,
context_layer,
gather_idx,
scatter_idx,
)
if return_lse:
# lse is of shape (B, S, H_LOCAL, 1)
# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
lse = lse.squeeze(-1)
return (output, lse)
return output
def _templated_context_parallel_attention(
@@ -1421,7 +1583,25 @@ def _templated_context_parallel_attention(
raise ValueError("GQA is not yet supported for templated attention.")
# TODO: add support for unified attention with ring/ulysses degree both being > 1
if _parallel_config.context_parallel_config.ring_degree > 1:
if (
_parallel_config.context_parallel_config.ring_degree > 1
and _parallel_config.context_parallel_config.ulysses_degree > 1
):
return _templated_unified_attention(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
elif _parallel_config.context_parallel_config.ring_degree > 1:
return TemplatedRingAttention.apply(
query,
key,