1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Aryan
2025-07-30 07:22:08 +02:00
parent 1ffc03e0ad
commit fa5d017e76
5 changed files with 164 additions and 99 deletions

View File

@@ -23,9 +23,8 @@ from ..models._modeling_parallel import (
ContextParallelInput,
ContextParallelModelPlan,
ContextParallelOutput,
ParallelConfig,
_InternalParallelConfig,
)
from ..models.attention_dispatch import _parallel_context
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from .hooks import HookRegistry, ModelHook
@@ -33,7 +32,6 @@ from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_CONTEXT_PARALLEL_MODEL_HOOK = "context_parallel_model_hook"
_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
@@ -76,7 +74,7 @@ class ModuleForwardMetadata:
def apply_context_parallel(
module: torch.nn.Module,
parallel_config: ParallelConfig,
parallel_config: _InternalParallelConfig,
plan: Dict[str, ContextParallelModelPlan],
) -> None:
"""Apply context parallel on a model."""
@@ -105,45 +103,26 @@ def apply_context_parallel(
registry = HookRegistry.check_if_exists_or_initialize(m)
registry.register_hook(hook, hook_name)
# HACK: we cannot use context managers or setattr or similar solutions in an overwritten forward
# diffusers hook method because Dynamo fails to trace it. Instead, we make use of module hooks
# available in pytorch to set the parallel context before/after the forward/backward pass.
# It is dirty, but fullgraph=True tracing works because of this and I haven't found a better solution yet.
# The previous/older implementation simply did this:
# def new_forward(self, ...):
# with _parallel_context(parallel_config):
# return self.fn_ref.original_forward(*args, **kwargs)
# TODO: ask help from Pytorch team on how to improve this
@torch.compiler.disable
def forward_pre_hook(module, args):
module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config)
module._diffusers_parallel_config_setter_context.__enter__()
@torch.compiler.disable
def forward_hook(module, args, output):
if module._diffusers_parallel_config_setter_context is not None:
module._diffusers_parallel_config_setter_context.__exit__(None, None, None)
module._diffusers_parallel_config_setter_context = None
def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
for module_id, cp_model_plan in plan.items():
submodule = _get_submodule_by_name(module, module_id)
if not isinstance(submodule, list):
submodule = [submodule]
@torch.compiler.disable
def backward_pre_hook(module, grad_output):
module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config)
module._diffusers_parallel_config_setter_context.__enter__()
@torch.compiler.disable
def backward_hook(module, grad_output, grad_input):
if module._diffusers_parallel_config_setter_context is not None:
module._diffusers_parallel_config_setter_context.__exit__(None, None, None)
module._diffusers_parallel_config_setter_context = None
module.register_forward_pre_hook(forward_pre_hook)
module.register_forward_hook(forward_hook)
module.register_full_backward_pre_hook(backward_pre_hook)
module.register_full_backward_hook(backward_hook)
for m in submodule:
registry = HookRegistry.check_if_exists_or_initialize(m)
if isinstance(cp_model_plan, dict):
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
else:
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
registry.remove_hook(hook_name)
class ContextParallelSplitHook(ModelHook):
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None:
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: _InternalParallelConfig) -> None:
super().__init__()
self.metadata = metadata
self.parallel_config = parallel_config
@@ -228,7 +207,7 @@ class ContextParallelSplitHook(ModelHook):
class ContextParallelGatherHook(ModelHook):
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None:
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: _InternalParallelConfig) -> None:
super().__init__()
self.metadata = metadata
self.parallel_config = parallel_config

View File

@@ -25,6 +25,7 @@ from ..utils import (
_import_structure = {}
if is_torch_available():
_import_structure["_modeling_parallel"] = ["ParallelConfig"]
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
_import_structure["auto_model"] = ["AutoModel"]
@@ -112,6 +113,7 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from ._modeling_parallel import ParallelConfig
from .adapter import MultiAdapter, T2IAdapter
from .attention_dispatch import AttentionBackendName, attention_backend
from .auto_model import AutoModel

View File

@@ -35,6 +35,18 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class ParallelConfig:
ring_degree: Optional[int] = None
ulysses_degree: Optional[int] = None
def __post_init__(self):
if self.ring_degree is None:
self.ring_degree = 1
if self.ulysses_degree is None:
self.ulysses_degree = 1
@dataclass
class _InternalParallelConfig:
rank: int
world_size: int
ring_degree: int

View File

@@ -38,8 +38,9 @@ from ..utils import (
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
if TYPE_CHECKING:
from ._modeling_parallel import ParallelConfig
from ._modeling_parallel import _InternalParallelConfig
_REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_SAGE_VERSION = "2.1.1"
@@ -58,12 +59,12 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
if _CAN_USE_FLASH_ATTN:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
else:
flash_attn_func = None
flash_attn_varlen_func = None
_flash_attn_forward = None
_flash_attn_backward = None
_wrapped_flash_attn_backward = None
_wrapped_flash_attn_forward = None
if _CAN_USE_FLASH_ATTN_3:
@@ -192,7 +193,7 @@ class _AttentionBackendRegistry:
_supports_context_parallel = {}
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
_checks_enabled = DIFFUSERS_ATTN_CHECKS
_parallel_config: Optional["ParallelConfig"] = None
_parallel_config: Optional["_InternalParallelConfig"] = None
@classmethod
def register(
@@ -252,17 +253,6 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
_AttentionBackendRegistry._active_backend = old_backend
@contextlib.contextmanager
def _parallel_context(parallel_config: "ParallelConfig"):
old_parallel_config = _AttentionBackendRegistry._parallel_config
_AttentionBackendRegistry._parallel_config = parallel_config
try:
yield
finally:
_AttentionBackendRegistry._parallel_config = old_parallel_config
def dispatch_attention_fn(
query: torch.Tensor,
key: torch.Tensor,
@@ -637,14 +627,15 @@ class _cudnn_attention_af(torch.autograd.Function):
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.")
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
ctx.attn_mask = attn_mask
tensors_to_save = ()
# Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results
# if the input tensors are not contiguous.
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
query = query.transpose(1, 2).contiguous()
tensors_to_save += (query, key, value)
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query,
@@ -659,9 +650,14 @@ class _cudnn_attention_af(torch.autograd.Function):
)
)
tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
ctx.save_for_backward(*tensors_to_save)
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
ctx.attn_mask = attn_mask
ctx.max_q = max_q
ctx.max_k = max_k
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
out = out.transpose(1, 2).contiguous()
if lse is not None:
@@ -674,8 +670,12 @@ class _cudnn_attention_af(torch.autograd.Function):
grad_out: torch.Tensor,
*args,
):
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
saved_tensors = ctx.saved_tensors if hasattr(ctx, "saved_tensors") else ctx.to_save
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = saved_tensors
grad_out = grad_out.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
# Cannot pass first 5 arguments as kwargs because: https://github.com/pytorch/pytorch/blob/d26ca5de058dbcf56ac52bb43e84dd98df2ace97/torch/_dynamo/variables/torch.py#L1341
grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward(
@@ -726,15 +726,33 @@ class _flash_attention_2_af(torch.autograd.Function):
softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))
if scale is None:
scale = query.shape[-1] ** (-0.5)
# flash-attn only returns LSE if dropout_p > 0. So, we need to workaround.
parallel_config = _AttentionBackendRegistry._parallel_config
if query.requires_grad or (parallel_config is not None and parallel_config.world_size > 1):
if grad_enabled or (parallel_config is not None and parallel_config.world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
with torch.set_grad_enabled(grad_enabled):
out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
lse = lse.permute(0, 2, 1)
ctx.save_for_backward(query, key, value, out, lse, rng_state)
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
@@ -743,22 +761,6 @@ class _flash_attention_2_af(torch.autograd.Function):
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
out, lse, S_dmask, rng_state = _flash_attn_forward(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
ctx.save_for_backward(query, key, value, out, lse, rng_state)
return (out, lse) if return_lse else out
@staticmethod
@@ -767,10 +769,11 @@ class _flash_attention_2_af(torch.autograd.Function):
grad_out: torch.Tensor,
*args,
):
query, key, value, out, lse, rng_state = ctx.saved_tensors
saved_tensors = ctx.saved_tensors if hasattr(ctx, "saved_tensors") else ctx.to_save
query, key, value, out, lse, rng_state = saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
lse_d = _flash_attn_backward( # noqa: F841
lse_d = _wrapped_flash_attn_backward( # noqa: F841
grad_out,
query,
key,
@@ -876,10 +879,18 @@ class TemplatedRingAttention(torch.autograd.Function):
ring_mesh = parallel_config._ring_mesh
rank = parallel_config._ring_local_rank
world_size = parallel_config.ring_degree
next_rank = (rank + 1) % world_size
prev_out = prev_lse = None
ctx.save_for_backward(query, key, value)
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
ctx.enable_gqa = enable_gqa
ctx.return_lse = return_lse
ctx.op = op
ctx.op_ctx = torch.autograd.function.FunctionCtx()
kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
kv_buffer = kv_buffer.chunk(world_size)
@@ -887,11 +898,14 @@ class TemplatedRingAttention(torch.autograd.Function):
for i in range(world_size):
if i > 0:
kv = kv_buffer[next_rank]
key = kv[: key.numel()].reshape_as(key)
value = kv[key.numel() :].reshape_as(value)
key_numel = key.numel()
key = kv[:key_numel].reshape_as(key)
value = kv[key_numel:].reshape_as(value)
next_rank = (next_rank + 1) % world_size
out, lse = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, True)
out, lse = op.forward(
ctx.op_ctx, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, True
)
if parallel_config.convert_to_fp32:
out = out.to(torch.float32)
@@ -906,6 +920,7 @@ class TemplatedRingAttention(torch.autograd.Function):
out = out.to(query.dtype)
lse = lse.squeeze(-1)
return (out, lse) if return_lse else out
@staticmethod
@@ -914,7 +929,55 @@ class TemplatedRingAttention(torch.autograd.Function):
grad_out: torch.Tensor,
*args,
):
raise NotImplementedError("Backward pass is not implemented for TemplatedRingAttention.")
parallel_config = _AttentionBackendRegistry._parallel_config
ring_mesh = parallel_config._ring_mesh
rank = parallel_config._ring_local_rank
world_size = parallel_config.ring_degree
next_rank = (rank + 1) % world_size
next_ranks = list(range(1, world_size)) + [0]
query, key, value = ctx.saved_tensors
accum_dtype = torch.float32 if parallel_config.convert_to_fp32 else query.dtype
grad_query = torch.zeros_like(query, dtype=accum_dtype)
grad_key = torch.zeros_like(key, dtype=accum_dtype)
grad_value = torch.zeros_like(value, dtype=accum_dtype)
next_grad_kv = None
kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
kv_buffer = kv_buffer.chunk(world_size)
for i in range(world_size):
if i > 0:
kv = kv_buffer[next_rank]
key_numel = key.numel()
key = kv[:key_numel].reshape_as(key)
value = kv[key_numel:].reshape_as(value)
next_rank = (next_rank + 1) % world_size
saved_tensors = list(ctx.op_ctx.to_save)
saved_tensors[1] = key
saved_tensors[2] = value
ctx.op_ctx.to_save = tuple(saved_tensors)
grad_query_op, grad_key_op, grad_value_op, *_ = ctx.op.backward(ctx.op_ctx, grad_out)
if i > 0:
grad_kv_buffer = _wait_tensor(next_grad_kv)
grad_key_numel = grad_key.numel()
grad_key = grad_kv_buffer[:grad_key_numel].reshape_as(grad_key)
grad_value = grad_kv_buffer[grad_key_numel:].reshape_as(grad_value)
grad_query += grad_query_op
grad_key += grad_key_op
grad_value += grad_value_op
if i < world_size - 1:
grad_kv_buffer = torch.cat([grad_key.flatten(), grad_value.flatten()]).contiguous()
next_grad_kv = funcol.permute_tensor(grad_kv_buffer, next_ranks, group=ring_mesh.get_group())
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None
class TemplatedUlyssesAttention(torch.autograd.Function):
@@ -937,6 +1000,8 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
world_size = parallel_config.ulysses_degree
group = ulysses_mesh.get_group()
ctx.op_ctx = torch.autograd.function.FunctionCtx()
B, S_LOCAL, H, D = query.shape
H_LOCAL = H // world_size
query, key, value = (
@@ -948,7 +1013,7 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
)
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
out = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
out = op.forward(ctx.op_ctx, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
if return_lse:
out, lse, *_ = out
@@ -1291,7 +1356,7 @@ def _native_cudnn_attention(
lse = None
if parallel_config is None and not return_lse:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
query=query,

View File

@@ -63,6 +63,7 @@ from ..utils.hub_utils import (
populate_model_card,
)
from ..utils.torch_utils import empty_device_cache
from ._modeling_parallel import ContextParallelModelPlan, ParallelConfig, _InternalParallelConfig
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
@@ -1501,19 +1502,20 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
)
def parallelize(self, *, ring_degree: int = 1, ulysses_degree: int = 1, cp_plan=None):
from ..hooks.context_parallel import ParallelConfig, apply_context_parallel
@contextmanager
def parallelize(self, *, config: ParallelConfig, cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None):
from ..hooks.context_parallel import apply_context_parallel, remove_context_parallel
from .attention_dispatch import _AttentionBackendRegistry
# TODO(aryan): add cp_plan type hint
logger.warning(
"`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
)
if not torch.distributed.is_initialized():
raise RuntimeError("torch.distributed must be initialized before calling `parallelize`.")
if ring_degree < 1 or ulysses_degree < 1:
if config.ring_degree < 1 or config.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if ring_degree > 1 and ulysses_degree > 1:
if config.ring_degree > 1 and config.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
@@ -1521,9 +1523,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
if ring_degree * ulysses_degree > world_size:
if config.ring_degree * config.ulysses_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({ring_degree}) and `ulysses_degree` ({ulysses_degree}) must not exceed the world size ({world_size})."
f"The product of `ring_degree` ({config.ring_degree}) and `ulysses_degree` ({config.ulysses_degree}) must not exceed the world size ({world_size})."
)
device_type = torch._C._get_accelerator().type
@@ -1532,14 +1534,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=(ring_degree, ulysses_degree),
mesh_shape=(config.ring_degree, config.ulysses_degree),
mesh_dim_names=("ring", "ulysses"),
)
parallel_config = ParallelConfig(
parallel_config = _InternalParallelConfig(
rank=rank,
world_size=world_size,
ring_degree=ring_degree,
ulysses_degree=ulysses_degree,
ring_degree=config.ring_degree,
ulysses_degree=config.ulysses_degree,
device=device,
cp_mesh=cp_mesh,
)
@@ -1550,6 +1552,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
apply_context_parallel(self, parallel_config, cp_plan)
_AttentionBackendRegistry._parallel_config = parallel_config
yield
remove_context_parallel(self, cp_plan)
@classmethod
def _load_pretrained_model(