mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add ulysses backward
This commit is contained in:
@@ -1040,7 +1040,33 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
):
|
||||
raise NotImplementedError("Backward pass is not implemented for TemplatedUlyssesAttention.")
|
||||
parallel_config = _AttentionBackendRegistry._parallel_config
|
||||
ulysses_mesh = parallel_config._ulysses_mesh
|
||||
world_size = parallel_config.ulysses_degree
|
||||
group = ulysses_mesh.get_group()
|
||||
|
||||
B, S_LOCAL, H, D = grad_out.shape
|
||||
H_LOCAL = H // world_size
|
||||
|
||||
grad_out = grad_out.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
|
||||
grad_out = _wait_tensor(funcol.all_to_all_single(grad_out, None, None, group=group))
|
||||
grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
|
||||
|
||||
grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx.op_ctx, grad_out)
|
||||
|
||||
grad_query, grad_key, grad_value = (
|
||||
x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
|
||||
for x in (grad_query_op, grad_key_op, grad_value_op)
|
||||
)
|
||||
grad_query, grad_key, grad_value = (
|
||||
_wait_tensor(funcol.all_to_all_single(x, None, None, group=group))
|
||||
for x in (grad_query, grad_key, grad_value)
|
||||
)
|
||||
grad_query, grad_key, grad_value = (
|
||||
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
|
||||
)
|
||||
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def _templated_context_parallel_attention(
|
||||
|
||||
Reference in New Issue
Block a user