diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index c3697f967d..24644c3ebe 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -23,9 +23,8 @@ from ..models._modeling_parallel import ( ContextParallelInput, ContextParallelModelPlan, ContextParallelOutput, - ParallelConfig, + _InternalParallelConfig, ) -from ..models.attention_dispatch import _parallel_context from ..utils import get_logger from ..utils.torch_utils import unwrap_module from .hooks import HookRegistry, ModelHook @@ -33,7 +32,6 @@ from .hooks import HookRegistry, ModelHook logger = get_logger(__name__) # pylint: disable=invalid-name -_CONTEXT_PARALLEL_MODEL_HOOK = "context_parallel_model_hook" _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}" _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}" @@ -76,7 +74,7 @@ class ModuleForwardMetadata: def apply_context_parallel( module: torch.nn.Module, - parallel_config: ParallelConfig, + parallel_config: _InternalParallelConfig, plan: Dict[str, ContextParallelModelPlan], ) -> None: """Apply context parallel on a model.""" @@ -105,45 +103,26 @@ def apply_context_parallel( registry = HookRegistry.check_if_exists_or_initialize(m) registry.register_hook(hook, hook_name) - # HACK: we cannot use context managers or setattr or similar solutions in an overwritten forward - # diffusers hook method because Dynamo fails to trace it. Instead, we make use of module hooks - # available in pytorch to set the parallel context before/after the forward/backward pass. - # It is dirty, but fullgraph=True tracing works because of this and I haven't found a better solution yet. - # The previous/older implementation simply did this: - # def new_forward(self, ...): - # with _parallel_context(parallel_config): - # return self.fn_ref.original_forward(*args, **kwargs) - # TODO: ask help from Pytorch team on how to improve this - @torch.compiler.disable - def forward_pre_hook(module, args): - module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config) - module._diffusers_parallel_config_setter_context.__enter__() - @torch.compiler.disable - def forward_hook(module, args, output): - if module._diffusers_parallel_config_setter_context is not None: - module._diffusers_parallel_config_setter_context.__exit__(None, None, None) - module._diffusers_parallel_config_setter_context = None +def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None: + for module_id, cp_model_plan in plan.items(): + submodule = _get_submodule_by_name(module, module_id) + if not isinstance(submodule, list): + submodule = [submodule] - @torch.compiler.disable - def backward_pre_hook(module, grad_output): - module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config) - module._diffusers_parallel_config_setter_context.__enter__() - - @torch.compiler.disable - def backward_hook(module, grad_output, grad_input): - if module._diffusers_parallel_config_setter_context is not None: - module._diffusers_parallel_config_setter_context.__exit__(None, None, None) - module._diffusers_parallel_config_setter_context = None - - module.register_forward_pre_hook(forward_pre_hook) - module.register_forward_hook(forward_hook) - module.register_full_backward_pre_hook(backward_pre_hook) - module.register_full_backward_hook(backward_hook) + for m in submodule: + registry = HookRegistry.check_if_exists_or_initialize(m) + if isinstance(cp_model_plan, dict): + hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id) + elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): + hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id) + else: + raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") + registry.remove_hook(hook_name) class ContextParallelSplitHook(ModelHook): - def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None: + def __init__(self, metadata: ContextParallelModelPlan, parallel_config: _InternalParallelConfig) -> None: super().__init__() self.metadata = metadata self.parallel_config = parallel_config @@ -228,7 +207,7 @@ class ContextParallelSplitHook(ModelHook): class ContextParallelGatherHook(ModelHook): - def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None: + def __init__(self, metadata: ContextParallelModelPlan, parallel_config: _InternalParallelConfig) -> None: super().__init__() self.metadata = metadata self.parallel_config = parallel_config diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index cd1df3667a..6cc4965465 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -25,6 +25,7 @@ from ..utils import ( _import_structure = {} if is_torch_available(): + _import_structure["_modeling_parallel"] = ["ParallelConfig"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"] _import_structure["auto_model"] = ["AutoModel"] @@ -112,6 +113,7 @@ if is_flax_available(): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): + from ._modeling_parallel import ParallelConfig from .adapter import MultiAdapter, T2IAdapter from .attention_dispatch import AttentionBackendName, attention_backend from .auto_model import AutoModel diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 7df474f04b..2a4e62a6e5 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -35,6 +35,18 @@ logger = get_logger(__name__) # pylint: disable=invalid-name @dataclass class ParallelConfig: + ring_degree: Optional[int] = None + ulysses_degree: Optional[int] = None + + def __post_init__(self): + if self.ring_degree is None: + self.ring_degree = 1 + if self.ulysses_degree is None: + self.ulysses_degree = 1 + + +@dataclass +class _InternalParallelConfig: rank: int world_size: int ring_degree: int diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index eca2b4824e..f8b8256e51 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -38,8 +38,9 @@ from ..utils import ( ) from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS + if TYPE_CHECKING: - from ._modeling_parallel import ParallelConfig + from ._modeling_parallel import _InternalParallelConfig _REQUIRED_FLASH_VERSION = "2.6.3" _REQUIRED_SAGE_VERSION = "2.1.1" @@ -58,12 +59,12 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _ if _CAN_USE_FLASH_ATTN: from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward + from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward else: flash_attn_func = None flash_attn_varlen_func = None - _flash_attn_forward = None - _flash_attn_backward = None + _wrapped_flash_attn_backward = None + _wrapped_flash_attn_forward = None if _CAN_USE_FLASH_ATTN_3: @@ -192,7 +193,7 @@ class _AttentionBackendRegistry: _supports_context_parallel = {} _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) _checks_enabled = DIFFUSERS_ATTN_CHECKS - _parallel_config: Optional["ParallelConfig"] = None + _parallel_config: Optional["_InternalParallelConfig"] = None @classmethod def register( @@ -252,17 +253,6 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke _AttentionBackendRegistry._active_backend = old_backend -@contextlib.contextmanager -def _parallel_context(parallel_config: "ParallelConfig"): - old_parallel_config = _AttentionBackendRegistry._parallel_config - _AttentionBackendRegistry._parallel_config = parallel_config - - try: - yield - finally: - _AttentionBackendRegistry._parallel_config = old_parallel_config - - def dispatch_attention_fn( query: torch.Tensor, key: torch.Tensor, @@ -637,14 +627,15 @@ class _cudnn_attention_af(torch.autograd.Function): if enable_gqa: raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.") - ctx.dropout_p = dropout_p - ctx.is_causal = is_causal - ctx.scale = scale - ctx.attn_mask = attn_mask + tensors_to_save = () # Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results # if the input tensors are not contiguous. - query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) + query = query.transpose(1, 2).contiguous() + tensors_to_save += (query, key, value) + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( torch.ops.aten._scaled_dot_product_cudnn_attention( query=query, @@ -659,9 +650,14 @@ class _cudnn_attention_af(torch.autograd.Function): ) ) + tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) + ctx.save_for_backward(*tensors_to_save) + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.attn_mask = attn_mask ctx.max_q = max_q ctx.max_k = max_k - ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) out = out.transpose(1, 2).contiguous() if lse is not None: @@ -674,8 +670,12 @@ class _cudnn_attention_af(torch.autograd.Function): grad_out: torch.Tensor, *args, ): - query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors + saved_tensors = ctx.saved_tensors if hasattr(ctx, "saved_tensors") else ctx.to_save + query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = saved_tensors + grad_out = grad_out.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() # Cannot pass first 5 arguments as kwargs because: https://github.com/pytorch/pytorch/blob/d26ca5de058dbcf56ac52bb43e84dd98df2ace97/torch/_dynamo/variables/torch.py#L1341 grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward( @@ -726,15 +726,33 @@ class _flash_attention_2_af(torch.autograd.Function): softcap = 0.0 alibi_slopes = None deterministic = False + grad_enabled = any(x.requires_grad for x in (query, key, value)) if scale is None: scale = query.shape[-1] ** (-0.5) # flash-attn only returns LSE if dropout_p > 0. So, we need to workaround. parallel_config = _AttentionBackendRegistry._parallel_config - if query.requires_grad or (parallel_config is not None and parallel_config.world_size > 1): + if grad_enabled or (parallel_config is not None and parallel_config.world_size > 1): dropout_p = dropout_p if dropout_p > 0 else 1e-30 + with torch.set_grad_enabled(grad_enabled): + out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward( + query, + key, + value, + dropout_p, + scale, + is_causal, + window_size[0], + window_size[1], + softcap, + alibi_slopes, + return_lse, + ) + lse = lse.permute(0, 2, 1) + + ctx.save_for_backward(query, key, value, out, lse, rng_state) ctx.dropout_p = dropout_p ctx.scale = scale ctx.is_causal = is_causal @@ -743,22 +761,6 @@ class _flash_attention_2_af(torch.autograd.Function): ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic - out, lse, S_dmask, rng_state = _flash_attn_forward( - query, - key, - value, - dropout_p, - scale, - is_causal, - window_size[0], - window_size[1], - softcap, - alibi_slopes, - return_lse, - ) - - ctx.save_for_backward(query, key, value, out, lse, rng_state) - return (out, lse) if return_lse else out @staticmethod @@ -767,10 +769,11 @@ class _flash_attention_2_af(torch.autograd.Function): grad_out: torch.Tensor, *args, ): - query, key, value, out, lse, rng_state = ctx.saved_tensors + saved_tensors = ctx.saved_tensors if hasattr(ctx, "saved_tensors") else ctx.to_save + query, key, value, out, lse, rng_state = saved_tensors grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) - lse_d = _flash_attn_backward( # noqa: F841 + lse_d = _wrapped_flash_attn_backward( # noqa: F841 grad_out, query, key, @@ -876,10 +879,18 @@ class TemplatedRingAttention(torch.autograd.Function): ring_mesh = parallel_config._ring_mesh rank = parallel_config._ring_local_rank world_size = parallel_config.ring_degree - next_rank = (rank + 1) % world_size prev_out = prev_lse = None + ctx.save_for_backward(query, key, value) + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.enable_gqa = enable_gqa + ctx.return_lse = return_lse + ctx.op = op + ctx.op_ctx = torch.autograd.function.FunctionCtx() + kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous() kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group()) kv_buffer = kv_buffer.chunk(world_size) @@ -887,11 +898,14 @@ class TemplatedRingAttention(torch.autograd.Function): for i in range(world_size): if i > 0: kv = kv_buffer[next_rank] - key = kv[: key.numel()].reshape_as(key) - value = kv[key.numel() :].reshape_as(value) + key_numel = key.numel() + key = kv[:key_numel].reshape_as(key) + value = kv[key_numel:].reshape_as(value) next_rank = (next_rank + 1) % world_size - out, lse = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, True) + out, lse = op.forward( + ctx.op_ctx, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, True + ) if parallel_config.convert_to_fp32: out = out.to(torch.float32) @@ -906,6 +920,7 @@ class TemplatedRingAttention(torch.autograd.Function): out = out.to(query.dtype) lse = lse.squeeze(-1) + return (out, lse) if return_lse else out @staticmethod @@ -914,7 +929,55 @@ class TemplatedRingAttention(torch.autograd.Function): grad_out: torch.Tensor, *args, ): - raise NotImplementedError("Backward pass is not implemented for TemplatedRingAttention.") + parallel_config = _AttentionBackendRegistry._parallel_config + ring_mesh = parallel_config._ring_mesh + rank = parallel_config._ring_local_rank + world_size = parallel_config.ring_degree + next_rank = (rank + 1) % world_size + next_ranks = list(range(1, world_size)) + [0] + + query, key, value = ctx.saved_tensors + + accum_dtype = torch.float32 if parallel_config.convert_to_fp32 else query.dtype + grad_query = torch.zeros_like(query, dtype=accum_dtype) + grad_key = torch.zeros_like(key, dtype=accum_dtype) + grad_value = torch.zeros_like(value, dtype=accum_dtype) + next_grad_kv = None + + kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous() + kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group()) + kv_buffer = kv_buffer.chunk(world_size) + + for i in range(world_size): + if i > 0: + kv = kv_buffer[next_rank] + key_numel = key.numel() + key = kv[:key_numel].reshape_as(key) + value = kv[key_numel:].reshape_as(value) + next_rank = (next_rank + 1) % world_size + + saved_tensors = list(ctx.op_ctx.to_save) + saved_tensors[1] = key + saved_tensors[2] = value + ctx.op_ctx.to_save = tuple(saved_tensors) + + grad_query_op, grad_key_op, grad_value_op, *_ = ctx.op.backward(ctx.op_ctx, grad_out) + + if i > 0: + grad_kv_buffer = _wait_tensor(next_grad_kv) + grad_key_numel = grad_key.numel() + grad_key = grad_kv_buffer[:grad_key_numel].reshape_as(grad_key) + grad_value = grad_kv_buffer[grad_key_numel:].reshape_as(grad_value) + + grad_query += grad_query_op + grad_key += grad_key_op + grad_value += grad_value_op + + if i < world_size - 1: + grad_kv_buffer = torch.cat([grad_key.flatten(), grad_value.flatten()]).contiguous() + next_grad_kv = funcol.permute_tensor(grad_kv_buffer, next_ranks, group=ring_mesh.get_group()) + + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None class TemplatedUlyssesAttention(torch.autograd.Function): @@ -937,6 +1000,8 @@ class TemplatedUlyssesAttention(torch.autograd.Function): world_size = parallel_config.ulysses_degree group = ulysses_mesh.get_group() + ctx.op_ctx = torch.autograd.function.FunctionCtx() + B, S_LOCAL, H, D = query.shape H_LOCAL = H // world_size query, key, value = ( @@ -948,7 +1013,7 @@ class TemplatedUlyssesAttention(torch.autograd.Function): ) query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) - out = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse) + out = op.forward(ctx.op_ctx, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse) if return_lse: out, lse, *_ = out @@ -1291,7 +1356,7 @@ def _native_cudnn_attention( lse = None if parallel_config is None and not return_lse: - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value)) with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): out = torch.nn.functional.scaled_dot_product_attention( query=query, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index cd11602afa..b93f2eb551 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -63,6 +63,7 @@ from ..utils.hub_utils import ( populate_model_card, ) from ..utils.torch_utils import empty_device_cache +from ._modeling_parallel import ContextParallelModelPlan, ParallelConfig, _InternalParallelConfig from .model_loading_utils import ( _caching_allocator_warmup, _determine_device_map, @@ -1501,19 +1502,20 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): f"Regional compilation failed because {repeated_blocks} classes are not found in the model. " ) - def parallelize(self, *, ring_degree: int = 1, ulysses_degree: int = 1, cp_plan=None): - from ..hooks.context_parallel import ParallelConfig, apply_context_parallel + @contextmanager + def parallelize(self, *, config: ParallelConfig, cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None): + from ..hooks.context_parallel import apply_context_parallel, remove_context_parallel + from .attention_dispatch import _AttentionBackendRegistry - # TODO(aryan): add cp_plan type hint logger.warning( "`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning." ) if not torch.distributed.is_initialized(): raise RuntimeError("torch.distributed must be initialized before calling `parallelize`.") - if ring_degree < 1 or ulysses_degree < 1: + if config.ring_degree < 1 or config.ulysses_degree < 1: raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") - if ring_degree > 1 and ulysses_degree > 1: + if config.ring_degree > 1 and config.ulysses_degree > 1: raise ValueError( "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." ) @@ -1521,9 +1523,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() - if ring_degree * ulysses_degree > world_size: + if config.ring_degree * config.ulysses_degree > world_size: raise ValueError( - f"The product of `ring_degree` ({ring_degree}) and `ulysses_degree` ({ulysses_degree}) must not exceed the world size ({world_size})." + f"The product of `ring_degree` ({config.ring_degree}) and `ulysses_degree` ({config.ulysses_degree}) must not exceed the world size ({world_size})." ) device_type = torch._C._get_accelerator().type @@ -1532,14 +1534,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): cp_mesh = torch.distributed.device_mesh.init_device_mesh( device_type=device_type, - mesh_shape=(ring_degree, ulysses_degree), + mesh_shape=(config.ring_degree, config.ulysses_degree), mesh_dim_names=("ring", "ulysses"), ) - parallel_config = ParallelConfig( + parallel_config = _InternalParallelConfig( rank=rank, world_size=world_size, - ring_degree=ring_degree, - ulysses_degree=ulysses_degree, + ring_degree=config.ring_degree, + ulysses_degree=config.ulysses_degree, device=device, cp_mesh=cp_mesh, ) @@ -1550,6 +1552,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): cp_plan = cp_plan if cp_plan is not None else self._cp_plan apply_context_parallel(self, parallel_config, cp_plan) + _AttentionBackendRegistry._parallel_config = parallel_config + + yield + + remove_context_parallel(self, cp_plan) @classmethod def _load_pretrained_model(