mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
make torch compile compatible
This commit is contained in:
@@ -105,19 +105,41 @@ def apply_context_parallel(
|
||||
registry = HookRegistry.check_if_exists_or_initialize(m)
|
||||
registry.register_hook(hook, hook_name)
|
||||
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = ContextParallelModelHook(parallel_config)
|
||||
registry.register_hook(hook, _CONTEXT_PARALLEL_MODEL_HOOK)
|
||||
# HACK: we cannot use context managers or setattr or similar solutions in an overwritten forward
|
||||
# diffusers hook method because Dynamo fails to trace it. Instead, we make use of module hooks
|
||||
# available in pytorch to set the parallel context before/after the forward/backward pass.
|
||||
# It is dirty, but fullgraph=True tracing works because of this and I haven't found a better solution yet.
|
||||
# The previous/older implementation simply did this:
|
||||
# def new_forward(self, ...):
|
||||
# with _parallel_context(parallel_config):
|
||||
# return self.fn_ref.original_forward(*args, **kwargs)
|
||||
# TODO: ask help from Pytorch team on how to improve this
|
||||
@torch.compiler.disable
|
||||
def forward_pre_hook(module, args):
|
||||
module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config)
|
||||
module._diffusers_parallel_config_setter_context.__enter__()
|
||||
|
||||
@torch.compiler.disable
|
||||
def forward_hook(module, args, output):
|
||||
if module._diffusers_parallel_config_setter_context is not None:
|
||||
module._diffusers_parallel_config_setter_context.__exit__(None, None, None)
|
||||
module._diffusers_parallel_config_setter_context = None
|
||||
|
||||
class ContextParallelModelHook(ModelHook):
|
||||
def __init__(self, parallel_config: ParallelConfig) -> None:
|
||||
super().__init__()
|
||||
self.parallel_config = parallel_config
|
||||
@torch.compiler.disable
|
||||
def backward_pre_hook(module, grad_output):
|
||||
module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config)
|
||||
module._diffusers_parallel_config_setter_context.__enter__()
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
with _parallel_context(self.parallel_config):
|
||||
return self.fn_ref.original_forward(*args, **kwargs)
|
||||
@torch.compiler.disable
|
||||
def backward_hook(module, grad_output, grad_input):
|
||||
if module._diffusers_parallel_config_setter_context is not None:
|
||||
module._diffusers_parallel_config_setter_context.__exit__(None, None, None)
|
||||
module._diffusers_parallel_config_setter_context = None
|
||||
|
||||
module.register_forward_pre_hook(forward_pre_hook)
|
||||
module.register_forward_hook(forward_hook)
|
||||
module.register_full_backward_pre_hook(backward_pre_hook)
|
||||
module.register_full_backward_hook(backward_hook)
|
||||
|
||||
|
||||
class ContextParallelSplitHook(ModelHook):
|
||||
@@ -234,13 +256,15 @@ class ContextParallelGatherHook(ModelHook):
|
||||
|
||||
class EquipartitionSharder:
|
||||
@classmethod
|
||||
@torch.compiler.disable
|
||||
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
|
||||
assert tensor.size()[dim] % mesh.size() == 0
|
||||
return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
|
||||
|
||||
# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
|
||||
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
|
||||
|
||||
return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
|
||||
|
||||
@classmethod
|
||||
@torch.compiler.disable
|
||||
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
|
||||
tensor = tensor.contiguous()
|
||||
tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group())
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional,
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.distributed.tensor
|
||||
|
||||
from ..utils import (
|
||||
get_logger,
|
||||
@@ -245,9 +246,6 @@ def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIV
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _parallel_context(parallel_config: "ParallelConfig"):
|
||||
"""
|
||||
Context manager to set the parallel configuration for attention backends that support it.
|
||||
"""
|
||||
old_parallel_config = _AttentionBackendRegistry._parallel_config
|
||||
_AttentionBackendRegistry._parallel_config = parallel_config
|
||||
|
||||
@@ -789,6 +787,16 @@ class _sage_attention_af(torch.autograd.Function):
|
||||
# ===== Context parallel =====
|
||||
|
||||
|
||||
# Reference:
|
||||
# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827
|
||||
# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246
|
||||
# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method):
|
||||
def _wait_tensor(tensor):
|
||||
if isinstance(tensor, funcol.AsyncCollectiveTensor):
|
||||
tensor = tensor.wait()
|
||||
return tensor
|
||||
|
||||
|
||||
class TemplatedRingAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
@@ -875,7 +883,9 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
|
||||
for x in (query, key, value)
|
||||
)
|
||||
query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value))
|
||||
query, key, value = (
|
||||
_wait_tensor(funcol.all_to_all_single(x, None, None, group=group)) for x in (query, key, value)
|
||||
)
|
||||
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
|
||||
|
||||
out = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
|
||||
@@ -883,12 +893,12 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
out, lse, *_ = out
|
||||
|
||||
out = out.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
|
||||
out = funcol.all_to_all_single(out, None, None, group=group).wait()
|
||||
out = _wait_tensor(funcol.all_to_all_single(out, None, None, group=group))
|
||||
out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
|
||||
|
||||
if return_lse:
|
||||
lse = lse.reshape(B, world_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
|
||||
lse = funcol.all_to_all_single(lse, None, None, group=group).wait()
|
||||
lse = _wait_tensor(funcol.all_to_all_single(lse, None, None, group=group))
|
||||
lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
|
||||
else:
|
||||
lse = None
|
||||
|
||||
Reference in New Issue
Block a user