From 215104f49b012440cf23773b8cddb3dcd8dd5bb7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 8 Aug 2025 11:01:10 +0200 Subject: [PATCH] update --- src/diffusers/hooks/context_parallel.py | 17 ++++++++++++++++- src/diffusers/models/attention_dispatch.py | 3 ++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 24644c3ebe..9121475fe8 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -233,6 +233,21 @@ class ContextParallelGatherHook(ModelHook): return output[0] if is_tensor else tuple(output) +class AllGatherFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor, dim, group): + ctx.dim = dim + ctx.group = group + ctx.world_size = torch.distributed.get_world_size(group) + ctx.rank = torch.distributed.get_rank(group) + return funcol.all_gather_tensor(tensor, dim, group=group) + + @staticmethod + def backward(ctx, grad_output): + grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim) + return grad_chunks[ctx.rank], None, None + + class EquipartitionSharder: @classmethod def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: @@ -246,7 +261,7 @@ class EquipartitionSharder: @classmethod def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: tensor = tensor.contiguous() - tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group()) + tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group()) return tensor diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index f8b8256e51..3d21ddfbcf 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -281,7 +281,8 @@ def dispatch_attention_fn( and not _AttentionBackendRegistry._is_context_parallel_enabled(backend_name) ): raise ValueError( - f"Backend {backend_name} does not support context parallelism, but a parallel configuration is provided." + f"Backend {backend_name} either does not support context parallelism or context parallelism " + f"was enabled with a world size of 1." ) kwargs = {