mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -16,6 +16,7 @@ from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .context_parallel import apply_context_parallel
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
|
||||
275
src/diffusers/hooks/context_parallel.py
Normal file
275
src/diffusers/hooks/context_parallel.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
||||
from ..models._modeling_parallel import (
|
||||
ContextParallelInput,
|
||||
ContextParallelModelPlan,
|
||||
ContextParallelOutput,
|
||||
ParallelConfig,
|
||||
)
|
||||
from ..models.attention_dispatch import _parallel_context
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_CONTEXT_PARALLEL_MODEL_HOOK = "context_parallel_model_hook"
|
||||
_CONTEXT_PARALLEL_SUBMODULE_INPUT_HOOK_TEMPLATE = "cp_input---{}"
|
||||
_CONTEXT_PARALLEL_SUBMODULE_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
|
||||
|
||||
|
||||
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
|
||||
@dataclass
|
||||
class ModuleForwardMetadata:
|
||||
cached_parameter_indices: Dict[str, int] = None
|
||||
_cls: Type = None
|
||||
|
||||
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
if identifier in kwargs:
|
||||
return kwargs[identifier], True, None
|
||||
|
||||
if self.cached_parameter_indices is not None:
|
||||
index = self.cached_parameter_indices.get(identifier, None)
|
||||
if index is None:
|
||||
raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
|
||||
return args[index], False, index
|
||||
|
||||
if self._cls is None:
|
||||
raise ValueError("Model class is not set for metadata.")
|
||||
|
||||
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
|
||||
parameters = parameters[1:] # skip `self`
|
||||
self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
|
||||
|
||||
if identifier not in self.cached_parameter_indices:
|
||||
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
|
||||
|
||||
index = self.cached_parameter_indices[identifier]
|
||||
|
||||
if index >= len(args):
|
||||
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
|
||||
|
||||
return args[index], False, index
|
||||
|
||||
|
||||
def apply_context_parallel(
|
||||
module: torch.nn.Module,
|
||||
parallel_config: ParallelConfig,
|
||||
plan: Dict[str, ContextParallelModelPlan],
|
||||
) -> None:
|
||||
"""Apply context parallel on a model."""
|
||||
logger.debug(f"Applying context parallel with CP mesh: {parallel_config.cp_mesh} and plan: {plan}")
|
||||
|
||||
for module_id, cp_model_plan in plan.items():
|
||||
submodule = _get_submodule_by_name(module, module_id)
|
||||
if not isinstance(submodule, list):
|
||||
submodule = [submodule]
|
||||
|
||||
logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
|
||||
|
||||
for m in submodule:
|
||||
if isinstance(cp_model_plan, dict):
|
||||
hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
|
||||
hook_name = _CONTEXT_PARALLEL_SUBMODULE_INPUT_HOOK_TEMPLATE.format(module_id)
|
||||
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
|
||||
if isinstance(cp_model_plan, ContextParallelOutput):
|
||||
cp_model_plan = [cp_model_plan]
|
||||
if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
|
||||
raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
|
||||
hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
|
||||
hook_name = _CONTEXT_PARALLEL_SUBMODULE_OUTPUT_HOOK_TEMPLATE.format(module_id)
|
||||
else:
|
||||
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(m)
|
||||
registry.register_hook(hook, hook_name)
|
||||
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = ContextParallelModelHook(parallel_config)
|
||||
registry.register_hook(hook, _CONTEXT_PARALLEL_MODEL_HOOK)
|
||||
|
||||
|
||||
class ContextParallelModelHook(ModelHook):
|
||||
def __init__(self, parallel_config: ParallelConfig) -> None:
|
||||
super().__init__()
|
||||
self.parallel_config = parallel_config
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
with _parallel_context(self.parallel_config):
|
||||
return self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
|
||||
class ContextParallelSplitHook(ModelHook):
|
||||
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None:
|
||||
super().__init__()
|
||||
self.metadata = metadata
|
||||
self.parallel_config = parallel_config
|
||||
self.module_forward_metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
cls = unwrap_module(module).__class__
|
||||
self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
|
||||
return module
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
args_list = list(args)
|
||||
|
||||
for name, cpm in self.metadata.items():
|
||||
if isinstance(cpm, ContextParallelInput) and cpm.split_output:
|
||||
continue
|
||||
|
||||
# Maybe the parameter was passed as a keyword argument
|
||||
input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
|
||||
name, args_list, kwargs
|
||||
)
|
||||
|
||||
if input_val is None:
|
||||
continue
|
||||
|
||||
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
|
||||
# the output instead of input for a particular layer by setting split_output=True
|
||||
if isinstance(input_val, torch.Tensor):
|
||||
input_val = self._prepare_cp_input(input_val, cpm)
|
||||
elif isinstance(input_val, (list, tuple)):
|
||||
if len(input_val) != len(cpm):
|
||||
raise ValueError(
|
||||
f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
|
||||
)
|
||||
sharded_input_val = []
|
||||
for i, x in enumerate(input_val):
|
||||
if torch.is_tensor(x) and not cpm[i].split_output:
|
||||
x = self._prepare_cp_input(x, cpm[i])
|
||||
sharded_input_val.append(x)
|
||||
input_val = sharded_input_val
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(input_val)}")
|
||||
|
||||
if is_kwarg:
|
||||
kwargs[name] = input_val
|
||||
elif index is not None and index < len(args_list):
|
||||
args_list[index] = input_val
|
||||
else:
|
||||
raise ValueError(
|
||||
f"An unexpected error occurred while processing the input '{name}'. Please open an "
|
||||
f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
|
||||
f"example along with the full stack trace."
|
||||
)
|
||||
|
||||
return tuple(args_list), kwargs
|
||||
|
||||
def post_forward(self, module, output):
|
||||
is_tensor = isinstance(output, torch.Tensor)
|
||||
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
|
||||
|
||||
if not is_tensor and not is_tensor_list:
|
||||
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
|
||||
|
||||
output = [output] if is_tensor else list(output)
|
||||
for index, cpm in self.metadata.items():
|
||||
if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
|
||||
continue
|
||||
if index >= len(output):
|
||||
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
|
||||
current_output = output[index]
|
||||
current_output = self._prepare_cp_input(current_output, cpm)
|
||||
output[index] = current_output
|
||||
|
||||
return output[0] if is_tensor else tuple(output)
|
||||
|
||||
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
|
||||
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
|
||||
raise ValueError(
|
||||
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
|
||||
)
|
||||
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
|
||||
class ContextParallelGatherHook(ModelHook):
|
||||
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None:
|
||||
super().__init__()
|
||||
self.metadata = metadata
|
||||
self.parallel_config = parallel_config
|
||||
|
||||
def post_forward(self, module, output):
|
||||
is_tensor = isinstance(output, torch.Tensor)
|
||||
|
||||
if is_tensor:
|
||||
output = [output]
|
||||
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
|
||||
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
|
||||
|
||||
output = list(output)
|
||||
|
||||
if len(output) != len(self.metadata):
|
||||
raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
|
||||
|
||||
for i, cpm in enumerate(self.metadata):
|
||||
if cpm is None:
|
||||
continue
|
||||
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
return output[0] if is_tensor else tuple(output)
|
||||
|
||||
|
||||
class EquipartitionSharder:
|
||||
@classmethod
|
||||
@torch.compiler.disable
|
||||
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
|
||||
assert tensor.size()[dim] % mesh.size() == 0
|
||||
return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
|
||||
|
||||
@classmethod
|
||||
@torch.compiler.disable
|
||||
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
|
||||
tensor = tensor.contiguous()
|
||||
tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group())
|
||||
return tensor
|
||||
|
||||
|
||||
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
|
||||
if name.count("*") > 1:
|
||||
raise ValueError("Wildcard '*' can only be used once in the name")
|
||||
return _find_submodule_by_name(model, name)
|
||||
|
||||
|
||||
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
|
||||
if name == "":
|
||||
return model
|
||||
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
|
||||
if first_atom == "*":
|
||||
if not isinstance(model, torch.nn.ModuleList):
|
||||
raise ValueError("Wildcard '*' can only be used with ModuleList")
|
||||
submodules = []
|
||||
for submodule in model:
|
||||
subsubmodules = _find_submodule_by_name(submodule, remaining_name)
|
||||
if not isinstance(subsubmodules, list):
|
||||
subsubmodules = [subsubmodules]
|
||||
submodules.extend(subsubmodules)
|
||||
return submodules
|
||||
else:
|
||||
if hasattr(model, first_atom):
|
||||
submodule = getattr(model, first_atom)
|
||||
return _find_submodule_by_name(submodule, remaining_name)
|
||||
else:
|
||||
raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
|
||||
105
src/diffusers/models/_modeling_parallel.py
Normal file
105
src/diffusers/models/_modeling_parallel.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Experimental parallelism support for Diffusers.
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# TODO(aryan): add support for the following:
|
||||
# - Unified Attention
|
||||
# - More dispatcher attention backends
|
||||
# - CFG/Data Parallel
|
||||
# - Tensor Parallel
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelConfig:
|
||||
rank: int
|
||||
world_size: int
|
||||
ring_degree: int
|
||||
ulysses_degree: int
|
||||
device: torch.device
|
||||
cp_mesh: torch.distributed.device_mesh.DeviceMesh
|
||||
|
||||
# Whether to convert output and LSE to float32 for ring attention numerical stability
|
||||
convert_to_fp32: bool = True
|
||||
# TODO: support alltoall
|
||||
rotate_method: Literal["allgather", "alltoall"] = "allgather"
|
||||
|
||||
_flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_ring_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_ring_local_rank: int = None
|
||||
_ulysses_local_rank: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.rotate_method != "allgather":
|
||||
raise ValueError(f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}.")
|
||||
if self._flattened_mesh is None:
|
||||
self._flattened_mesh = self.cp_mesh._flatten()
|
||||
if self._ring_mesh is None:
|
||||
self._ring_mesh = self.cp_mesh["ring"]
|
||||
if self._ulysses_mesh is None:
|
||||
self._ulysses_mesh = self.cp_mesh["ulysses"]
|
||||
if self._ring_local_rank is None:
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
if self._ulysses_local_rank is None:
|
||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextParallelInput:
|
||||
split_dim: int
|
||||
expected_dims: Optional[int] = None
|
||||
split_output: bool = False
|
||||
|
||||
def __repr__(self):
|
||||
return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextParallelOutput:
|
||||
gather_dim: int
|
||||
expected_dims: Optional[int] = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})"
|
||||
|
||||
|
||||
# A dictionary where keys denote the input to be split across context parallel region, and the
|
||||
# value denotes the sharding configuration.
|
||||
# If the key is a string, it denotes the name of the parameter in the forward function.
|
||||
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
|
||||
# to be split across context parallel region.
|
||||
ContextParallelInputType = Dict[
|
||||
Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
|
||||
]
|
||||
|
||||
# A dictionary where keys denote the output to be gathered across context parallel region, and the
|
||||
# value denotes the gathering configuration.
|
||||
ContextParallelOutputType = Union[
|
||||
ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
|
||||
]
|
||||
|
||||
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
|
||||
# the module should be split/gathered across context parallel region.
|
||||
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
|
||||
@@ -17,9 +17,10 @@ import functools
|
||||
import inspect
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
||||
from ..utils import (
|
||||
get_logger,
|
||||
@@ -38,15 +39,22 @@ from ..utils import (
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._modeling_parallel import ParallelConfig
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"):
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward
|
||||
else:
|
||||
logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.")
|
||||
flash_attn_func = None
|
||||
flash_attn_varlen_func = None
|
||||
_flash_attn_forward = None
|
||||
_flash_attn_backward = None
|
||||
|
||||
|
||||
if is_flash_attn_3_available():
|
||||
@@ -104,6 +112,27 @@ else:
|
||||
xops = None
|
||||
|
||||
|
||||
if torch.__version__ >= "2.4.0":
|
||||
_custom_op = torch.library.custom_op
|
||||
_register_fake = torch.library.register_fake
|
||||
else:
|
||||
|
||||
def _custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
|
||||
def wrap(func):
|
||||
return func
|
||||
|
||||
return wrap if fn is None else fn
|
||||
|
||||
def _register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
|
||||
def wrap(func):
|
||||
return func
|
||||
|
||||
return wrap if fn is None else fn
|
||||
|
||||
_custom_op = _custom_op_no_op
|
||||
_register_fake = _register_fake_no_op
|
||||
|
||||
|
||||
# TODO(aryan): Add support for the following:
|
||||
# - Sage Attention++
|
||||
# - block sparse, radial and other attention methods
|
||||
@@ -154,17 +183,25 @@ class _AttentionBackendRegistry:
|
||||
_backends = {}
|
||||
_constraints = {}
|
||||
_supported_arg_names = {}
|
||||
_supports_context_parallel = {}
|
||||
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
|
||||
_checks_enabled = DIFFUSERS_ATTN_CHECKS
|
||||
_parallel_config: Optional["ParallelConfig"] = None
|
||||
|
||||
@classmethod
|
||||
def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None):
|
||||
def register(
|
||||
cls,
|
||||
backend: AttentionBackendName,
|
||||
constraints: Optional[List[Callable]] = None,
|
||||
supports_context_parallel: bool = False,
|
||||
):
|
||||
logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
|
||||
|
||||
def decorator(func):
|
||||
cls._backends[backend] = func
|
||||
cls._constraints[backend] = constraints or []
|
||||
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
|
||||
cls._supports_context_parallel[backend] = supports_context_parallel
|
||||
return func
|
||||
|
||||
return decorator
|
||||
@@ -177,6 +214,17 @@ class _AttentionBackendRegistry:
|
||||
def list_backends(cls):
|
||||
return list(cls._backends.keys())
|
||||
|
||||
@classmethod
|
||||
def _is_context_parallel_enabled(cls, backend: AttentionBackendName) -> bool:
|
||||
if backend not in cls._supports_context_parallel:
|
||||
raise ValueError(f"Backend {backend} is not registered.")
|
||||
supports_context_parallel = cls._supports_context_parallel[backend]
|
||||
is_degree_greater_than_1 = _AttentionBackendRegistry._parallel_config is not None and (
|
||||
_AttentionBackendRegistry._parallel_config.ring_degree > 1
|
||||
or _AttentionBackendRegistry._parallel_config.ulysses_degree > 1
|
||||
)
|
||||
return supports_context_parallel and is_degree_greater_than_1
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE):
|
||||
@@ -195,6 +243,20 @@ def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIV
|
||||
_AttentionBackendRegistry._active_backend = old_backend
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _parallel_context(parallel_config: "ParallelConfig"):
|
||||
"""
|
||||
Context manager to set the parallel configuration for attention backends that support it.
|
||||
"""
|
||||
old_parallel_config = _AttentionBackendRegistry._parallel_config
|
||||
_AttentionBackendRegistry._parallel_config = parallel_config
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_AttentionBackendRegistry._parallel_config = old_parallel_config
|
||||
|
||||
|
||||
def dispatch_attention_fn(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@@ -218,6 +280,14 @@ def dispatch_attention_fn(
|
||||
backend_name = AttentionBackendName(backend)
|
||||
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
|
||||
|
||||
if (
|
||||
_AttentionBackendRegistry._parallel_config is not None
|
||||
and not _AttentionBackendRegistry._is_context_parallel_enabled(backend_name)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Backend {backend_name} does not support context parallelism, but a parallel configuration is provided."
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"query": query,
|
||||
"key": key,
|
||||
@@ -415,20 +485,398 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
# TODO: library.custom_op and register_fake probably need version guards?
|
||||
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
|
||||
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
|
||||
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
||||
def _wrapped_flash_attn_3_original(
|
||||
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
||||
@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
||||
def _wrapped_flash_attn_3(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
qv: Optional[torch.Tensor] = None,
|
||||
q_descale: Optional[torch.Tensor] = None,
|
||||
k_descale: Optional[torch.Tensor] = None,
|
||||
v_descale: Optional[torch.Tensor] = None,
|
||||
attention_chunk: int = 0,
|
||||
softcap: float = 0.0,
|
||||
num_splits: int = 1,
|
||||
pack_gqa: Optional[bool] = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
out, lse = flash_attn_3_func(query, key, value)
|
||||
# Hardcoded for now because pytorch does not support tuple/int type hints
|
||||
window_size = (-1, -1)
|
||||
out, lse, *_ = flash_attn_3_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
qv=qv,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
attention_chunk=attention_chunk,
|
||||
softcap=softcap,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
sm_margin=sm_margin,
|
||||
)
|
||||
lse = lse.permute(0, 2, 1)
|
||||
return out, lse
|
||||
|
||||
|
||||
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
|
||||
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, seq_len, num_heads, head_dim = query.shape
|
||||
@_register_fake("_diffusers_flash_attn_3::_flash_attn_forward")
|
||||
def _(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
qv: Optional[torch.Tensor] = None,
|
||||
q_descale: Optional[torch.Tensor] = None,
|
||||
k_descale: Optional[torch.Tensor] = None,
|
||||
v_descale: Optional[torch.Tensor] = None,
|
||||
attention_chunk: int = 0,
|
||||
softcap: float = 0.0,
|
||||
num_splits: int = 1,
|
||||
pack_gqa: Optional[bool] = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
window_size = (-1, -1) # noqa: F841
|
||||
# A lot of the parameters here are not yet used in any way within diffusers.
|
||||
# We can safely ignore for now and keep the fake op shape propagation simple.
|
||||
batch_size, seq_len, num_heads, head_dim = q.shape
|
||||
lse_shape = (batch_size, seq_len, num_heads)
|
||||
return torch.empty_like(query), query.new_empty(lse_shape)
|
||||
return torch.empty_like(q), q.new_empty(lse_shape)
|
||||
|
||||
|
||||
# ===== Autograd functions =====
|
||||
|
||||
|
||||
class _cudnn_attention(torch.autograd.Function):
|
||||
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
|
||||
# forward declaration:
|
||||
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
# backward declaration:
|
||||
# aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
):
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.")
|
||||
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.is_causal = is_causal
|
||||
ctx.scale = scale
|
||||
ctx.attn_mask = attn_mask
|
||||
|
||||
# Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results
|
||||
# if the input tensors are not contiguous.
|
||||
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
|
||||
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
|
||||
torch.ops.aten._scaled_dot_product_cudnn_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_bias=attn_mask,
|
||||
compute_log_sumexp=return_lse,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
return_debug_mask=False,
|
||||
scale=scale,
|
||||
)
|
||||
)
|
||||
|
||||
ctx.max_q = max_q
|
||||
ctx.max_k = max_k
|
||||
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
|
||||
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
if lse is not None:
|
||||
lse = lse.transpose(1, 2).contiguous()
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args: torch.Tensor,
|
||||
):
|
||||
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
|
||||
grad_out = grad_out.transpose(1, 2).contiguous()
|
||||
|
||||
# Cannot pass first 5 arguments as kwargs because: https://github.com/pytorch/pytorch/blob/d26ca5de058dbcf56ac52bb43e84dd98df2ace97/torch/_dynamo/variables/torch.py#L1341
|
||||
grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp=lse,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset=philox_offset,
|
||||
attn_bias=ctx.attn_mask,
|
||||
cum_seq_q=cum_seq_q,
|
||||
cum_seq_k=cum_seq_k,
|
||||
max_q=ctx.max_q,
|
||||
max_k=ctx.max_k,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
)
|
||||
grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value))
|
||||
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None
|
||||
|
||||
|
||||
# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
|
||||
class _flash_attention_2(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for flash-attn 2.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.")
|
||||
|
||||
# Hardcoded for now
|
||||
window_size = (-1, -1)
|
||||
softcap = 0.0
|
||||
alibi_slopes = None
|
||||
deterministic = False
|
||||
|
||||
if scale is None:
|
||||
scale = query.shape[-1] ** (-0.5)
|
||||
|
||||
# flash-attn only returns LSE if dropout_p > 0. So, we need to workaround.
|
||||
parallel_config = _AttentionBackendRegistry._parallel_config
|
||||
if query.requires_grad or (parallel_config is not None and parallel_config.world_size > 1):
|
||||
dropout_p = dropout_p if dropout_p > 0 else 1e-30
|
||||
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
ctx.window_size = window_size
|
||||
ctx.softcap = softcap
|
||||
ctx.alibi_slopes = alibi_slopes
|
||||
ctx.deterministic = deterministic
|
||||
|
||||
out, lse, S_dmask, rng_state = _flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
scale,
|
||||
is_causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
alibi_slopes,
|
||||
return_lse,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(query, key, value, out, lse, rng_state)
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args: torch.Tensor,
|
||||
):
|
||||
query, key, value, out, lse, rng_state = ctx.saved_tensors
|
||||
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
|
||||
|
||||
lse_d = _flash_attn_backward( # noqa: F841
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
grad_query,
|
||||
grad_key,
|
||||
grad_value,
|
||||
ctx.dropout_p,
|
||||
ctx.scale,
|
||||
ctx.is_causal,
|
||||
ctx.window_size[0],
|
||||
ctx.window_size[1],
|
||||
ctx.softcap,
|
||||
ctx.alibi_slopes,
|
||||
ctx.deterministic,
|
||||
rng_state,
|
||||
)
|
||||
|
||||
# Head dimension may have been padded
|
||||
grad_query = grad_query[..., : grad_out.shape[-1]]
|
||||
grad_key = grad_key[..., : grad_out.shape[-1]]
|
||||
grad_value = grad_value[..., : grad_out.shape[-1]]
|
||||
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
# ===== Context parallel =====
|
||||
|
||||
|
||||
class TemplatedRingAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
return_lse: bool,
|
||||
op: torch.autograd.Function,
|
||||
):
|
||||
parallel_config = _AttentionBackendRegistry._parallel_config
|
||||
ring_mesh = parallel_config._ring_mesh
|
||||
rank = parallel_config._ring_local_rank
|
||||
world_size = parallel_config.ring_degree
|
||||
|
||||
next_rank = (rank + 1) % world_size
|
||||
prev_out = prev_lse = None
|
||||
|
||||
kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
|
||||
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
|
||||
kv_buffer = kv_buffer.chunk(world_size)
|
||||
|
||||
for i in range(world_size):
|
||||
if i > 0:
|
||||
kv = kv_buffer[next_rank]
|
||||
key = kv[: key.numel()].reshape_as(key)
|
||||
value = kv[key.numel() :].reshape_as(value)
|
||||
next_rank = (next_rank + 1) % world_size
|
||||
|
||||
out, lse = op.apply(query, key, value, None, 0.0, None, False, False, True)
|
||||
|
||||
if parallel_config.convert_to_fp32:
|
||||
out = out.to(torch.float32)
|
||||
lse = lse.to(torch.float32)
|
||||
|
||||
lse = lse.unsqueeze(-1)
|
||||
if prev_out is not None:
|
||||
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
|
||||
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
|
||||
prev_out = out
|
||||
prev_lse = lse
|
||||
|
||||
out = out.to(query.dtype)
|
||||
lse = lse.squeeze(-1)
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args: torch.Tensor,
|
||||
):
|
||||
raise NotImplementedError("Backward pass is not implemented for TemplatedRingAttention.")
|
||||
|
||||
|
||||
class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
return_lse: bool,
|
||||
op: torch.autograd.Function,
|
||||
):
|
||||
parallel_config = _AttentionBackendRegistry._parallel_config
|
||||
ulysses_mesh = parallel_config._ulysses_mesh
|
||||
world_size = parallel_config.ulysses_degree
|
||||
group = ulysses_mesh.get_group()
|
||||
|
||||
B, S_LOCAL, H, D = query.shape
|
||||
H_LOCAL = H // world_size
|
||||
query, key, value = (
|
||||
x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
|
||||
for x in (query, key, value)
|
||||
)
|
||||
query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value))
|
||||
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
|
||||
|
||||
out = op.apply(query, key, value, None, 0.0, None, False, False, return_lse)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
out = out.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
|
||||
out = funcol.all_to_all_single(out, None, None, group=group).wait()
|
||||
out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
|
||||
|
||||
if return_lse:
|
||||
lse = lse.reshape(B, world_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
|
||||
lse = funcol.all_to_all_single(lse, None, None, group=group).wait()
|
||||
lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
|
||||
else:
|
||||
lse = None
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args: torch.Tensor,
|
||||
):
|
||||
raise NotImplementedError("Backward pass is not implemented for TemplatedUlyssesAttention.")
|
||||
|
||||
|
||||
def _templated_context_parallel_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
*,
|
||||
op: torch.autograd.Function,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("Attention mask is not yet supported for templated attention.")
|
||||
if is_causal:
|
||||
raise ValueError("Causal attention is not yet supported for templated attention.")
|
||||
if enable_gqa:
|
||||
raise ValueError("GQA is not yet supported for templated attention.")
|
||||
|
||||
parallel_config = _AttentionBackendRegistry._parallel_config
|
||||
# TODO: add support for unified attention with ring/ulysses degree both being > 1
|
||||
if parallel_config.ring_degree > 1:
|
||||
return TemplatedRingAttention.apply(query, key, value, return_lse, op)
|
||||
elif parallel_config.ulysses_degree > 1:
|
||||
return TemplatedUlyssesAttention.apply(query, key, value, return_lse, op)
|
||||
else:
|
||||
return op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse)
|
||||
|
||||
|
||||
# ===== Attention backends =====
|
||||
@@ -445,11 +893,7 @@ def _flash_attention(
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
softcap: float = 0.0,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
deterministic: bool = False,
|
||||
return_attn_probs: bool = False,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
out = flash_attn_func(
|
||||
q=query,
|
||||
@@ -458,11 +902,7 @@ def _flash_attention(
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
alibi_slopes=alibi_slopes,
|
||||
deterministic=deterministic,
|
||||
return_attn_probs=return_attn_probs,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -475,19 +915,11 @@ def _flash_varlen_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||
max_seqlen_q: Optional[int] = None,
|
||||
max_seqlen_k: Optional[int] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
softcap: float = 0.0,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
deterministic: bool = False,
|
||||
return_attn_probs: bool = False,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len_q, _, _ = query.shape
|
||||
_, seq_len_kv, _, _ = key.shape
|
||||
@@ -495,16 +927,11 @@ def _flash_varlen_attention(
|
||||
if attn_mask is not None:
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
else:
|
||||
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
||||
)
|
||||
|
||||
key_valid, value_valid = [], []
|
||||
for b in range(batch_size):
|
||||
@@ -527,11 +954,7 @@ def _flash_varlen_attention(
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
alibi_slopes=alibi_slopes,
|
||||
deterministic=deterministic,
|
||||
return_attn_probs=return_attn_probs,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
@@ -548,30 +971,16 @@ def _flash_attention_3(
|
||||
value: torch.Tensor,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
softcap: float = 0.0,
|
||||
deterministic: bool = False,
|
||||
return_attn_probs: bool = False,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
out, lse, *_ = flash_attn_3_func(
|
||||
out, lse = _wrapped_flash_attn_3(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
attention_chunk=0,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
)
|
||||
return (out, lse) if return_attn_probs else out
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
@@ -582,17 +991,10 @@ def _flash_varlen_attention_3(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||
max_seqlen_q: Optional[int] = None,
|
||||
max_seqlen_k: Optional[int] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
softcap: float = 0.0,
|
||||
deterministic: bool = False,
|
||||
return_attn_probs: bool = False,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len_q, _, _ = query.shape
|
||||
_, seq_len_kv, _, _ = key.shape
|
||||
@@ -600,16 +1002,11 @@ def _flash_varlen_attention_3(
|
||||
if attn_mask is not None:
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
else:
|
||||
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
||||
)
|
||||
|
||||
key_valid, value_valid = [], []
|
||||
for b in range(batch_size):
|
||||
@@ -629,24 +1026,12 @@ def _flash_varlen_attention_3(
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
seqused_q=None,
|
||||
seqused_k=None,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
)
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
return (out, lse) if return_attn_probs else out
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
@@ -662,7 +1047,6 @@ def _native_flex_attention(
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
kernel_options: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
# TODO: should we LRU cache the block mask creation?
|
||||
score_mod = None
|
||||
@@ -707,7 +1091,6 @@ def _native_flex_attention(
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
return_lse=return_lse,
|
||||
kernel_options=kernel_options,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
return out
|
||||
@@ -726,7 +1109,10 @@ def _native_attention(
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
@@ -745,6 +1131,7 @@ def _native_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._NATIVE_CUDNN,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _native_cudnn_attention(
|
||||
query: torch.Tensor,
|
||||
@@ -755,21 +1142,33 @@ def _native_cudnn_attention(
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
parallel_config = _AttentionBackendRegistry._parallel_config
|
||||
|
||||
lse = None
|
||||
if parallel_config is None and not return_lse:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op=_cudnn_attention
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
return out
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
@@ -785,7 +1184,10 @@ def _native_efficient_attention(
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Native efficient attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
@@ -814,7 +1216,10 @@ def _native_flash_attention(
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Native flash attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
@@ -844,7 +1249,10 @@ def _native_math_attention(
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Native math attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
@@ -871,7 +1279,10 @@ def _native_npu_attention(
|
||||
value: torch.Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
return npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
@@ -898,7 +1309,10 @@ def _native_xla_attention(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
query = query / math.sqrt(query.shape[-1])
|
||||
out = xla_flash_attention(
|
||||
@@ -942,31 +1356,25 @@ def _sage_varlen_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||
max_seqlen_q: Optional[int] = None,
|
||||
max_seqlen_k: Optional[int] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
smooth_k: bool = True,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Sage varlen backend does not support setting `return_lse=True`.")
|
||||
|
||||
batch_size, seq_len_q, _, _ = query.shape
|
||||
_, seq_len_kv, _, _ = key.shape
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
else:
|
||||
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
||||
)
|
||||
|
||||
key_valid, value_valid = [], []
|
||||
for b in range(batch_size):
|
||||
@@ -988,7 +1396,6 @@ def _sage_varlen_attention(
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
smooth_k=smooth_k,
|
||||
)
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
@@ -1005,10 +1412,6 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
|
||||
value: torch.Tensor,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
|
||||
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
|
||||
smooth_k: bool = True,
|
||||
smooth_v: bool = False,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return sageattn_qk_int8_pv_fp8_cuda(
|
||||
@@ -1017,11 +1420,7 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
qk_quant_gran=qk_quant_gran,
|
||||
sm_scale=scale,
|
||||
pv_accum_dtype=pv_accum_dtype,
|
||||
smooth_k=smooth_k,
|
||||
smooth_v=smooth_v,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
@@ -1036,9 +1435,6 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
|
||||
value: torch.Tensor,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
|
||||
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
|
||||
smooth_k: bool = True,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return sageattn_qk_int8_pv_fp8_cuda_sm90(
|
||||
@@ -1047,10 +1443,7 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
qk_quant_gran=qk_quant_gran,
|
||||
sm_scale=scale,
|
||||
pv_accum_dtype=pv_accum_dtype,
|
||||
smooth_k=smooth_k,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
@@ -1065,10 +1458,6 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
|
||||
value: torch.Tensor,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
|
||||
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32",
|
||||
smooth_k: bool = True,
|
||||
smooth_v: bool = False,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return sageattn_qk_int8_pv_fp16_cuda(
|
||||
@@ -1077,11 +1466,7 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
qk_quant_gran=qk_quant_gran,
|
||||
sm_scale=scale,
|
||||
pv_accum_dtype=pv_accum_dtype,
|
||||
smooth_k=smooth_k,
|
||||
smooth_v=smooth_v,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
@@ -1096,8 +1481,6 @@ def _sage_qk_int8_pv_fp16_triton_attention(
|
||||
value: torch.Tensor,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
|
||||
smooth_k: bool = True,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return sageattn_qk_int8_pv_fp16_triton(
|
||||
@@ -1105,10 +1488,8 @@ def _sage_qk_int8_pv_fp16_triton_attention(
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
quantization_backend=quantization_backend,
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
smooth_k=smooth_k,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
@@ -1126,7 +1507,11 @@ def _xformers_attention(
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("xformers attention backend does not support setting `return_lse=True`.")
|
||||
|
||||
batch_size, seq_len_q, num_heads_q, _ = query.shape
|
||||
_, seq_len_kv, num_heads_kv, _ = key.shape
|
||||
|
||||
|
||||
@@ -271,6 +271,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
_skip_layerwise_casting_patterns = None
|
||||
_supports_group_offloading = True
|
||||
_repeated_blocks = []
|
||||
_cp_plan = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -1492,6 +1493,52 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
|
||||
)
|
||||
|
||||
def parallelize(self, *, ring_degree: int = 1, ulysses_degree: int = 1, cp_plan=None):
|
||||
from ..hooks.context_parallel import ParallelConfig, apply_context_parallel
|
||||
|
||||
# TODO(aryan): add cp_plan type hint
|
||||
logger.warning(
|
||||
"`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
|
||||
)
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
raise RuntimeError("torch.distributed must be initialized before calling `parallelize`.")
|
||||
if ring_degree < 1 or ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if ring_degree > 1 and ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
if ring_degree * ulysses_degree > world_size:
|
||||
raise ValueError(
|
||||
f"The product of `ring_degree` ({ring_degree}) and `ulysses_degree` ({ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
|
||||
device_type = torch._C._get_accelerator().type
|
||||
device_module = torch.get_device_module(device_type)
|
||||
device = torch.device(device_type, rank % device_module.device_count())
|
||||
|
||||
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=(ring_degree, ulysses_degree),
|
||||
mesh_dim_names=("ring", "ulysses"),
|
||||
)
|
||||
parallel_config = ParallelConfig(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
ring_degree=ring_degree,
|
||||
ulysses_degree=ulysses_degree,
|
||||
device=device,
|
||||
cp_mesh=cp_mesh,
|
||||
)
|
||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||
|
||||
apply_context_parallel(self, parallel_config, cp_plan)
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
|
||||
@@ -25,6 +25,7 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -569,6 +570,15 @@ class FluxTransformer2DModel(
|
||||
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
|
||||
"txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user