1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Aryan
2025-07-16 17:19:30 +02:00
parent 84d2c84ea4
commit 51fed50837
6 changed files with 972 additions and 149 deletions

View File

@@ -16,6 +16,7 @@ from ..utils import is_torch_available
if is_torch_available():
from .context_parallel import apply_context_parallel
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading

View File

@@ -0,0 +1,275 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Dict, List, Type, Union
import torch
import torch.distributed._functional_collectives as funcol
from ..models._modeling_parallel import (
ContextParallelInput,
ContextParallelModelPlan,
ContextParallelOutput,
ParallelConfig,
)
from ..models.attention_dispatch import _parallel_context
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_CONTEXT_PARALLEL_MODEL_HOOK = "context_parallel_model_hook"
_CONTEXT_PARALLEL_SUBMODULE_INPUT_HOOK_TEMPLATE = "cp_input---{}"
_CONTEXT_PARALLEL_SUBMODULE_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
@dataclass
class ModuleForwardMetadata:
cached_parameter_indices: Dict[str, int] = None
_cls: Type = None
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
kwargs = kwargs or {}
if identifier in kwargs:
return kwargs[identifier], True, None
if self.cached_parameter_indices is not None:
index = self.cached_parameter_indices.get(identifier, None)
if index is None:
raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
return args[index], False, index
if self._cls is None:
raise ValueError("Model class is not set for metadata.")
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
parameters = parameters[1:] # skip `self`
self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
if identifier not in self.cached_parameter_indices:
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
index = self.cached_parameter_indices[identifier]
if index >= len(args):
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
return args[index], False, index
def apply_context_parallel(
module: torch.nn.Module,
parallel_config: ParallelConfig,
plan: Dict[str, ContextParallelModelPlan],
) -> None:
"""Apply context parallel on a model."""
logger.debug(f"Applying context parallel with CP mesh: {parallel_config.cp_mesh} and plan: {plan}")
for module_id, cp_model_plan in plan.items():
submodule = _get_submodule_by_name(module, module_id)
if not isinstance(submodule, list):
submodule = [submodule]
logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
for m in submodule:
if isinstance(cp_model_plan, dict):
hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
hook_name = _CONTEXT_PARALLEL_SUBMODULE_INPUT_HOOK_TEMPLATE.format(module_id)
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
if isinstance(cp_model_plan, ContextParallelOutput):
cp_model_plan = [cp_model_plan]
if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
hook_name = _CONTEXT_PARALLEL_SUBMODULE_OUTPUT_HOOK_TEMPLATE.format(module_id)
else:
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
registry = HookRegistry.check_if_exists_or_initialize(m)
registry.register_hook(hook, hook_name)
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = ContextParallelModelHook(parallel_config)
registry.register_hook(hook, _CONTEXT_PARALLEL_MODEL_HOOK)
class ContextParallelModelHook(ModelHook):
def __init__(self, parallel_config: ParallelConfig) -> None:
super().__init__()
self.parallel_config = parallel_config
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
with _parallel_context(self.parallel_config):
return self.fn_ref.original_forward(*args, **kwargs)
class ContextParallelSplitHook(ModelHook):
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None:
super().__init__()
self.metadata = metadata
self.parallel_config = parallel_config
self.module_forward_metadata = None
def initialize_hook(self, module):
cls = unwrap_module(module).__class__
self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
return module
def pre_forward(self, module, *args, **kwargs):
args_list = list(args)
for name, cpm in self.metadata.items():
if isinstance(cpm, ContextParallelInput) and cpm.split_output:
continue
# Maybe the parameter was passed as a keyword argument
input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
name, args_list, kwargs
)
if input_val is None:
continue
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
# the output instead of input for a particular layer by setting split_output=True
if isinstance(input_val, torch.Tensor):
input_val = self._prepare_cp_input(input_val, cpm)
elif isinstance(input_val, (list, tuple)):
if len(input_val) != len(cpm):
raise ValueError(
f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
)
sharded_input_val = []
for i, x in enumerate(input_val):
if torch.is_tensor(x) and not cpm[i].split_output:
x = self._prepare_cp_input(x, cpm[i])
sharded_input_val.append(x)
input_val = sharded_input_val
else:
raise ValueError(f"Unsupported input type: {type(input_val)}")
if is_kwarg:
kwargs[name] = input_val
elif index is not None and index < len(args_list):
args_list[index] = input_val
else:
raise ValueError(
f"An unexpected error occurred while processing the input '{name}'. Please open an "
f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
f"example along with the full stack trace."
)
return tuple(args_list), kwargs
def post_forward(self, module, output):
is_tensor = isinstance(output, torch.Tensor)
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
if not is_tensor and not is_tensor_list:
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
output = [output] if is_tensor else list(output)
for index, cpm in self.metadata.items():
if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
continue
if index >= len(output):
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
current_output = output[index]
current_output = self._prepare_cp_input(current_output, cpm)
output[index] = current_output
return output[0] if is_tensor else tuple(output)
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
raise ValueError(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
)
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
class ContextParallelGatherHook(ModelHook):
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None:
super().__init__()
self.metadata = metadata
self.parallel_config = parallel_config
def post_forward(self, module, output):
is_tensor = isinstance(output, torch.Tensor)
if is_tensor:
output = [output]
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
output = list(output)
if len(output) != len(self.metadata):
raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
return output[0] if is_tensor else tuple(output)
class EquipartitionSharder:
@classmethod
@torch.compiler.disable
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
assert tensor.size()[dim] % mesh.size() == 0
return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
@classmethod
@torch.compiler.disable
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
tensor = tensor.contiguous()
tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group())
return tensor
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name.count("*") > 1:
raise ValueError("Wildcard '*' can only be used once in the name")
return _find_submodule_by_name(model, name)
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name == "":
return model
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
if first_atom == "*":
if not isinstance(model, torch.nn.ModuleList):
raise ValueError("Wildcard '*' can only be used with ModuleList")
submodules = []
for submodule in model:
subsubmodules = _find_submodule_by_name(submodule, remaining_name)
if not isinstance(subsubmodules, list):
subsubmodules = [subsubmodules]
submodules.extend(subsubmodules)
return submodules
else:
if hasattr(model, first_atom):
submodule = getattr(model, first_atom)
return _find_submodule_by_name(submodule, remaining_name)
else:
raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")

View File

@@ -0,0 +1,105 @@
# Experimental parallelism support for Diffusers.
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Tuple, Union
import torch
from ..utils import get_logger
logger = get_logger(__name__) # pylint: disable=invalid-name
# TODO(aryan): add support for the following:
# - Unified Attention
# - More dispatcher attention backends
# - CFG/Data Parallel
# - Tensor Parallel
@dataclass
class ParallelConfig:
rank: int
world_size: int
ring_degree: int
ulysses_degree: int
device: torch.device
cp_mesh: torch.distributed.device_mesh.DeviceMesh
# Whether to convert output and LSE to float32 for ring attention numerical stability
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
_flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None
_ring_mesh: torch.distributed.device_mesh.DeviceMesh = None
_ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
_ring_local_rank: int = None
_ulysses_local_rank: int = None
def __post_init__(self):
if self.rotate_method != "allgather":
raise ValueError(f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}.")
if self._flattened_mesh is None:
self._flattened_mesh = self.cp_mesh._flatten()
if self._ring_mesh is None:
self._ring_mesh = self.cp_mesh["ring"]
if self._ulysses_mesh is None:
self._ulysses_mesh = self.cp_mesh["ulysses"]
if self._ring_local_rank is None:
self._ring_local_rank = self._ring_mesh.get_local_rank()
if self._ulysses_local_rank is None:
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
@dataclass(frozen=True)
class ContextParallelInput:
split_dim: int
expected_dims: Optional[int] = None
split_output: bool = False
def __repr__(self):
return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})"
@dataclass(frozen=True)
class ContextParallelOutput:
gather_dim: int
expected_dims: Optional[int] = None
def __repr__(self):
return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})"
# A dictionary where keys denote the input to be split across context parallel region, and the
# value denotes the sharding configuration.
# If the key is a string, it denotes the name of the parameter in the forward function.
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
# to be split across context parallel region.
ContextParallelInputType = Dict[
Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
]
# A dictionary where keys denote the output to be gathered across context parallel region, and the
# value denotes the gathering configuration.
ContextParallelOutputType = Union[
ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
]
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
# the module should be split/gathered across context parallel region.
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]

View File

@@ -17,9 +17,10 @@ import functools
import inspect
import math
from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.distributed._functional_collectives as funcol
from ..utils import (
get_logger,
@@ -38,15 +39,22 @@ from ..utils import (
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
if TYPE_CHECKING:
from ._modeling_parallel import ParallelConfig
logger = get_logger(__name__) # pylint: disable=invalid-name
if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"):
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward
else:
logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.")
flash_attn_func = None
flash_attn_varlen_func = None
_flash_attn_forward = None
_flash_attn_backward = None
if is_flash_attn_3_available():
@@ -104,6 +112,27 @@ else:
xops = None
if torch.__version__ >= "2.4.0":
_custom_op = torch.library.custom_op
_register_fake = torch.library.register_fake
else:
def _custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
def wrap(func):
return func
return wrap if fn is None else fn
def _register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
def wrap(func):
return func
return wrap if fn is None else fn
_custom_op = _custom_op_no_op
_register_fake = _register_fake_no_op
# TODO(aryan): Add support for the following:
# - Sage Attention++
# - block sparse, radial and other attention methods
@@ -154,17 +183,25 @@ class _AttentionBackendRegistry:
_backends = {}
_constraints = {}
_supported_arg_names = {}
_supports_context_parallel = {}
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
_checks_enabled = DIFFUSERS_ATTN_CHECKS
_parallel_config: Optional["ParallelConfig"] = None
@classmethod
def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None):
def register(
cls,
backend: AttentionBackendName,
constraints: Optional[List[Callable]] = None,
supports_context_parallel: bool = False,
):
logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
def decorator(func):
cls._backends[backend] = func
cls._constraints[backend] = constraints or []
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
cls._supports_context_parallel[backend] = supports_context_parallel
return func
return decorator
@@ -177,6 +214,17 @@ class _AttentionBackendRegistry:
def list_backends(cls):
return list(cls._backends.keys())
@classmethod
def _is_context_parallel_enabled(cls, backend: AttentionBackendName) -> bool:
if backend not in cls._supports_context_parallel:
raise ValueError(f"Backend {backend} is not registered.")
supports_context_parallel = cls._supports_context_parallel[backend]
is_degree_greater_than_1 = _AttentionBackendRegistry._parallel_config is not None and (
_AttentionBackendRegistry._parallel_config.ring_degree > 1
or _AttentionBackendRegistry._parallel_config.ulysses_degree > 1
)
return supports_context_parallel and is_degree_greater_than_1
@contextlib.contextmanager
def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE):
@@ -195,6 +243,20 @@ def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIV
_AttentionBackendRegistry._active_backend = old_backend
@contextlib.contextmanager
def _parallel_context(parallel_config: "ParallelConfig"):
"""
Context manager to set the parallel configuration for attention backends that support it.
"""
old_parallel_config = _AttentionBackendRegistry._parallel_config
_AttentionBackendRegistry._parallel_config = parallel_config
try:
yield
finally:
_AttentionBackendRegistry._parallel_config = old_parallel_config
def dispatch_attention_fn(
query: torch.Tensor,
key: torch.Tensor,
@@ -218,6 +280,14 @@ def dispatch_attention_fn(
backend_name = AttentionBackendName(backend)
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
if (
_AttentionBackendRegistry._parallel_config is not None
and not _AttentionBackendRegistry._is_context_parallel_enabled(backend_name)
):
raise ValueError(
f"Backend {backend_name} does not support context parallelism, but a parallel configuration is provided."
)
kwargs = {
"query": query,
"key": key,
@@ -415,20 +485,398 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# TODO: library.custom_op and register_fake probably need version guards?
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_original(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
qv: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
attention_chunk: int = 0,
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = flash_attn_3_func(query, key, value)
# Hardcoded for now because pytorch does not support tuple/int type hints
window_size = (-1, -1)
out, lse, *_ = flash_attn_3_func(
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
causal=causal,
qv=qv,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
window_size=window_size,
attention_chunk=attention_chunk,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
)
lse = lse.permute(0, 2, 1)
return out, lse
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
@_register_fake("_diffusers_flash_attn_3::_flash_attn_forward")
def _(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
qv: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
attention_chunk: int = 0,
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
window_size = (-1, -1) # noqa: F841
# A lot of the parameters here are not yet used in any way within diffusers.
# We can safely ignore for now and keep the fake op shape propagation simple.
batch_size, seq_len, num_heads, head_dim = q.shape
lse_shape = (batch_size, seq_len, num_heads)
return torch.empty_like(query), query.new_empty(lse_shape)
return torch.empty_like(q), q.new_empty(lse_shape)
# ===== Autograd functions =====
class _cudnn_attention(torch.autograd.Function):
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
# forward declaration:
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
# backward declaration:
# aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
enable_gqa: bool = False,
return_lse: bool = False,
):
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.")
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
ctx.attn_mask = attn_mask
# Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results
# if the input tensors are not contiguous.
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query,
key=key,
value=value,
attn_bias=attn_mask,
compute_log_sumexp=return_lse,
dropout_p=dropout_p,
is_causal=is_causal,
return_debug_mask=False,
scale=scale,
)
)
ctx.max_q = max_q
ctx.max_k = max_k
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
out = out.transpose(1, 2).contiguous()
if lse is not None:
lse = lse.transpose(1, 2).contiguous()
return (out, lse) if return_lse else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args: torch.Tensor,
):
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
grad_out = grad_out.transpose(1, 2).contiguous()
# Cannot pass first 5 arguments as kwargs because: https://github.com/pytorch/pytorch/blob/d26ca5de058dbcf56ac52bb43e84dd98df2ace97/torch/_dynamo/variables/torch.py#L1341
grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward(
grad_out,
query,
key,
value,
out,
logsumexp=lse,
philox_seed=philox_seed,
philox_offset=philox_offset,
attn_bias=ctx.attn_mask,
cum_seq_q=cum_seq_q,
cum_seq_k=cum_seq_k,
max_q=ctx.max_q,
max_k=ctx.max_k,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
)
grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value))
return grad_query, grad_key, grad_value, None, None, None, None, None
# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
class _flash_attention_2(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
enable_gqa: bool = False,
return_lse: bool = False,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 2.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.")
# Hardcoded for now
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
if scale is None:
scale = query.shape[-1] ** (-0.5)
# flash-attn only returns LSE if dropout_p > 0. So, we need to workaround.
parallel_config = _AttentionBackendRegistry._parallel_config
if query.requires_grad or (parallel_config is not None and parallel_config.world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
out, lse, S_dmask, rng_state = _flash_attn_forward(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
ctx.save_for_backward(query, key, value, out, lse, rng_state)
return (out, lse) if return_lse else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args: torch.Tensor,
):
query, key, value, out, lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
lse_d = _flash_attn_backward( # noqa: F841
grad_out,
query,
key,
value,
out,
lse,
grad_query,
grad_key,
grad_value,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)
# Head dimension may have been padded
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
# ===== Context parallel =====
class TemplatedRingAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
return_lse: bool,
op: torch.autograd.Function,
):
parallel_config = _AttentionBackendRegistry._parallel_config
ring_mesh = parallel_config._ring_mesh
rank = parallel_config._ring_local_rank
world_size = parallel_config.ring_degree
next_rank = (rank + 1) % world_size
prev_out = prev_lse = None
kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
kv_buffer = kv_buffer.chunk(world_size)
for i in range(world_size):
if i > 0:
kv = kv_buffer[next_rank]
key = kv[: key.numel()].reshape_as(key)
value = kv[key.numel() :].reshape_as(value)
next_rank = (next_rank + 1) % world_size
out, lse = op.apply(query, key, value, None, 0.0, None, False, False, True)
if parallel_config.convert_to_fp32:
out = out.to(torch.float32)
lse = lse.to(torch.float32)
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
prev_out = out
prev_lse = lse
out = out.to(query.dtype)
lse = lse.squeeze(-1)
return (out, lse) if return_lse else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args: torch.Tensor,
):
raise NotImplementedError("Backward pass is not implemented for TemplatedRingAttention.")
class TemplatedUlyssesAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
return_lse: bool,
op: torch.autograd.Function,
):
parallel_config = _AttentionBackendRegistry._parallel_config
ulysses_mesh = parallel_config._ulysses_mesh
world_size = parallel_config.ulysses_degree
group = ulysses_mesh.get_group()
B, S_LOCAL, H, D = query.shape
H_LOCAL = H // world_size
query, key, value = (
x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
for x in (query, key, value)
)
query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value))
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
out = op.apply(query, key, value, None, 0.0, None, False, False, return_lse)
if return_lse:
out, lse, *_ = out
out = out.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
out = funcol.all_to_all_single(out, None, None, group=group).wait()
out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
if return_lse:
lse = lse.reshape(B, world_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
lse = funcol.all_to_all_single(lse, None, None, group=group).wait()
lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
else:
lse = None
return (out, lse) if return_lse else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args: torch.Tensor,
):
raise NotImplementedError("Backward pass is not implemented for TemplatedUlyssesAttention.")
def _templated_context_parallel_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
*,
op: torch.autograd.Function,
):
if attn_mask is not None:
raise ValueError("Attention mask is not yet supported for templated attention.")
if is_causal:
raise ValueError("Causal attention is not yet supported for templated attention.")
if enable_gqa:
raise ValueError("GQA is not yet supported for templated attention.")
parallel_config = _AttentionBackendRegistry._parallel_config
# TODO: add support for unified attention with ring/ulysses degree both being > 1
if parallel_config.ring_degree > 1:
return TemplatedRingAttention.apply(query, key, value, return_lse, op)
elif parallel_config.ulysses_degree > 1:
return TemplatedUlyssesAttention.apply(query, key, value, return_lse, op)
else:
return op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse)
# ===== Attention backends =====
@@ -445,11 +893,7 @@ def _flash_attention(
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False,
return_attn_probs: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
out = flash_attn_func(
q=query,
@@ -458,11 +902,7 @@ def _flash_attention(
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
return_attn_probs=return_lse,
)
return out
@@ -475,19 +915,11 @@ def _flash_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False,
return_attn_probs: bool = False,
attn_mask: Optional[torch.Tensor] = None,
return_lse: bool = False,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
@@ -495,16 +927,11 @@ def _flash_varlen_attention(
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
else:
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
)
key_valid, value_valid = [], []
for b in range(batch_size):
@@ -527,11 +954,7 @@ def _flash_varlen_attention(
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
return_attn_probs=return_lse,
)
out = out.unflatten(0, (batch_size, -1))
@@ -548,30 +971,16 @@ def _flash_attention_3(
value: torch.Tensor,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
deterministic: bool = False,
return_attn_probs: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
out, lse, *_ = flash_attn_3_func(
out, lse = _wrapped_flash_attn_3(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
attention_chunk=0,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
return (out, lse) if return_attn_probs else out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
@@ -582,17 +991,10 @@ def _flash_varlen_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
deterministic: bool = False,
return_attn_probs: bool = False,
attn_mask: Optional[torch.Tensor] = None,
return_lse: bool = False,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
@@ -600,16 +1002,11 @@ def _flash_varlen_attention_3(
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
else:
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
)
key_valid, value_valid = [], []
for b in range(batch_size):
@@ -629,24 +1026,12 @@ def _flash_varlen_attention_3(
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
seqused_q=None,
seqused_k=None,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
out = out.unflatten(0, (batch_size, -1))
return (out, lse) if return_attn_probs else out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
@@ -662,7 +1047,6 @@ def _native_flex_attention(
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
kernel_options: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
# TODO: should we LRU cache the block mask creation?
score_mod = None
@@ -707,7 +1091,6 @@ def _native_flex_attention(
scale=scale,
enable_gqa=enable_gqa,
return_lse=return_lse,
kernel_options=kernel_options,
)
out = out.permute(0, 2, 1, 3)
return out
@@ -726,7 +1109,10 @@ def _native_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
if return_lse:
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
@@ -745,6 +1131,7 @@ def _native_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_CUDNN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
)
def _native_cudnn_attention(
query: torch.Tensor,
@@ -755,21 +1142,33 @@ def _native_cudnn_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
parallel_config = _AttentionBackendRegistry._parallel_config
lse = None
if parallel_config is None and not return_lse:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
else:
out = _templated_context_parallel_attention(
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op=_cudnn_attention
)
out = out.permute(0, 2, 1, 3)
return out
if return_lse:
out, lse = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
@@ -785,7 +1184,10 @@ def _native_efficient_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
if return_lse:
raise ValueError("Native efficient attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
@@ -814,7 +1216,10 @@ def _native_flash_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
if return_lse:
raise ValueError("Native flash attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
@@ -844,7 +1249,10 @@ def _native_math_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
if return_lse:
raise ValueError("Native math attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
out = torch.nn.functional.scaled_dot_product_attention(
@@ -871,7 +1279,10 @@ def _native_npu_attention(
value: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
return_lse: bool = False,
) -> torch.Tensor:
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
return npu_fusion_attention(
query,
key,
@@ -898,7 +1309,10 @@ def _native_xla_attention(
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
if return_lse:
raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
query = query / math.sqrt(query.shape[-1])
out = xla_flash_attention(
@@ -942,31 +1356,25 @@ def _sage_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
smooth_k: bool = True,
attn_mask: Optional[torch.Tensor] = None,
return_lse: bool = False,
) -> torch.Tensor:
if return_lse:
raise ValueError("Sage varlen backend does not support setting `return_lse=True`.")
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
else:
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
)
key_valid, value_valid = [], []
for b in range(batch_size):
@@ -988,7 +1396,6 @@ def _sage_varlen_attention(
max_seqlen_k=max_seqlen_k,
is_causal=is_causal,
sm_scale=scale,
smooth_k=smooth_k,
)
out = out.unflatten(0, (batch_size, -1))
@@ -1005,10 +1412,6 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
smooth_k: bool = True,
smooth_v: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp8_cuda(
@@ -1017,11 +1420,7 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
v=value,
tensor_layout="NHD",
is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
smooth_v=smooth_v,
return_lse=return_lse,
)
@@ -1036,9 +1435,6 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
smooth_k: bool = True,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp8_cuda_sm90(
@@ -1047,10 +1443,7 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
v=value,
tensor_layout="NHD",
is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
return_lse=return_lse,
)
@@ -1065,10 +1458,6 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32",
smooth_k: bool = True,
smooth_v: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp16_cuda(
@@ -1077,11 +1466,7 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
v=value,
tensor_layout="NHD",
is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
smooth_v=smooth_v,
return_lse=return_lse,
)
@@ -1096,8 +1481,6 @@ def _sage_qk_int8_pv_fp16_triton_attention(
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
smooth_k: bool = True,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp16_triton(
@@ -1105,10 +1488,8 @@ def _sage_qk_int8_pv_fp16_triton_attention(
k=key,
v=value,
tensor_layout="NHD",
quantization_backend=quantization_backend,
is_causal=is_causal,
sm_scale=scale,
smooth_k=smooth_k,
return_lse=return_lse,
)
@@ -1126,7 +1507,11 @@ def _xformers_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
if return_lse:
raise ValueError("xformers attention backend does not support setting `return_lse=True`.")
batch_size, seq_len_q, num_heads_q, _ = query.shape
_, seq_len_kv, num_heads_kv, _ = key.shape

View File

@@ -271,6 +271,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_skip_layerwise_casting_patterns = None
_supports_group_offloading = True
_repeated_blocks = []
_cp_plan = None
def __init__(self):
super().__init__()
@@ -1492,6 +1493,52 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
)
def parallelize(self, *, ring_degree: int = 1, ulysses_degree: int = 1, cp_plan=None):
from ..hooks.context_parallel import ParallelConfig, apply_context_parallel
# TODO(aryan): add cp_plan type hint
logger.warning(
"`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
)
if not torch.distributed.is_initialized():
raise RuntimeError("torch.distributed must be initialized before calling `parallelize`.")
if ring_degree < 1 or ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if ring_degree > 1 and ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
if ring_degree * ulysses_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({ring_degree}) and `ulysses_degree` ({ulysses_degree}) must not exceed the world size ({world_size})."
)
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
device = torch.device(device_type, rank % device_module.device_count())
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=(ring_degree, ulysses_degree),
mesh_dim_names=("ring", "ulysses"),
)
parallel_config = ParallelConfig(
rank=rank,
world_size=world_size,
ring_degree=ring_degree,
ulysses_degree=ulysses_degree,
device=device,
cp_mesh=cp_mesh,
)
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
apply_context_parallel(self, parallel_config, cp_plan)
@classmethod
def _load_pretrained_model(
cls,

View File

@@ -25,6 +25,7 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
@@ -569,6 +570,15 @@ class FluxTransformer2DModel(
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_cp_plan = {
"": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
"txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
},
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}
@register_to_config
def __init__(