1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Aryan
2025-08-08 11:01:10 +02:00
parent fa5d017e76
commit 215104f49b
2 changed files with 18 additions and 2 deletions

View File

@@ -233,6 +233,21 @@ class ContextParallelGatherHook(ModelHook):
return output[0] if is_tensor else tuple(output)
class AllGatherFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, dim, group):
ctx.dim = dim
ctx.group = group
ctx.world_size = torch.distributed.get_world_size(group)
ctx.rank = torch.distributed.get_rank(group)
return funcol.all_gather_tensor(tensor, dim, group=group)
@staticmethod
def backward(ctx, grad_output):
grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
return grad_chunks[ctx.rank], None, None
class EquipartitionSharder:
@classmethod
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
@@ -246,7 +261,7 @@ class EquipartitionSharder:
@classmethod
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
tensor = tensor.contiguous()
tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group())
tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
return tensor

View File

@@ -281,7 +281,8 @@ def dispatch_attention_fn(
and not _AttentionBackendRegistry._is_context_parallel_enabled(backend_name)
):
raise ValueError(
f"Backend {backend_name} does not support context parallelism, but a parallel configuration is provided."
f"Backend {backend_name} either does not support context parallelism or context parallelism "
f"was enabled with a world size of 1."
)
kwargs = {