diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 788d030afa..c3697f967d 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -105,19 +105,41 @@ def apply_context_parallel( registry = HookRegistry.check_if_exists_or_initialize(m) registry.register_hook(hook, hook_name) - registry = HookRegistry.check_if_exists_or_initialize(module) - hook = ContextParallelModelHook(parallel_config) - registry.register_hook(hook, _CONTEXT_PARALLEL_MODEL_HOOK) + # HACK: we cannot use context managers or setattr or similar solutions in an overwritten forward + # diffusers hook method because Dynamo fails to trace it. Instead, we make use of module hooks + # available in pytorch to set the parallel context before/after the forward/backward pass. + # It is dirty, but fullgraph=True tracing works because of this and I haven't found a better solution yet. + # The previous/older implementation simply did this: + # def new_forward(self, ...): + # with _parallel_context(parallel_config): + # return self.fn_ref.original_forward(*args, **kwargs) + # TODO: ask help from Pytorch team on how to improve this + @torch.compiler.disable + def forward_pre_hook(module, args): + module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config) + module._diffusers_parallel_config_setter_context.__enter__() + @torch.compiler.disable + def forward_hook(module, args, output): + if module._diffusers_parallel_config_setter_context is not None: + module._diffusers_parallel_config_setter_context.__exit__(None, None, None) + module._diffusers_parallel_config_setter_context = None -class ContextParallelModelHook(ModelHook): - def __init__(self, parallel_config: ParallelConfig) -> None: - super().__init__() - self.parallel_config = parallel_config + @torch.compiler.disable + def backward_pre_hook(module, grad_output): + module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config) + module._diffusers_parallel_config_setter_context.__enter__() - def new_forward(self, module: torch.nn.Module, *args, **kwargs): - with _parallel_context(self.parallel_config): - return self.fn_ref.original_forward(*args, **kwargs) + @torch.compiler.disable + def backward_hook(module, grad_output, grad_input): + if module._diffusers_parallel_config_setter_context is not None: + module._diffusers_parallel_config_setter_context.__exit__(None, None, None) + module._diffusers_parallel_config_setter_context = None + + module.register_forward_pre_hook(forward_pre_hook) + module.register_forward_hook(forward_hook) + module.register_full_backward_pre_hook(backward_pre_hook) + module.register_full_backward_hook(backward_hook) class ContextParallelSplitHook(ModelHook): @@ -234,13 +256,15 @@ class ContextParallelGatherHook(ModelHook): class EquipartitionSharder: @classmethod - @torch.compiler.disable def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: assert tensor.size()[dim] % mesh.size() == 0 - return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()] + + # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank) + # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()] + + return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())] @classmethod - @torch.compiler.disable def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: tensor = tensor.contiguous() tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group()) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 9323c45acb..28837c06b8 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, import torch import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor from ..utils import ( get_logger, @@ -245,9 +246,6 @@ def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIV @contextlib.contextmanager def _parallel_context(parallel_config: "ParallelConfig"): - """ - Context manager to set the parallel configuration for attention backends that support it. - """ old_parallel_config = _AttentionBackendRegistry._parallel_config _AttentionBackendRegistry._parallel_config = parallel_config @@ -789,6 +787,16 @@ class _sage_attention_af(torch.autograd.Function): # ===== Context parallel ===== +# Reference: +# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827 +# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246 +# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method): +def _wait_tensor(tensor): + if isinstance(tensor, funcol.AsyncCollectiveTensor): + tensor = tensor.wait() + return tensor + + class TemplatedRingAttention(torch.autograd.Function): @staticmethod def forward( @@ -875,7 +883,9 @@ class TemplatedUlyssesAttention(torch.autograd.Function): x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() for x in (query, key, value) ) - query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value)) + query, key, value = ( + _wait_tensor(funcol.all_to_all_single(x, None, None, group=group)) for x in (query, key, value) + ) query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) out = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse) @@ -883,12 +893,12 @@ class TemplatedUlyssesAttention(torch.autograd.Function): out, lse, *_ = out out = out.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() - out = funcol.all_to_all_single(out, None, None, group=group).wait() + out = _wait_tensor(funcol.all_to_all_single(out, None, None, group=group)) out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() if return_lse: lse = lse.reshape(B, world_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous() - lse = funcol.all_to_all_single(lse, None, None, group=group).wait() + lse = _wait_tensor(funcol.all_to_all_single(lse, None, None, group=group)) lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous() else: lse = None