From cca53814a380d2c60c81010c050aa8dbe59bac69 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 14 Aug 2025 07:12:43 +0200 Subject: [PATCH] add ulysses backward --- src/diffusers/models/attention_dispatch.py | 28 +++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ebf17678ac..43428fc472 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1040,7 +1040,33 @@ class TemplatedUlyssesAttention(torch.autograd.Function): grad_out: torch.Tensor, *args, ): - raise NotImplementedError("Backward pass is not implemented for TemplatedUlyssesAttention.") + parallel_config = _AttentionBackendRegistry._parallel_config + ulysses_mesh = parallel_config._ulysses_mesh + world_size = parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + B, S_LOCAL, H, D = grad_out.shape + H_LOCAL = H // world_size + + grad_out = grad_out.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + grad_out = _wait_tensor(funcol.all_to_all_single(grad_out, None, None, group=group)) + grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx.op_ctx, grad_out) + + grad_query, grad_key, grad_value = ( + x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + for x in (grad_query_op, grad_key_op, grad_value_op) + ) + grad_query, grad_key, grad_value = ( + _wait_tensor(funcol.all_to_all_single(x, None, None, group=group)) + for x in (grad_query, grad_key, grad_value) + ) + grad_query, grad_key, grad_value = ( + x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value) + ) + + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None def _templated_context_parallel_attention(