1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

make torch compile compatible

This commit is contained in:
Aryan
2025-07-16 23:55:20 +02:00
parent 171152f275
commit 62f164d04d
2 changed files with 53 additions and 19 deletions

View File

@@ -105,19 +105,41 @@ def apply_context_parallel(
registry = HookRegistry.check_if_exists_or_initialize(m)
registry.register_hook(hook, hook_name)
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = ContextParallelModelHook(parallel_config)
registry.register_hook(hook, _CONTEXT_PARALLEL_MODEL_HOOK)
# HACK: we cannot use context managers or setattr or similar solutions in an overwritten forward
# diffusers hook method because Dynamo fails to trace it. Instead, we make use of module hooks
# available in pytorch to set the parallel context before/after the forward/backward pass.
# It is dirty, but fullgraph=True tracing works because of this and I haven't found a better solution yet.
# The previous/older implementation simply did this:
# def new_forward(self, ...):
# with _parallel_context(parallel_config):
# return self.fn_ref.original_forward(*args, **kwargs)
# TODO: ask help from Pytorch team on how to improve this
@torch.compiler.disable
def forward_pre_hook(module, args):
module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config)
module._diffusers_parallel_config_setter_context.__enter__()
@torch.compiler.disable
def forward_hook(module, args, output):
if module._diffusers_parallel_config_setter_context is not None:
module._diffusers_parallel_config_setter_context.__exit__(None, None, None)
module._diffusers_parallel_config_setter_context = None
class ContextParallelModelHook(ModelHook):
def __init__(self, parallel_config: ParallelConfig) -> None:
super().__init__()
self.parallel_config = parallel_config
@torch.compiler.disable
def backward_pre_hook(module, grad_output):
module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config)
module._diffusers_parallel_config_setter_context.__enter__()
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
with _parallel_context(self.parallel_config):
return self.fn_ref.original_forward(*args, **kwargs)
@torch.compiler.disable
def backward_hook(module, grad_output, grad_input):
if module._diffusers_parallel_config_setter_context is not None:
module._diffusers_parallel_config_setter_context.__exit__(None, None, None)
module._diffusers_parallel_config_setter_context = None
module.register_forward_pre_hook(forward_pre_hook)
module.register_forward_hook(forward_hook)
module.register_full_backward_pre_hook(backward_pre_hook)
module.register_full_backward_hook(backward_hook)
class ContextParallelSplitHook(ModelHook):
@@ -234,13 +256,15 @@ class ContextParallelGatherHook(ModelHook):
class EquipartitionSharder:
@classmethod
@torch.compiler.disable
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
assert tensor.size()[dim] % mesh.size() == 0
return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
@classmethod
@torch.compiler.disable
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
tensor = tensor.contiguous()
tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group())

View File

@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional,
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor
from ..utils import (
get_logger,
@@ -245,9 +246,6 @@ def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIV
@contextlib.contextmanager
def _parallel_context(parallel_config: "ParallelConfig"):
"""
Context manager to set the parallel configuration for attention backends that support it.
"""
old_parallel_config = _AttentionBackendRegistry._parallel_config
_AttentionBackendRegistry._parallel_config = parallel_config
@@ -789,6 +787,16 @@ class _sage_attention_af(torch.autograd.Function):
# ===== Context parallel =====
# Reference:
# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827
# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246
# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method):
def _wait_tensor(tensor):
if isinstance(tensor, funcol.AsyncCollectiveTensor):
tensor = tensor.wait()
return tensor
class TemplatedRingAttention(torch.autograd.Function):
@staticmethod
def forward(
@@ -875,7 +883,9 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
for x in (query, key, value)
)
query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value))
query, key, value = (
_wait_tensor(funcol.all_to_all_single(x, None, None, group=group)) for x in (query, key, value)
)
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
out = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
@@ -883,12 +893,12 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
out, lse, *_ = out
out = out.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
out = funcol.all_to_all_single(out, None, None, group=group).wait()
out = _wait_tensor(funcol.all_to_all_single(out, None, None, group=group))
out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
if return_lse:
lse = lse.reshape(B, world_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
lse = funcol.all_to_all_single(lse, None, None, group=group).wait()
lse = _wait_tensor(funcol.all_to_all_single(lse, None, None, group=group))
lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
else:
lse = None