1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add ulysses backward

This commit is contained in:
Aryan
2025-08-14 07:12:43 +02:00
parent 27e1d27233
commit cca53814a3

View File

@@ -1040,7 +1040,33 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
grad_out: torch.Tensor,
*args,
):
raise NotImplementedError("Backward pass is not implemented for TemplatedUlyssesAttention.")
parallel_config = _AttentionBackendRegistry._parallel_config
ulysses_mesh = parallel_config._ulysses_mesh
world_size = parallel_config.ulysses_degree
group = ulysses_mesh.get_group()
B, S_LOCAL, H, D = grad_out.shape
H_LOCAL = H // world_size
grad_out = grad_out.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
grad_out = _wait_tensor(funcol.all_to_all_single(grad_out, None, None, group=group))
grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx.op_ctx, grad_out)
grad_query, grad_key, grad_value = (
x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
for x in (grad_query_op, grad_key_op, grad_value_op)
)
grad_query, grad_key, grad_value = (
_wait_tensor(funcol.all_to_all_single(x, None, None, group=group))
for x in (grad_query, grad_key, grad_value)
)
grad_query, grad_key, grad_value = (
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
)
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
def _templated_context_parallel_attention(