mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -233,6 +233,21 @@ class ContextParallelGatherHook(ModelHook):
|
||||
return output[0] if is_tensor else tuple(output)
|
||||
|
||||
|
||||
class AllGatherFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, tensor, dim, group):
|
||||
ctx.dim = dim
|
||||
ctx.group = group
|
||||
ctx.world_size = torch.distributed.get_world_size(group)
|
||||
ctx.rank = torch.distributed.get_rank(group)
|
||||
return funcol.all_gather_tensor(tensor, dim, group=group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
|
||||
return grad_chunks[ctx.rank], None, None
|
||||
|
||||
|
||||
class EquipartitionSharder:
|
||||
@classmethod
|
||||
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
|
||||
@@ -246,7 +261,7 @@ class EquipartitionSharder:
|
||||
@classmethod
|
||||
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
|
||||
tensor = tensor.contiguous()
|
||||
tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group())
|
||||
tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
|
||||
return tensor
|
||||
|
||||
|
||||
|
||||
@@ -281,7 +281,8 @@ def dispatch_attention_fn(
|
||||
and not _AttentionBackendRegistry._is_context_parallel_enabled(backend_name)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Backend {backend_name} does not support context parallelism, but a parallel configuration is provided."
|
||||
f"Backend {backend_name} either does not support context parallelism or context parallelism "
|
||||
f"was enabled with a world size of 1."
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
|
||||
Reference in New Issue
Block a user