diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index 21e3390584..bf84183bcb 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -11,36 +11,817 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Callable, Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
-from ..utils import deprecate, logging
+
+# Import xformers only if it's available
+try:
+ import xformers
+ import xformers.ops
+except ImportError:
+ xformers = None
+
+from ..utils import logging
+from ..utils.import_utils import (
+ is_torch_npu_available,
+ is_torch_xla_available,
+ is_xformers_available,
+)
from ..utils.torch_utils import maybe_allow_in_graph
-from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
-from .attention_processor import Attention, JointAttnProcessor2_0
-from .embeddings import SinusoidalPositionalEmbedding
-from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
+from .attention_processor import (
+ AttentionProcessor,
+ AttnProcessor,
+)
+from .normalization import RMSNorm
logger = logging.get_logger(__name__)
-def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
- # "feed_forward_chunk_size" can be used to save memory
- if hidden_states.shape[chunk_dim] % chunk_size != 0:
- raise ValueError(
- f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
- )
+class AttentionMixin:
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
- num_chunks = hidden_states.shape[chunk_dim] // chunk_size
- ff_output = torch.cat(
- [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
- dim=chunk_dim,
- )
- return ff_output
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, AttentionModuleMixin):
+ module.fuse_projections(fuse=True)
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+
+class AttentionModuleMixin:
+ """
+ A mixin class that provides common methods for attention modules.
+
+ This mixin adds functionality to set different attention processors, handle attention masks, compute attention
+ scores, and manage projections.
+ """
+
+ # Default processor classes to be overridden by subclasses
+ default_processor_cls = None
+ _available_processors = []
+
+ fused_projections = False
+ is_cross_attention = False
+
+ def _get_compatible_processor(self, backend):
+ for processor_cls in self._available_processors:
+ if backend in processor_cls.compatible_backends:
+ processor = processor_cls()
+ return processor
+
+ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
+ """
+ Set whether to use NPU flash attention from `torch_npu` or not.
+
+ Args:
+ use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
+ """
+ processor = self.default_processor_cls()
+
+ if use_npu_flash_attention:
+ if not is_torch_npu_available():
+ raise ImportError("torch_npu is not available")
+ processor = self._get_compatible_processor("npu")
+
+ self.set_processor(processor)
+
+ def set_use_xla_flash_attention(
+ self,
+ use_xla_flash_attention: bool,
+ partition_spec: Optional[Tuple[Optional[str], ...]] = None,
+ is_flux=False,
+ ) -> None:
+ """
+ Set whether to use XLA flash attention from `torch_xla` or not.
+
+ Args:
+ use_xla_flash_attention (`bool`):
+ Whether to use pallas flash attention kernel from `torch_xla` or not.
+ partition_spec (`Tuple[]`, *optional*):
+ Specify the partition specification if using SPMD. Otherwise None.
+ is_flux (`bool`, *optional*, defaults to `False`):
+ Whether the model is a Flux model.
+ """
+ processor = self.default_processor_cls()
+ if use_xla_flash_attention:
+ if not is_torch_xla_available():
+ raise ImportError("torch_xla is not available")
+ processor = self._get_compatible_processor("xla")
+
+ self.set_processor(processor)
+
+ @torch.no_grad()
+ def fuse_projections(self, fuse=True):
+ """
+ Fuse the query, key, and value projections into a single projection for efficiency.
+
+ Args:
+ fuse (`bool`): Whether to fuse the projections or not.
+ """
+ # Skip if already in desired state
+ if getattr(self, "fused_projections", False) == fuse:
+ return
+
+ device = self.to_q.weight.data.device
+ dtype = self.to_q.weight.data.dtype
+
+ if not self.is_cross_attention:
+ # Fuse self-attention projections
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_qkv.weight.copy_(concatenated_weights)
+ if self.use_bias:
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ self.to_qkv.bias.copy_(concatenated_bias)
+
+ else:
+ # Fuse cross-attention key-value projections
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_kv.weight.copy_(concatenated_weights)
+ if self.use_bias:
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ self.to_kv.bias.copy_(concatenated_bias)
+
+ # Handle added projections for models like SD3, Flux, etc.
+ if (
+ getattr(self, "add_q_proj", None) is not None
+ and getattr(self, "add_k_proj", None) is not None
+ and getattr(self, "add_v_proj", None) is not None
+ ):
+ concatenated_weights = torch.cat(
+ [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
+ )
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_added_qkv = nn.Linear(
+ in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
+ )
+ self.to_added_qkv.weight.copy_(concatenated_weights)
+ if self.added_proj_bias:
+ concatenated_bias = torch.cat(
+ [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
+ )
+ self.to_added_qkv.bias.copy_(concatenated_bias)
+
+ self.fused_projections = fuse
+
+ def set_use_memory_efficient_attention_xformers(
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ """
+ Set whether to use memory efficient attention from `xformers` or not.
+
+ Args:
+ use_memory_efficient_attention_xformers (`bool`):
+ Whether to use memory efficient attention from `xformers` or not.
+ attention_op (`Callable`, *optional*):
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
+ `xformers`.
+ """
+ if use_memory_efficient_attention_xformers:
+ if not is_xformers_available():
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
+ " only available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ if xformers is not None:
+ dtype = None
+ if attention_op is not None:
+ op_fw, op_bw = attention_op
+ dtype, *_ = op_fw.SUPPORTED_DTYPES
+ q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
+ _ = xformers.ops.memory_efficient_attention(q, q, q)
+ except Exception as e:
+ raise e
+
+ processor = self._get_compatible_processor("xformers")
+ else:
+ # Set default processor
+ processor = self.default_processor_cls()
+
+ if processor is not None:
+ self.set_processor(processor)
+
+ def set_attention_slice(self, slice_size: int) -> None:
+ """
+ Set the slice size for attention computation.
+
+ Args:
+ slice_size (`int`):
+ The slice size for attention computation.
+ """
+ if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ processor = None
+
+ # Try to get a compatible processor for sliced attention
+ if slice_size is not None:
+ processor = self._get_compatible_processor("sliced")
+
+ # If no processor was found or slice_size is None, use default processor
+ if processor is None:
+ processor = self.default_processor_cls()
+
+ self.set_processor(processor)
+
+ def set_processor(self, processor: "AttnProcessor") -> None:
+ """
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ """
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
+ """
+ Get the attention processor in use.
+
+ Args:
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
+ Set to `True` to return the deprecated LoRA attention processor.
+
+ Returns:
+ "AttentionProcessor": The attention processor in use.
+ """
+ if not return_deprecated_lora:
+ return self.processor
+
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
+ """
+ Reshape the tensor for multi-head attention processing.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ if tensor.ndim == 3:
+ batch_size, seq_len, dim = tensor.shape
+ extra_dim = 1
+ else:
+ batch_size, extra_dim, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3)
+
+ if out_dim == 3:
+ tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
+
+ return tensor
+
+ def get_attention_scores(
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ Compute the attention scores.
+
+ Args:
+ query (`torch.Tensor`): The query tensor.
+ key (`torch.Tensor`): The key tensor.
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
+
+ Returns:
+ `torch.Tensor`: The attention probabilities/scores.
+ """
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key.transpose(-1, -2),
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ del attention_scores
+
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+ def prepare_attention_mask(
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
+ ) -> torch.Tensor:
+ """
+ Prepare the attention mask for the attention computation.
+
+ Args:
+ attention_mask (`torch.Tensor`): The attention mask to prepare.
+ target_length (`int`): The target length of the attention mask.
+ batch_size (`int`): The batch size for repeating the attention mask.
+ out_dim (`int`, *optional*, defaults to `3`): Output dimension.
+
+ Returns:
+ `torch.Tensor`: The prepared attention mask.
+ """
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
+ # remaining_length: int = target_length - current_length
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+ return attention_mask
+
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Normalize the encoder hidden states.
+
+ Args:
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
+
+ Returns:
+ `torch.Tensor`: The normalized encoder hidden states.
+ """
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
+ if isinstance(self.norm_cross, nn.LayerNorm):
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ elif isinstance(self.norm_cross, nn.GroupNorm):
+ # Group norm norms along the channels dimension and expects
+ # input to be in the shape of (N, C, *). In this case, we want
+ # to norm along the hidden dimension, so we need to move
+ # (batch_size, sequence_length, hidden_size) ->
+ # (batch_size, hidden_size, sequence_length)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ else:
+ assert False
+
+ return encoder_hidden_states
+
+
+@maybe_allow_in_graph
+class Attention(nn.Module, AttentionModuleMixin):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`):
+ The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8):
+ The number of heads to use for multi-head attention.
+ kv_heads (`int`, *optional*, defaults to `None`):
+ The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
+ `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
+ Query Attention (MQA) otherwise GQA is used.
+ dim_head (`int`, *optional*, defaults to 64):
+ The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ upcast_attention (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the attention computation to `float32`.
+ upcast_softmax (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the softmax computation to `float32`.
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the group norm in the cross attention.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ norm_num_groups (`int`, *optional*, defaults to `None`):
+ The number of groups to use for the group norm in the attention.
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the spatial normalization.
+ out_bias (`bool`, *optional*, defaults to `True`):
+ Set to `True` to use a bias in the output linear layer.
+ scale_qk (`bool`, *optional*, defaults to `True`):
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
+ `added_kv_proj_dim` is not `None`.
+ eps (`float`, *optional*, defaults to 1e-5):
+ An additional value added to the denominator in group normalization that is used for numerical stability.
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
+ A factor to rescale the output by dividing it with this value.
+ residual_connection (`bool`, *optional*, defaults to `False`):
+ Set to `True` to add the residual connection to the output.
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
+ Set to `True` if the attention block is loaded from a deprecated state dict.
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
+ The attention processor to use. If `None`, defaults to `AttnProcessorSDPA` if `torch 2.x` is used and
+ `AttnProcessor` otherwise.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ kv_heads: Optional[int] = None,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ cross_attention_norm_num_groups: int = 32,
+ qk_norm: Optional[str] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ norm_num_groups: Optional[int] = None,
+ spatial_norm_dim: Optional[int] = None,
+ out_bias: bool = True,
+ scale_qk: bool = True,
+ only_cross_attention: bool = False,
+ eps: float = 1e-5,
+ rescale_output_factor: float = 1.0,
+ residual_connection: bool = False,
+ _from_deprecated_attn_block: bool = False,
+ processor: Optional["AttnProcessor"] = None,
+ out_dim: int = None,
+ out_context_dim: int = None,
+ context_pre_only=None,
+ pre_only=False,
+ elementwise_affine: bool = True,
+ is_causal: bool = False,
+ ):
+ super().__init__()
+
+ # To prevent circular import.
+ from .normalization import FP32LayerNorm, LpNorm
+
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.is_cross_attention = cross_attention_dim is not None
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.rescale_output_factor = rescale_output_factor
+ self.residual_connection = residual_connection
+ self.dropout = dropout
+ self.fused_projections = False
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+ self.is_causal = is_causal
+
+ # we make use of this private variable to know whether this class is loaded
+ # with an deprecated state dict so that we can convert it on the fly
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
+
+ self.scale_qk = scale_qk
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.only_cross_attention = only_cross_attention
+
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
+ raise ValueError(
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+ )
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
+ else:
+ self.group_norm = None
+
+ if spatial_norm_dim is not None:
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
+ else:
+ self.spatial_norm = None
+
+ if qk_norm is None:
+ self.norm_q = None
+ self.norm_k = None
+ elif qk_norm == "layer_norm":
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ elif qk_norm == "fp32_layer_norm":
+ self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ elif qk_norm == "layer_norm_across_heads":
+ # Lumina applies qk norm across all heads
+ self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
+ self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
+ elif qk_norm == "rms_norm":
+ self.norm_q = RMSNorm(dim_head, eps=eps)
+ self.norm_k = RMSNorm(dim_head, eps=eps)
+ elif qk_norm == "rms_norm_across_heads":
+ # LTX applies qk norm across all heads
+ self.norm_q = RMSNorm(dim_head * heads, eps=eps)
+ self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
+ elif qk_norm == "l2":
+ self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
+ self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
+ else:
+ raise ValueError(
+ f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
+ )
+
+ if cross_attention_norm is None:
+ self.norm_cross = None
+ elif cross_attention_norm == "layer_norm":
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
+ elif cross_attention_norm == "group_norm":
+ if self.added_kv_proj_dim is not None:
+ # The given `encoder_hidden_states` are initially of shape
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
+ # before the projection, so we need to use `added_kv_proj_dim` as
+ # the number of channels for the group norm.
+ norm_cross_num_channels = added_kv_proj_dim
+ else:
+ norm_cross_num_channels = self.cross_attention_dim
+
+ self.norm_cross = nn.GroupNorm(
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
+ )
+ else:
+ raise ValueError(
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
+ )
+
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.only_cross_attention:
+ # only relevant for the `AddedKVProcessor` classes
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ else:
+ self.to_k = None
+ self.to_v = None
+
+ self.added_proj_bias = added_proj_bias
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
+ if self.context_pre_only is not None:
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ else:
+ self.add_q_proj = None
+ self.add_k_proj = None
+ self.add_v_proj = None
+
+ if not self.pre_only:
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(dropout))
+ else:
+ self.to_out = None
+
+ if self.context_pre_only is not None and not self.context_pre_only:
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
+ else:
+ self.to_add_out = None
+
+ if qk_norm is not None and added_kv_proj_dim is not None:
+ if qk_norm == "layer_norm":
+ self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ elif qk_norm == "fp32_layer_norm":
+ self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ elif qk_norm == "rms_norm":
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
+ elif qk_norm == "rms_norm_across_heads":
+ # Wan applies qk norm across all heads
+ # Wan also doesn't apply a q norm
+ self.norm_added_q = None
+ self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
+ else:
+ raise ValueError(
+ f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
+ )
+ else:
+ self.norm_added_q = None
+ self.norm_added_k = None
+
+ # set attention processor
+ # We use the AttnProcessorSDPA by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ if processor is None:
+ processor = self.default_processor_class()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ r"""
+ The forward method of the `Attention` class.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ The hidden states of the query.
+ encoder_hidden_states (`torch.Tensor`, *optional*):
+ The hidden states of the encoder.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention mask to use. If `None`, no mask is applied.
+ **cross_attention_kwargs:
+ Additional keyword arguments to pass along to the cross attention.
+
+ Returns:
+ `torch.Tensor`: The output of the attention layer.
+ """
+ # The `Attention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
+ unused_kwargs = [
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
+ ]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
+
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
@maybe_allow_in_graph
diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py
new file mode 100644
index 0000000000..c6c78a44a6
--- /dev/null
+++ b/src/diffusers/models/attention_dispatch.py
@@ -0,0 +1,1098 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import functools
+import inspect
+import math
+from enum import Enum
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+
+import torch
+
+from ..utils import (
+ get_logger,
+ is_flash_attn_3_available,
+ is_flash_attn_available,
+ is_flash_attn_version,
+ is_sageattention_available,
+ is_sageattention_version,
+ is_torch_npu_available,
+ is_torch_version,
+ is_torch_xla_available,
+ is_torch_xla_version,
+ is_xformers_available,
+ is_xformers_version,
+)
+from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"):
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+else:
+ logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.")
+ flash_attn_func = None
+ flash_attn_varlen_func = None
+
+
+if is_flash_attn_3_available():
+ from flash_attn_interface import flash_attn_func as flash_attn_3_func
+ from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
+else:
+ flash_attn_3_func = None
+ flash_attn_3_varlen_func = None
+
+
+if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"):
+ from sageattention import (
+ sageattn,
+ sageattn_qk_int8_pv_fp8_cuda,
+ sageattn_qk_int8_pv_fp8_cuda_sm90,
+ sageattn_qk_int8_pv_fp16_cuda,
+ sageattn_qk_int8_pv_fp16_triton,
+ sageattn_varlen,
+ )
+else:
+ logger.warning(
+ "`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`."
+ )
+ sageattn = None
+ sageattn_qk_int8_pv_fp16_cuda = None
+ sageattn_qk_int8_pv_fp16_triton = None
+ sageattn_qk_int8_pv_fp8_cuda = None
+ sageattn_qk_int8_pv_fp8_cuda_sm90 = None
+ sageattn_varlen = None
+
+
+if is_torch_version(">=", "2.5.0"):
+ # We cannot import the flex_attention function from the package directly because it is expected (from the
+ # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
+ # compiled function.
+ import torch.nn.attention.flex_attention as flex_attention
+
+
+if is_torch_npu_available():
+ from torch_npu import npu_fusion_attention
+else:
+ npu_fusion_attention = None
+
+
+if is_torch_xla_available() and is_torch_xla_version(">", "2.2"):
+ from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
+else:
+ xla_flash_attention = None
+
+
+if is_xformers_available() and is_xformers_version(">=", "0.0.29"):
+ import xformers.ops as xops
+else:
+ logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.")
+ xops = None
+
+
+_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
+_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
+_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
+
+
+class AttentionBackendName(str, Enum):
+ # EAGER = "eager"
+
+ # `flash-attn`
+ FLASH = "flash"
+ FLASH_VARLEN = "flash_varlen"
+ _FLASH_3 = "_flash_3"
+ _FLASH_VARLEN_3 = "_flash_varlen_3"
+
+ # PyTorch native
+ FLEX = "flex"
+ NATIVE = "native"
+ _NATIVE_CUDNN = "_native_cudnn"
+ _NATIVE_EFFICIENT = "_native_efficient"
+ _NATIVE_FLASH = "_native_flash"
+ _NATIVE_MATH = "_native_math"
+ _NATIVE_NPU = "_native_npu"
+ _NATIVE_XLA = "_native_xla"
+
+ # `sageattention`
+ SAGE = "sage"
+ SAGE_VARLEN = "sage_varlen"
+ _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
+ _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
+ _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda"
+ _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton"
+ # TODO: let's not add support for Sparge Attention now because it requires tuning per model
+ # We can look into supporting something "autotune"-ing in the future
+ # SPARGE = "sparge"
+
+ # `xformers`
+ XFORMERS = "xformers"
+
+
+class _AttentionBackendRegistry:
+ _backends = {}
+ _constraints = {}
+ _supported_arg_names = {}
+ _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
+ _checks_enabled = DIFFUSERS_ATTN_CHECKS
+
+ @classmethod
+ def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None):
+ logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
+
+ def decorator(func):
+ cls._backends[backend] = func
+ cls._constraints[backend] = constraints or []
+ cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
+ return func
+
+ return decorator
+
+ @classmethod
+ def get_active_backend(cls):
+ return cls._active_backend, cls._backends[cls._active_backend]
+
+ @classmethod
+ def list_backends(cls):
+ return list(cls._backends.keys())
+
+
+@contextlib.contextmanager
+def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE):
+ """
+ Context manager to set the active attention backend.
+ """
+ if backend not in _AttentionBackendRegistry._backends:
+ raise ValueError(f"Backend {backend} is not registered.")
+
+ old_backend = _AttentionBackendRegistry._active_backend
+ _AttentionBackendRegistry._active_backend = backend
+
+ try:
+ yield
+ finally:
+ _AttentionBackendRegistry._active_backend = old_backend
+
+
+def dispatch_attention_fn(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ *,
+ backend: Optional[AttentionBackendName] = None,
+) -> torch.Tensor:
+ attention_kwargs = attention_kwargs or {}
+
+ if backend is None:
+ # If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment
+ # variable), or we use a custom backend based on whether user is using the `attention_backend` context manager
+ backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend()
+ else:
+ backend_name = AttentionBackendName(backend)
+ backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
+
+ kwargs = {
+ "query": query,
+ "key": key,
+ "value": value,
+ "attn_mask": attn_mask,
+ "dropout_p": dropout_p,
+ "is_causal": is_causal,
+ "scale": scale,
+ "enable_gqa": enable_gqa,
+ **attention_kwargs,
+ }
+
+ if _AttentionBackendRegistry._checks_enabled:
+ removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
+ if removed_kwargs:
+ logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.")
+ for check in _AttentionBackendRegistry._constraints.get(backend_name):
+ check(**kwargs)
+
+ kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
+ return backend_fn(**kwargs)
+
+
+# ===== Checks =====
+# A list of very simple functions to catch common errors quickly when debugging.
+
+
+def _check_attn_mask_is_none(attn_mask: Optional[torch.Tensor], **kwargs) -> None:
+ if attn_mask is not None:
+ raise ValueError("Attention mask must be None for this backend.")
+
+
+def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None:
+ if attn_mask is not None and is_causal:
+ raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.")
+
+
+def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ if query.device != key.device or query.device != value.device:
+ raise ValueError("Query, key, and value must be on the same device.")
+ if query.dtype != key.dtype or query.dtype != value.dtype:
+ raise ValueError("Query, key, and value must have the same dtype.")
+
+
+def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_device(query, key, value)
+ if query.device.type != "cuda":
+ raise ValueError("Query, key, and value must be on a CUDA device.")
+
+
+def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable:
+ def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_device_cuda(query, key, value)
+ if torch.cuda.get_device_capability(query.device) < (major, minor):
+ raise ValueError(
+ f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}."
+ )
+
+ return check_device_cuda
+
+
+def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ if query.dtype != key.dtype:
+ raise ValueError("Query and key must have the same dtype.")
+ if query.dtype != value.dtype:
+ raise ValueError("Query and value must have the same dtype.")
+
+
+def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_qkv_dtype_match(query, key, value)
+ if query.dtype not in (torch.bfloat16, torch.float16):
+ raise ValueError("Query, key, and value must be either bfloat16 or float16.")
+
+
+def _check_shape(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+) -> None:
+ if query.shape[-1] != key.shape[-1]:
+ raise ValueError("Query and key must have the same last dimension.")
+ if query.shape[-2] != value.shape[-2]:
+ raise ValueError("Query and value must have the same second to last dimension.")
+ if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:
+ raise ValueError("Attention mask must match the key's second to last dimension.")
+
+
+# ===== Helper functions =====
+
+
+@functools.lru_cache(maxsize=8)
+def _prepare_for_flash_attn_or_sage_varlen(
+ batch_size: int,
+ seq_len_q: int,
+ seq_len_kv: int,
+ attn_mask: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+) -> None:
+ seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
+ if attn_mask is None:
+ seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
+ else:
+ seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
+ cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
+ cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
+ max_seqlen_q = seqlens_q.max().item()
+ max_seqlen_k = seqlens_k.max().item()
+ return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
+
+
+def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
+ """
+ Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in
+ FlashAttention/Sage varlen.
+
+ Supports 1D to 4D shapes and common broadcasting patterns.
+ """
+ if attn_mask.dtype != torch.bool:
+ raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")
+
+ if attn_mask.ndim == 1:
+ # [seq_len_k] -> broadcast across batch
+ attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 2:
+ # [batch_size, seq_len_k]. Maybe broadcast across batch
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask."
+ )
+ attn_mask = attn_mask.expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 3:
+ # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension
+ # We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen.
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask."
+ )
+ attn_mask = attn_mask.any(dim=1)
+ attn_mask = attn_mask.expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 4:
+ # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask."
+ )
+ attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K]
+ attn_mask = attn_mask.any(dim=(1, 2)) # [B, K]
+
+ else:
+ raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}")
+
+ if attn_mask.shape != (batch_size, seq_len_k):
+ raise ValueError(
+ f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})"
+ )
+
+ return attn_mask
+
+
+def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
+ return q_idx >= kv_idx
+
+
+# ===== Attention backends =====
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.FLASH,
+ constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ alibi_slopes: Optional[torch.Tensor] = None,
+ deterministic: bool = False,
+ return_attn_probs: bool = False,
+) -> torch.Tensor:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out = flash_attn_func(
+ q=query,
+ k=key,
+ v=value,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ window_size=window_size,
+ softcap=softcap,
+ alibi_slopes=alibi_slopes,
+ deterministic=deterministic,
+ return_attn_probs=return_attn_probs,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.FLASH_VARLEN,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_varlen_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_k: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_k: Optional[int] = None,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ alibi_slopes: Optional[torch.Tensor] = None,
+ deterministic: bool = False,
+ return_attn_probs: bool = False,
+ attn_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ batch_size, _, seq_len_q, _ = query.shape
+ _, _, seq_len_kv, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+ else:
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ out = flash_attn_varlen_func(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ window_size=window_size,
+ softcap=softcap,
+ alibi_slopes=alibi_slopes,
+ deterministic=deterministic,
+ return_attn_probs=return_attn_probs,
+ )
+ out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3)
+
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._FLASH_3,
+ constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_attention_3(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ deterministic: bool = False,
+ return_attn_probs: bool = False,
+) -> torch.Tensor:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out, lse, *_ = flash_attn_3_func(
+ q=query,
+ k=key,
+ v=value,
+ softmax_scale=scale,
+ causal=is_causal,
+ qv=None,
+ q_descale=None,
+ k_descale=None,
+ v_descale=None,
+ window_size=window_size,
+ attention_chunk=0,
+ softcap=softcap,
+ num_splits=1,
+ pack_gqa=None,
+ deterministic=deterministic,
+ sm_margin=0,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return (out, lse) if return_attn_probs else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._FLASH_VARLEN_3,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_varlen_attention_3(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_k: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_k: Optional[int] = None,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ deterministic: bool = False,
+ return_attn_probs: bool = False,
+ attn_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ batch_size, _, seq_len_q, _ = query.shape
+ _, _, seq_len_kv, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+ else:
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ out, lse, *_ = flash_attn_3_varlen_func(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ seqused_q=None,
+ seqused_k=None,
+ softmax_scale=scale,
+ causal=is_causal,
+ qv=None,
+ q_descale=None,
+ k_descale=None,
+ v_descale=None,
+ window_size=window_size,
+ softcap=softcap,
+ num_splits=1,
+ pack_gqa=None,
+ deterministic=deterministic,
+ sm_margin=0,
+ )
+ out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3)
+
+ return (out, lse) if return_attn_probs else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.FLEX,
+ constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
+)
+def _native_flex_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ kernel_options: Optional[Dict[str, Any]] = None,
+) -> torch.Tensor:
+ # TODO: should we LRU cache the block mask creation?
+ score_mod = None
+ block_mask = None
+ batch_size, num_heads, seq_len_q, _ = query.shape
+ _, _, seq_len_kv, _ = key.shape
+
+ if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask):
+ block_mask = attn_mask
+ elif is_causal:
+ block_mask = flex_attention.create_block_mask(
+ _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device
+ )
+ elif torch.is_tensor(attn_mask):
+ if attn_mask.ndim == 2:
+ attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
+
+ attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv)
+
+ if attn_mask.dtype == torch.bool:
+ # TODO: this probably does not work but verify!
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
+ return attn_mask[batch_idx, head_idx, q_idx, kv_idx]
+
+ block_mask = flex_attention.create_block_mask(
+ mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device
+ )
+ else:
+
+ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
+ return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx]
+ else:
+ raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.")
+
+ return flex_attention.flex_attention(
+ query=query,
+ key=key,
+ value=value,
+ score_mod=score_mod,
+ block_mask=block_mask,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ return_lse=return_lse,
+ kernel_options=kernel_options,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.NATIVE,
+ constraints=[_check_device, _check_shape],
+)
+def _native_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ return torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_CUDNN,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _native_cudnn_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
+ return torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_EFFICIENT,
+ constraints=[_check_device, _check_shape],
+)
+def _native_efficient_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
+ return torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_FLASH,
+ constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _native_flash_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
+ return torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_MATH,
+ constraints=[_check_device, _check_shape],
+)
+def _native_math_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
+ return torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_NPU,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _native_npu_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+) -> torch.Tensor:
+ return npu_fusion_attention(
+ query,
+ key,
+ value,
+ query.size(1), # num_heads
+ input_layout="BNSD",
+ pse=None,
+ scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
+ pre_tockens=65536,
+ next_tokens=65536,
+ keep_prob=1.0 - dropout_p,
+ sync=False,
+ inner_precise=0,
+ )[0]
+
+
+# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_XLA,
+ constraints=[_check_device, _check_shape],
+)
+def _native_xla_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+) -> torch.Tensor:
+ query = query / math.sqrt(query.shape[-1])
+ return xla_flash_attention(
+ q=query,
+ k=key,
+ v=value,
+ causal=is_causal,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.SAGE,
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _sage_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="HND",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.SAGE_VARLEN,
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _sage_varlen_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_k: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_k: Optional[int] = None,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ smooth_k: bool = True,
+ attn_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ batch_size, _, seq_len_q, _ = query.shape
+ _, _, seq_len_kv, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+ else:
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ out = sageattn_varlen(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ is_causal=is_causal,
+ sm_scale=scale,
+ smooth_k=smooth_k,
+ )
+ out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3)
+
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
+ constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp8_cuda_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
+ smooth_k: bool = True,
+ smooth_v: bool = False,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp8_cuda(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="HND",
+ is_causal=is_causal,
+ qk_quant_gran=qk_quant_gran,
+ sm_scale=scale,
+ pv_accum_dtype=pv_accum_dtype,
+ smooth_k=smooth_k,
+ smooth_v=smooth_v,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
+ constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
+ smooth_k: bool = True,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="HND",
+ is_causal=is_causal,
+ qk_quant_gran=qk_quant_gran,
+ sm_scale=scale,
+ pv_accum_dtype=pv_accum_dtype,
+ smooth_k=smooth_k,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
+ constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp16_cuda_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32",
+ smooth_k: bool = True,
+ smooth_v: bool = False,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp16_cuda(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="HND",
+ is_causal=is_causal,
+ qk_quant_gran=qk_quant_gran,
+ sm_scale=scale,
+ pv_accum_dtype=pv_accum_dtype,
+ smooth_k=smooth_k,
+ smooth_v=smooth_v,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
+ constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp16_triton_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
+ smooth_k: bool = True,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp16_triton(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="HND",
+ quantization_backend=quantization_backend,
+ is_causal=is_causal,
+ sm_scale=scale,
+ smooth_k=smooth_k,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.XFORMERS,
+ constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
+)
+def _xformers_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ batch_size, num_heads_q, seq_len_q, _ = query.shape
+ _, num_heads_kv, seq_len_kv, _ = key.shape
+
+ # TODO: check if `contiguous` is really needed since it may cause unnecessary slowdowns
+ if is_causal:
+ attn_mask = xops.LowerTriangularMask()
+ elif attn_mask is not None:
+ if attn_mask.ndim == 2:
+ attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
+ elif attn_mask.ndim != 4:
+ raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
+ attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
+
+ # QKV need to be in [batch, seq_len, num_heads, head_dim] format for xformers
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+
+ if enable_gqa:
+ if num_heads_q % num_heads_kv != 0:
+ raise ValueError("Number of heads in query must be divisible by number of heads in key/value.")
+ num_heads_per_group = num_heads_q // num_heads_kv
+ query = query.unflatten(2, (num_heads_kv, -1))
+ key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
+ value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
+
+ out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale)
+ if enable_gqa:
+ out = out.flatten(2, 3)
+ out = out.permute(0, 2, 1, 3)
+ return out
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 23ae05e2ab..6af7734151 100755
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import math
from typing import Callable, List, Optional, Tuple, Union
@@ -22,13 +21,13 @@ from torch import nn
from ..image_processor import IPAdapterMaskProcessor
from ..utils import deprecate, is_torch_xla_available, logging
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available
-from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
+from .attention_dispatch import dispatch_attention_fn
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_torch_npu_available():
- import torch_npu
+ pass
if is_xformers_available():
import xformers
@@ -39,67 +38,15 @@ else:
if is_torch_xla_available():
# flash attention pallas kernel is introduced in the torch_xla 2.3 release.
if is_torch_xla_version(">", "2.2"):
- from torch_xla.experimental.custom_kernel import flash_attention
from torch_xla.runtime import is_spmd
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
-@maybe_allow_in_graph
-class Attention(nn.Module):
+class AttnProcessor:
r"""
- A cross attention layer.
-
- Parameters:
- query_dim (`int`):
- The number of channels in the query.
- cross_attention_dim (`int`, *optional*):
- The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
- heads (`int`, *optional*, defaults to 8):
- The number of heads to use for multi-head attention.
- kv_heads (`int`, *optional*, defaults to `None`):
- The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
- `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
- Query Attention (MQA) otherwise GQA is used.
- dim_head (`int`, *optional*, defaults to 64):
- The number of channels in each head.
- dropout (`float`, *optional*, defaults to 0.0):
- The dropout probability to use.
- bias (`bool`, *optional*, defaults to False):
- Set to `True` for the query, key, and value linear layers to contain a bias parameter.
- upcast_attention (`bool`, *optional*, defaults to False):
- Set to `True` to upcast the attention computation to `float32`.
- upcast_softmax (`bool`, *optional*, defaults to False):
- Set to `True` to upcast the softmax computation to `float32`.
- cross_attention_norm (`str`, *optional*, defaults to `None`):
- The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
- cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
- The number of groups to use for the group norm in the cross attention.
- added_kv_proj_dim (`int`, *optional*, defaults to `None`):
- The number of channels to use for the added key and value projections. If `None`, no projection is used.
- norm_num_groups (`int`, *optional*, defaults to `None`):
- The number of groups to use for the group norm in the attention.
- spatial_norm_dim (`int`, *optional*, defaults to `None`):
- The number of channels to use for the spatial normalization.
- out_bias (`bool`, *optional*, defaults to `True`):
- Set to `True` to use a bias in the output linear layer.
- scale_qk (`bool`, *optional*, defaults to `True`):
- Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
- only_cross_attention (`bool`, *optional*, defaults to `False`):
- Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
- `added_kv_proj_dim` is not `None`.
- eps (`float`, *optional*, defaults to 1e-5):
- An additional value added to the denominator in group normalization that is used for numerical stability.
- rescale_output_factor (`float`, *optional*, defaults to 1.0):
- A factor to rescale the output by dividing it with this value.
- residual_connection (`bool`, *optional*, defaults to `False`):
- Set to `True` to add the residual connection to the output.
- _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
- Set to `True` if the attention block is loaded from a deprecated state dict.
- processor (`AttnProcessor`, *optional*, defaults to `None`):
- The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
- `AttnProcessor` otherwise.
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(
@@ -927,421 +874,7 @@ class SanaMultiscaleLinearAttention(nn.Module):
return self.processor(self, hidden_states)
-class MochiAttention(nn.Module):
- def __init__(
- self,
- query_dim: int,
- added_kv_proj_dim: int,
- processor: "MochiAttnProcessor2_0",
- heads: int = 8,
- dim_head: int = 64,
- dropout: float = 0.0,
- bias: bool = False,
- added_proj_bias: bool = True,
- out_dim: Optional[int] = None,
- out_context_dim: Optional[int] = None,
- out_bias: bool = True,
- context_pre_only: bool = False,
- eps: float = 1e-5,
- ):
- super().__init__()
- from .normalization import MochiRMSNorm
-
- self.inner_dim = out_dim if out_dim is not None else dim_head * heads
- self.out_dim = out_dim if out_dim is not None else query_dim
- self.out_context_dim = out_context_dim if out_context_dim else query_dim
- self.context_pre_only = context_pre_only
-
- self.heads = out_dim // dim_head if out_dim is not None else heads
-
- self.norm_q = MochiRMSNorm(dim_head, eps, True)
- self.norm_k = MochiRMSNorm(dim_head, eps, True)
- self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
- self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
-
- self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
- self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
- self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
-
- self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
- self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
- if self.context_pre_only is not None:
- self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
-
- self.to_out = nn.ModuleList([])
- self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
- self.to_out.append(nn.Dropout(dropout))
-
- if not self.context_pre_only:
- self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
-
- self.processor = processor
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- **kwargs,
- ):
- return self.processor(
- self,
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- **kwargs,
- )
-
-
-class MochiAttnProcessor2_0:
- """Attention processor used in Mochi."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: "MochiAttention",
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- query = query.unflatten(2, (attn.heads, -1))
- key = key.unflatten(2, (attn.heads, -1))
- value = value.unflatten(2, (attn.heads, -1))
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- encoder_query = attn.add_q_proj(encoder_hidden_states)
- encoder_key = attn.add_k_proj(encoder_hidden_states)
- encoder_value = attn.add_v_proj(encoder_hidden_states)
-
- encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
- encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
- encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
-
- if attn.norm_added_q is not None:
- encoder_query = attn.norm_added_q(encoder_query)
- if attn.norm_added_k is not None:
- encoder_key = attn.norm_added_k(encoder_key)
-
- if image_rotary_emb is not None:
-
- def apply_rotary_emb(x, freqs_cos, freqs_sin):
- x_even = x[..., 0::2].float()
- x_odd = x[..., 1::2].float()
-
- cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
- sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
-
- return torch.stack([cos, sin], dim=-1).flatten(-2)
-
- query = apply_rotary_emb(query, *image_rotary_emb)
- key = apply_rotary_emb(key, *image_rotary_emb)
-
- query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
- encoder_query, encoder_key, encoder_value = (
- encoder_query.transpose(1, 2),
- encoder_key.transpose(1, 2),
- encoder_value.transpose(1, 2),
- )
-
- sequence_length = query.size(2)
- encoder_sequence_length = encoder_query.size(2)
- total_length = sequence_length + encoder_sequence_length
-
- batch_size, heads, _, dim = query.shape
- attn_outputs = []
- for idx in range(batch_size):
- mask = attention_mask[idx][None, :]
- valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
-
- valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
- valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
- valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
-
- valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
- valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
- valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
-
- attn_output = F.scaled_dot_product_attention(
- valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
- )
- valid_sequence_length = attn_output.size(2)
- attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
- attn_outputs.append(attn_output)
-
- hidden_states = torch.cat(attn_outputs, dim=0)
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
-
- hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
- (sequence_length, encoder_sequence_length), dim=1
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if hasattr(attn, "to_add_out"):
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
-
-
-class AttnProcessor:
- r"""
- Default processor for performing attention-related computations.
- """
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- temb: Optional[torch.Tensor] = None,
- *args,
- **kwargs,
- ) -> torch.Tensor:
- if len(args) > 0 or kwargs.get("scale", None) is not None:
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
- deprecate("scale", "1.0.0", deprecation_message)
-
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- query = attn.head_to_batch_dim(query)
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class CustomDiffusionAttnProcessor(nn.Module):
- r"""
- Processor for implementing attention for the Custom Diffusion method.
-
- Args:
- train_kv (`bool`, defaults to `True`):
- Whether to newly train the key and value matrices corresponding to the text features.
- train_q_out (`bool`, defaults to `True`):
- Whether to newly train query matrices corresponding to the latent image features.
- hidden_size (`int`, *optional*, defaults to `None`):
- The hidden size of the attention layer.
- cross_attention_dim (`int`, *optional*, defaults to `None`):
- The number of channels in the `encoder_hidden_states`.
- out_bias (`bool`, defaults to `True`):
- Whether to include the bias parameter in `train_q_out`.
- dropout (`float`, *optional*, defaults to 0.0):
- The dropout probability to use.
- """
-
- def __init__(
- self,
- train_kv: bool = True,
- train_q_out: bool = True,
- hidden_size: Optional[int] = None,
- cross_attention_dim: Optional[int] = None,
- out_bias: bool = True,
- dropout: float = 0.0,
- ):
- super().__init__()
- self.train_kv = train_kv
- self.train_q_out = train_q_out
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
-
- # `_custom_diffusion` id for easy serialization and loading.
- if self.train_kv:
- self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- if self.train_q_out:
- self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
- self.to_out_custom_diffusion = nn.ModuleList([])
- self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
- self.to_out_custom_diffusion.append(nn.Dropout(dropout))
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- batch_size, sequence_length, _ = hidden_states.shape
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- if self.train_q_out:
- query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
- else:
- query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
-
- if encoder_hidden_states is None:
- crossattn = False
- encoder_hidden_states = hidden_states
- else:
- crossattn = True
- if attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- if self.train_kv:
- key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
- value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
- key = key.to(attn.to_q.weight.dtype)
- value = value.to(attn.to_q.weight.dtype)
- else:
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- if crossattn:
- detach = torch.ones_like(key)
- detach[:, :1, :] = detach[:, :1, :] * 0.0
- key = detach * key + (1 - detach) * key.detach()
- value = detach * value + (1 - detach) * value.detach()
-
- query = attn.head_to_batch_dim(query)
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- if self.train_q_out:
- # linear proj
- hidden_states = self.to_out_custom_diffusion[0](hidden_states)
- # dropout
- hidden_states = self.to_out_custom_diffusion[1](hidden_states)
- else:
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- return hidden_states
-
-
class AttnAddedKVProcessor:
- r"""
- Processor for performing attention-related computations with extra learnable key and value matrices for the text
- encoder.
- """
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- *args,
- **kwargs,
- ) -> torch.Tensor:
- if len(args) > 0 or kwargs.get("scale", None) is not None:
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
- deprecate("scale", "1.0.0", deprecation_message)
-
- residual = hidden_states
-
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
- batch_size, sequence_length, _ = hidden_states.shape
-
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
- query = attn.head_to_batch_dim(query)
-
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
-
- if not attn.only_cross_attention:
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
- else:
- key = encoder_hidden_states_key_proj
- value = encoder_hidden_states_value_proj
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
- hidden_states = hidden_states + residual
-
- return hidden_states
-
-
-class AttnAddedKVProcessor2_0:
r"""
Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
learnable key and value matrices for the text encoder.
@@ -1349,9 +882,7 @@ class AttnAddedKVProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
+ raise ImportError("AttnAddedKVProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
@@ -1417,12 +948,12 @@ class AttnAddedKVProcessor2_0:
return hidden_states
-class JointAttnProcessor2_0:
+class JointAttnProcessor:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ raise ImportError("JointAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
@@ -1503,14 +1034,12 @@ class JointAttnProcessor2_0:
return hidden_states
-class PAGJointAttnProcessor2_0:
+class PAGJointAttnProcessor:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
+ raise ImportError("PAGJointAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
@@ -1659,13 +1188,13 @@ class PAGJointAttnProcessor2_0:
return hidden_states, encoder_hidden_states
-class PAGCFGJointAttnProcessor2_0:
+class PAGCFGJointAttnProcessor:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
- "PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ "PAGCFGJointAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
@@ -1824,85 +1353,6 @@ class PAGCFGJointAttnProcessor2_0:
return hidden_states, encoder_hidden_states
-class FusedJointAttnProcessor2_0:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- *args,
- **kwargs,
- ) -> torch.FloatTensor:
- residual = hidden_states
-
- input_ndim = hidden_states.ndim
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
- context_input_ndim = encoder_hidden_states.ndim
- if context_input_ndim == 4:
- batch_size, channel, height, width = encoder_hidden_states.shape
- encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size = encoder_hidden_states.shape[0]
-
- # `sample` projections.
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
-
- # `context` projections.
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
- split_size = encoder_qkv.shape[-1] // 3
- (
- encoder_hidden_states_query_proj,
- encoder_hidden_states_key_proj,
- encoder_hidden_states_value_proj,
- ) = torch.split(encoder_qkv, split_size, dim=-1)
-
- # attention
- query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
- key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
- value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # Split the attention outputs.
- hidden_states, encoder_hidden_states = (
- hidden_states[:, : residual.shape[1]],
- hidden_states[:, residual.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- if not attn.context_pre_only:
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
- if context_input_ndim == 4:
- encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- return hidden_states, encoder_hidden_states
-
-
class XFormersJointAttnProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
@@ -1988,1575 +1438,16 @@ class XFormersJointAttnProcessor:
return hidden_states
-class AllegroAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
- used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector.
- """
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- temb: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # Apply RoPE if needed
- if image_rotary_emb is not None and not attn.is_cross_attention:
- from .embeddings import apply_rotary_emb_allegro
-
- query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
- key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class AuraFlowAttnProcessor2_0:
- """Attention processor used typically in processing Aura Flow."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
- raise ImportError(
- "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- *args,
- **kwargs,
- ) -> torch.FloatTensor:
- batch_size = hidden_states.shape[0]
-
- # `sample` projections.
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- # `context` projections.
- if encoder_hidden_states is not None:
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- # Reshape.
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
- query = query.view(batch_size, -1, attn.heads, head_dim)
- key = key.view(batch_size, -1, attn.heads, head_dim)
- value = value.view(batch_size, -1, attn.heads, head_dim)
-
- # Apply QK norm.
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # Concatenate the projections.
- if encoder_hidden_states is not None:
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- )
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- )
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
-
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
-
- query = query.transpose(1, 2)
- key = key.transpose(1, 2)
- value = value.transpose(1, 2)
-
- # Attention.
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
- )
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # Split the attention outputs.
- if encoder_hidden_states is not None:
- hidden_states, encoder_hidden_states = (
- hidden_states[:, encoder_hidden_states.shape[1] :],
- hidden_states[:, : encoder_hidden_states.shape[1]],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- if encoder_hidden_states is not None:
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- if encoder_hidden_states is not None:
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FusedAuraFlowAttnProcessor2_0:
- """Attention processor used typically in processing Aura Flow with fused projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
- raise ImportError(
- "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- *args,
- **kwargs,
- ) -> torch.FloatTensor:
- batch_size = hidden_states.shape[0]
-
- # `sample` projections.
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
-
- # `context` projections.
- if encoder_hidden_states is not None:
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
- split_size = encoder_qkv.shape[-1] // 3
- (
- encoder_hidden_states_query_proj,
- encoder_hidden_states_key_proj,
- encoder_hidden_states_value_proj,
- ) = torch.split(encoder_qkv, split_size, dim=-1)
-
- # Reshape.
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
- query = query.view(batch_size, -1, attn.heads, head_dim)
- key = key.view(batch_size, -1, attn.heads, head_dim)
- value = value.view(batch_size, -1, attn.heads, head_dim)
-
- # Apply QK norm.
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # Concatenate the projections.
- if encoder_hidden_states is not None:
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- )
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- )
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
-
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
-
- query = query.transpose(1, 2)
- key = key.transpose(1, 2)
- value = value.transpose(1, 2)
-
- # Attention.
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
- )
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # Split the attention outputs.
- if encoder_hidden_states is not None:
- hidden_states, encoder_hidden_states = (
- hidden_states[:, encoder_hidden_states.shape[1] :],
- hidden_states[:, : encoder_hidden_states.shape[1]],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- if encoder_hidden_states is not None:
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- if encoder_hidden_states is not None:
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FluxAttnProcessor2_0:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FluxAttnProcessor2_0_NPU:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- if query.dtype in (torch.float16, torch.bfloat16):
- hidden_states = torch_npu.npu_fusion_attention(
- query,
- key,
- value,
- attn.heads,
- input_layout="BNSD",
- pse=None,
- scale=1.0 / math.sqrt(query.shape[-1]),
- pre_tockens=65536,
- next_tockens=65536,
- keep_prob=1.0,
- sync=False,
- inner_precise=0,
- )[0]
- else:
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FusedFluxAttnProcessor2_0:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- # `context` projections.
- if encoder_hidden_states is not None:
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
- split_size = encoder_qkv.shape[-1] // 3
- (
- encoder_hidden_states_query_proj,
- encoder_hidden_states_key_proj,
- encoder_hidden_states_value_proj,
- ) = torch.split(encoder_qkv, split_size, dim=-1)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FusedFluxAttnProcessor2_0_NPU:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- # `context` projections.
- if encoder_hidden_states is not None:
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
- split_size = encoder_qkv.shape[-1] // 3
- (
- encoder_hidden_states_query_proj,
- encoder_hidden_states_key_proj,
- encoder_hidden_states_value_proj,
- ) = torch.split(encoder_qkv, split_size, dim=-1)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- if query.dtype in (torch.float16, torch.bfloat16):
- hidden_states = torch_npu.npu_fusion_attention(
- query,
- key,
- value,
- attn.heads,
- input_layout="BNSD",
- pse=None,
- scale=1.0 / math.sqrt(query.shape[-1]),
- pre_tockens=65536,
- next_tockens=65536,
- keep_prob=1.0,
- sync=False,
- inner_precise=0,
- )[0]
- else:
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
- """Flux Attention processor for IP-Adapter."""
-
- def __init__(
- self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
- ):
- super().__init__()
-
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
-
- if not isinstance(num_tokens, (tuple, list)):
- num_tokens = [num_tokens]
-
- if not isinstance(scale, list):
- scale = [scale] * len(num_tokens)
- if len(scale) != len(num_tokens):
- raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
- self.scale = scale
-
- self.to_k_ip = nn.ModuleList(
- [
- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
- for _ in range(len(num_tokens))
- ]
- )
- self.to_v_ip = nn.ModuleList(
- [
- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
- for _ in range(len(num_tokens))
- ]
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ip_hidden_states: Optional[List[torch.Tensor]] = None,
- ip_adapter_masks: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- hidden_states_query_proj = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- # IP-adapter
- ip_query = hidden_states_query_proj
- ip_attn_output = torch.zeros_like(hidden_states)
-
- for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
- ):
- ip_key = to_k_ip(current_ip_hidden_states)
- ip_value = to_v_ip(current_ip_hidden_states)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- current_ip_hidden_states = F.scaled_dot_product_attention(
- ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
- current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
- batch_size, -1, attn.heads * head_dim
- )
- current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
- ip_attn_output += scale * current_ip_hidden_states
-
- return hidden_states, encoder_hidden_states, ip_attn_output
- else:
- return hidden_states
-
-
-class CogVideoXAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
- query and key vectors, but does not include spatial normalization.
- """
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.size(1)
-
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
- batch_size, sequence_length, _ = hidden_states.shape
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # Apply RoPE if needed
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
- if not attn.is_cross_attention:
- key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states, hidden_states = hidden_states.split(
- [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
- )
- return hidden_states, encoder_hidden_states
-
-
-class FusedCogVideoXAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
- query and key vectors, but does not include spatial normalization.
- """
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.size(1)
-
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # Apply RoPE if needed
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
- if not attn.is_cross_attention:
- key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states, hidden_states = hidden_states.split(
- [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
- )
- return hidden_states, encoder_hidden_states
-
-
-class XFormersAttnAddedKVProcessor:
- r"""
- Processor for implementing memory efficient attention using xFormers.
-
- Args:
- attention_op (`Callable`, *optional*, defaults to `None`):
- The base
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
- operator.
- """
-
- def __init__(self, attention_op: Optional[Callable] = None):
- self.attention_op = attention_op
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
- batch_size, sequence_length, _ = hidden_states.shape
-
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
- query = attn.head_to_batch_dim(query)
-
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
-
- if not attn.only_cross_attention:
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
- else:
- key = encoder_hidden_states_key_proj
- value = encoder_hidden_states_value_proj
-
- hidden_states = xformers.ops.memory_efficient_attention(
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
- )
- hidden_states = hidden_states.to(query.dtype)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
- hidden_states = hidden_states + residual
-
- return hidden_states
-
-
-class XFormersAttnProcessor:
- r"""
- Processor for implementing memory efficient attention using xFormers.
-
- Args:
- attention_op (`Callable`, *optional*, defaults to `None`):
- The base
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
- operator.
- """
-
- def __init__(self, attention_op: Optional[Callable] = None):
- self.attention_op = attention_op
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- temb: Optional[torch.Tensor] = None,
- *args,
- **kwargs,
- ) -> torch.Tensor:
- if len(args) > 0 or kwargs.get("scale", None) is not None:
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
- deprecate("scale", "1.0.0", deprecation_message)
-
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, key_tokens, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
- if attention_mask is not None:
- # expand our mask's singleton query_tokens dimension:
- # [batch*heads, 1, key_tokens] ->
- # [batch*heads, query_tokens, key_tokens]
- # so that it can be added as a bias onto the attention scores that xformers computes:
- # [batch*heads, query_tokens, key_tokens]
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
- _, query_tokens, _ = hidden_states.shape
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- query = attn.head_to_batch_dim(query).contiguous()
- key = attn.head_to_batch_dim(key).contiguous()
- value = attn.head_to_batch_dim(value).contiguous()
-
- hidden_states = xformers.ops.memory_efficient_attention(
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
- )
- hidden_states = hidden_states.to(query.dtype)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class AttnProcessorNPU:
- r"""
- Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
- fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
- not significant.
-
- """
-
- def __init__(self):
- if not is_torch_npu_available():
- raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- temb: Optional[torch.Tensor] = None,
- *args,
- **kwargs,
- ) -> torch.Tensor:
- if len(args) > 0 or kwargs.get("scale", None) is not None:
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
- deprecate("scale", "1.0.0", deprecation_message)
-
- residual = hidden_states
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
- attention_mask = attention_mask.repeat(1, 1, hidden_states.shape[1], 1)
- if attention_mask.dtype == torch.bool:
- attention_mask = torch.logical_not(attention_mask.bool())
- else:
- attention_mask = attention_mask.bool()
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- if query.dtype in (torch.float16, torch.bfloat16):
- hidden_states = torch_npu.npu_fusion_attention(
- query,
- key,
- value,
- attn.heads,
- input_layout="BNSD",
- pse=None,
- atten_mask=attention_mask,
- scale=1.0 / math.sqrt(query.shape[-1]),
- pre_tockens=65536,
- next_tockens=65536,
- keep_prob=1.0,
- sync=False,
- inner_precise=0,
- )[0]
- else:
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class AttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
- """
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- temb: Optional[torch.Tensor] = None,
- *args,
- **kwargs,
- ) -> torch.Tensor:
- if len(args) > 0 or kwargs.get("scale", None) is not None:
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
- deprecate("scale", "1.0.0", deprecation_message)
-
- residual = hidden_states
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class XLAFlashAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
- """
-
- def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
- if is_torch_xla_version("<", "2.3"):
- raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
- if is_spmd() and is_torch_xla_version("<", "2.4"):
- raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
- self.partition_spec = partition_spec
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- temb: Optional[torch.Tensor] = None,
- *args,
- **kwargs,
- ) -> torch.Tensor:
- residual = hidden_states
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- if all(tensor.shape[2] >= 4096 for tensor in [query, key, value]):
- if attention_mask is not None:
- attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
- # Convert mask to float and replace 0s with -inf and 1s with 0
- attention_mask = (
- attention_mask.float()
- .masked_fill(attention_mask == 0, float("-inf"))
- .masked_fill(attention_mask == 1, float(0.0))
- )
-
- # Apply attention mask to key
- key = key + attention_mask
- query /= math.sqrt(query.shape[3])
- partition_spec = self.partition_spec if is_spmd() else None
- hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec)
- else:
- logger.warning(
- "Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
- )
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class XLAFluxFlashAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
- """
-
- def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
- if is_torch_xla_version("<", "2.3"):
- raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
- if is_spmd() and is_torch_xla_version("<", "2.4"):
- raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
- self.partition_spec = partition_spec
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- query /= math.sqrt(head_dim)
- hidden_states = flash_attention(query, key, value, causal=False)
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class MochiVaeAttnProcessor2_0:
+class MochiVaeAttnProcessor:
r"""
Attention processor used in Mochi VAE.
"""
+ _attention_backend = None
+
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ raise ImportError("AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
@@ -3614,8 +1505,14 @@ class MochiVaeAttnProcessor2_0:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=attn.is_causal,
+ backend=self._attention_backend,
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@@ -3634,7 +1531,7 @@ class MochiVaeAttnProcessor2_0:
return hidden_states
-class StableAudioAttnProcessor2_0:
+class StableAudioAttnProcessor:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
@@ -3643,7 +1540,7 @@ class StableAudioAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
- "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ "StableAudioAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def apply_partial_rotary_emb(
@@ -3767,105 +1664,7 @@ class StableAudioAttnProcessor2_0:
return hidden_states
-class HunyuanAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
- used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
- """
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- temb: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- from .embeddings import apply_rotary_emb
-
- residual = hidden_states
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # Apply RoPE if needed
- if image_rotary_emb is not None:
- query = apply_rotary_emb(query, image_rotary_emb)
- if not attn.is_cross_attention:
- key = apply_rotary_emb(key, image_rotary_emb)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class FusedHunyuanAttnProcessor2_0:
+class FusedHunyuanAttnProcessor:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
@@ -3875,7 +1674,7 @@ class FusedHunyuanAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
- "FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ "FusedHunyuanAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
@@ -3968,7 +1767,7 @@ class FusedHunyuanAttnProcessor2_0:
return hidden_states
-class PAGHunyuanAttnProcessor2_0:
+class PAGHunyuanAttnProcessor:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
@@ -3978,7 +1777,7 @@ class PAGHunyuanAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
- "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ "PAGHunyuanAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
@@ -4091,7 +1890,7 @@ class PAGHunyuanAttnProcessor2_0:
return hidden_states
-class PAGCFGHunyuanAttnProcessor2_0:
+class PAGCFGHunyuanAttnProcessor:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
@@ -4101,7 +1900,7 @@ class PAGCFGHunyuanAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
- "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ "PAGCFGHunyuanAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
@@ -4215,7 +2014,7 @@ class PAGCFGHunyuanAttnProcessor2_0:
return hidden_states
-class LuminaAttnProcessor2_0:
+class LuminaAttnProcessor:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
@@ -4223,7 +2022,7 @@ class LuminaAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ raise ImportError("AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
@@ -4311,7 +2110,7 @@ class LuminaAttnProcessor2_0:
return hidden_states
-class FusedAttnProcessor2_0:
+class FusedAttnProcessor:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
@@ -4327,7 +2126,7 @@ class FusedAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
- "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
+ "FusedAttnProcessor requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
)
def __call__(
@@ -4533,7 +2332,7 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
return hidden_states
-class CustomDiffusionAttnProcessor2_0(nn.Module):
+class CustomDiffusionAttnProcessor(nn.Module):
r"""
Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
dot-product attention.
@@ -4855,207 +2654,7 @@ class SpatialNorm(nn.Module):
return new_f
-class IPAdapterAttnProcessor(nn.Module):
- r"""
- Attention processor for Multiple IP-Adapters.
-
- Args:
- hidden_size (`int`):
- The hidden size of the attention layer.
- cross_attention_dim (`int`):
- The number of channels in the `encoder_hidden_states`.
- num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
- The context length of the image features.
- scale (`float` or List[`float`], defaults to 1.0):
- the weight scale of image prompt.
- """
-
- def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
- super().__init__()
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
-
- if not isinstance(num_tokens, (tuple, list)):
- num_tokens = [num_tokens]
- self.num_tokens = num_tokens
-
- if not isinstance(scale, list):
- scale = [scale] * len(num_tokens)
- if len(scale) != len(num_tokens):
- raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
- self.scale = scale
-
- self.to_k_ip = nn.ModuleList(
- [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
- )
- self.to_v_ip = nn.ModuleList(
- [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- temb: Optional[torch.Tensor] = None,
- scale: float = 1.0,
- ip_adapter_masks: Optional[torch.Tensor] = None,
- ):
- residual = hidden_states
-
- # separate ip_hidden_states from encoder_hidden_states
- if encoder_hidden_states is not None:
- if isinstance(encoder_hidden_states, tuple):
- encoder_hidden_states, ip_hidden_states = encoder_hidden_states
- else:
- deprecation_message = (
- "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
- " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
- )
- deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- [encoder_hidden_states[:, end_pos:, :]],
- )
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- query = attn.head_to_batch_dim(query)
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- if ip_adapter_masks is not None:
- if not isinstance(ip_adapter_masks, List):
- # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
- ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
- if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
- raise ValueError(
- f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
- f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
- f"({len(ip_hidden_states)})"
- )
- else:
- for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
- if mask is None:
- continue
- if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
- raise ValueError(
- "Each element of the ip_adapter_masks array should be a tensor with shape "
- "[1, num_images_for_ip_adapter, height, width]."
- " Please use `IPAdapterMaskProcessor` to preprocess your mask"
- )
- if mask.shape[1] != ip_state.shape[1]:
- raise ValueError(
- f"Number of masks ({mask.shape[1]}) does not match "
- f"number of ip images ({ip_state.shape[1]}) at index {index}"
- )
- if isinstance(scale, list) and not len(scale) == mask.shape[1]:
- raise ValueError(
- f"Number of masks ({mask.shape[1]}) does not match "
- f"number of scales ({len(scale)}) at index {index}"
- )
- else:
- ip_adapter_masks = [None] * len(self.scale)
-
- # for ip-adapter
- for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
- ):
- skip = False
- if isinstance(scale, list):
- if all(s == 0 for s in scale):
- skip = True
- elif scale == 0:
- skip = True
- if not skip:
- if mask is not None:
- if not isinstance(scale, list):
- scale = [scale] * mask.shape[1]
-
- current_num_images = mask.shape[1]
- for i in range(current_num_images):
- ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
- ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
-
- ip_key = attn.head_to_batch_dim(ip_key)
- ip_value = attn.head_to_batch_dim(ip_value)
-
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
- _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
- _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
-
- mask_downsample = IPAdapterMaskProcessor.downsample(
- mask[:, i, :, :],
- batch_size,
- _current_ip_hidden_states.shape[1],
- _current_ip_hidden_states.shape[2],
- )
-
- mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
-
- hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
- else:
- ip_key = to_k_ip(current_ip_hidden_states)
- ip_value = to_v_ip(current_ip_hidden_states)
-
- ip_key = attn.head_to_batch_dim(ip_key)
- ip_value = attn.head_to_batch_dim(ip_value)
-
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
- current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
- current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
-
- hidden_states = hidden_states + scale * current_ip_hidden_states
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class IPAdapterAttnProcessor2_0(torch.nn.Module):
+class IPAdapterAttnProcessor(torch.nn.Module):
r"""
Attention processor for IP-Adapter for PyTorch 2.0.
@@ -5519,7 +3118,7 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
return hidden_states
-class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
+class SD3IPAdapterJointAttnProcessor(torch.nn.Module):
"""
Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
additional image-based information and timestep embeddings.
@@ -5690,7 +3289,7 @@ class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
return hidden_states
-class PAGIdentitySelfAttnProcessor2_0:
+class PAGIdentitySelfAttnProcessor:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
PAG reference: https://huggingface.co/papers/2403.17377
@@ -5789,7 +3388,7 @@ class PAGIdentitySelfAttnProcessor2_0:
return hidden_states
-class PAGCFGIdentitySelfAttnProcessor2_0:
+class PAGCFGIdentitySelfAttnProcessor:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
PAG reference: https://huggingface.co/papers/2403.17377
@@ -5892,7 +3491,7 @@ class PAGCFGIdentitySelfAttnProcessor2_0:
return hidden_states
-class SanaMultiscaleAttnProcessor2_0:
+class SanaMultiscaleAttnProcessor:
r"""
Processor for implementing multiscale quadratic attention.
"""
@@ -5961,7 +3560,7 @@ class LoRAAttnProcessor:
pass
-class LoRAAttnProcessor2_0:
+class LoRAAttnProcessor:
r"""
Processor for implementing attention with LoRA (enabled by default if you're using PyTorch 2.0).
"""
@@ -5988,18 +3587,7 @@ class LoRAAttnAddedKVProcessor:
pass
-class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
- r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
- """
-
- def __init__(self):
- deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
- deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
- super().__init__()
-
-
-class SanaLinearAttnProcessor2_0:
+class SanaLinearAttnProcessor:
r"""
Processor for implementing scaled dot-product linear attention.
"""
@@ -6051,7 +3639,7 @@ class SanaLinearAttnProcessor2_0:
return hidden_states
-class PAGCFGSanaLinearAttnProcessor2_0:
+class PAGCFGSanaLinearAttnProcessor:
r"""
Processor for implementing scaled dot-product linear attention.
"""
@@ -6106,7 +3694,7 @@ class PAGCFGSanaLinearAttnProcessor2_0:
return hidden_states
-class PAGIdentitySanaLinearAttnProcessor2_0:
+class PAGIdentitySanaLinearAttnProcessor:
r"""
Processor for implementing scaled dot-product linear attention.
"""
@@ -6163,73 +3751,368 @@ class PAGIdentitySanaLinearAttnProcessor2_0:
return hidden_states
+# Deprecated classes for backward compatibility
+
+
+class AttnProcessor:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = (
+ "`AttnProcessor` is deprecated and this will be removed in a future version. Please use `AttnProcessor`"
+ )
+ deprecate("AttnProcessor", "1.0.0", deprecation_message)
+
+ return AttnProcessor(*args, **kwargs)
+
+
+class AttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = (
+ "`AttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AttnProcessor`"
+ )
+ deprecate("AttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return AttnProcessor(*args, **kwargs)
+
+
+class AttnAddedKVProcessor:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`AttnAddedKVAttentionProcessor` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessor`"
+ deprecate("AttnAddedKVAttentionProcessor", "1.0.0", deprecation_message)
+
+ return AttnAddedKVProcessor(*args, **kwargs)
+
+
+class AttnAddedKVProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`AttnAddedKVAttentionProcessor` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessor`"
+ deprecate("AttnAddedKVAttentionProcessor", "1.0.0", deprecation_message)
+
+ return AttnAddedKVProcessor(*args, **kwargs)
+
+
+class AllegroAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`AllegroAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AllegroAttnProcessor`"
+ deprecate("AllegroAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return AllegroAttnProcessor(*args, **kwargs)
+
+
+class AuraFlowAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`AuraFlowAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessor`"
+ deprecate("AuraFlowAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return AuraFlowAttnProcessor(*args, **kwargs)
+
+
+class MochiAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`MochiAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `MochiAttnProcessor`"
+ deprecate("MochiAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_mochi import MochiAttnProcessor
+
+ return MochiAttnProcessor(*args, **kwargs)
+
+
+class MochiVaeAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`MochiVaeAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `MochiVaeAttnProcessor`"
+ deprecate("MochiVaeAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .autoencoders.autoencoder_kl_mochi import MochiVaeAttnProcessor
+
+ return MochiVaeAttnProcessor(*args, **kwargs)
+
+
+class FluxAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
+ deprecate("FluxAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ return FluxAttnProcessor(*args, **kwargs)
+
+
+class FluxSingleAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FluxSingleAttnProcessor` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
+ deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_allegro import FluxAttnProcessor
+
+ return FluxAttnProcessor(*args, **kwargs)
+
+
+class FusedAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FusedAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AttnProcessor`"
+ deprecate("FusedAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return AttnProcessor(*args, **kwargs)
+
+
+class JointAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`JointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `JointAttnProcessor`"
+ deprecate("JointAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return JointAttnProcessor(*args, **kwargs)
+
+
+class PAGJointAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`PAGJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGJointAttnProcessor`"
+ deprecate("PAGJointAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return PAGJointAttnProcessor(*args, **kwargs)
+
+
+class PAGCFGJointAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`PAGCFGJointAttnProcessor2_0 is deprecated and this will be removed in a future version. Please use `PAGCFGJointAttnProcessor`"
+ deprecate("PAGCFGJointAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return PAGCFGJointAttnProcessor(*args, **kwargs)
+
+
+class FusedJointAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FusedJointAttnProcessor2_0 is deprecated and this will be removed in a future version. Please use `JointAttnProcessor`"
+ deprecate("FusedJointAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return JointAttnProcessor(*args, **kwargs)
+
+
+class FusedAuraFlowAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FusedAuraFlowAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessor`"
+ deprecate("FusedAuraFlowAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return AuraFlowAttnProcessor(*args, **kwargs)
+
+
+class FusedFluxAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FusedFluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
+ deprecate("FusedFluxAttnProcessor2_0", "1.0.0", deprecation_message)
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ return FluxAttnProcessor(*args, **kwargs)
+
+
+class CogVideoXAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`CogVideoXAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `CogVideoXAttnProcessor`"
+ deprecate("CogVideoXAttnProcessor2_0", "1.0.0", deprecation_message)
+ from .transformers.cogvideox_transformer_3d import CogVideoXAttnProcessor
+
+ return CogVideoXAttnProcessor(*args, **kwargs)
+
+
+class FusedCogVideoXAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FusedCogVideoXAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `CogVideoXAttnProcessor`"
+ deprecate("FusedCogVideoXAttnProcessor2_0", "1.0.0", deprecation_message)
+ from .transformers.cogvideox_transformer_3d import CogVideoXAttnProcessor
+
+ return CogVideoXAttnProcessor(*args, **kwargs)
+
+
+class XLAFlashAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`XLAFlashAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `XLAFlashAttnProcessor`"
+ deprecate("XLAFlashAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return XLAFlashAttnProcessor(*args, **kwargs)
+
+
+class XLAFluxFlashAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`XLAFluxFlashAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `XLAFluxFlashAttnProcessor`"
+ deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from transformers.transformer_flux import FluxAttnProcessorXLA
+
+ return FluxAttnProcessorXLA(*args, **kwargs)
+
+
+class StableAudioAttnProcessor2_0:
+ def __new__(self, *args, **kwargs):
+ deprecation_message = "`StableAudioAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `StableAudioAttnProcessor`"
+ deprecate("StableAudioAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return StableAudioAttnProcessor(*args, **kwargs)
+
+
+class HunyuanAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`HunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `HunyuanAttnProcessor`"
+ deprecate("HunyuanAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return HunyuanAttnProcessor(*args, **kwargs)
+
+
+class FusedHunyuanAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FusedHunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedHunyuanAttnProcessor`"
+ deprecate("FusedHunyuanAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return HunyuanAttnProcessor(*args, **kwargs)
+
+
+class PAGHunyuanAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`PAGHunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGHunyuanAttnProcessor`"
+ deprecate("PAGHunyuanAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return PAGHunyuanAttnProcessor(*args, **kwargs)
+
+
+class PAGCFGHunyuanAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`PAGCFGHunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGHunyuanAttnProcessor`"
+ deprecate("PAGCFGHunyuanAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return PAGCFGHunyuanAttnProcessor(*args, **kwargs)
+
+
+class LuminaAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`LuminaAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `LuminaAttnProcessor`"
+ deprecate("LuminaAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return LuminaAttnProcessor(*args, **kwargs)
+
+
+class PAGIdentitySelfAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`PAGIdentitySelfAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGIdentitySelfAttnProcessor`"
+ deprecate("PAGIdentitySelfAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return PAGIdentitySelfAttnProcessor(*args, **kwargs)
+
+
+class PAGCFGIdentitySelfAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`PAGCFGIdentitySelfAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGIdentitySelfAttnProcessor`"
+ deprecate("PAGCFGIdentitySelfAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return PAGCFGIdentitySelfAttnProcessor(*args, **kwargs)
+
+
+class SanaMultiscaleAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`SanaMultiscaleAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `SanaMultiscaleAttnProcessor`"
+ deprecate("SanaMultiscaleAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return SanaMultiscaleAttnProcessor(*args, **kwargs)
+
+
+class LoRAAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`LoRAAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `LoRAAttnProcessor`"
+ deprecate("LoRAAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return LoRAAttnProcessor(*args, **kwargs)
+
+
+class SanaLinearAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`SanaLinearAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `SanaLinearAttnProcessor`"
+ deprecate("SanaLinearAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return SanaLinearAttnProcessor(*args, **kwargs)
+
+
+class PAGCFGSanaLinearAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`PAGCFGSanaLinearAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGSanaLinearAttnProcessor`"
+ deprecate("PAGCFGSanaLinearAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return PAGCFGSanaLinearAttnProcessor(*args, **kwargs)
+
+
+class PAGIdentitySanaLinearAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`PAGIdentitySanaLinearAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGIdentitySanaLinearAttnProcessor`"
+ deprecate("PAGIdentitySanaLinearAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return PAGIdentitySanaLinearAttnProcessor(*args, **kwargs)
+
+
+class IPAdapterAttnProcessor(IPAdapterAttnProcessor):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "`IPAdapterAttnProcessor` is deprecated and this will be removed in a future version. Please use `IPAdapterAttnProcessor`"
+ deprecate("IPAdapterAttnProcessor", "1.0.0", deprecation_message)
+ super().__init__(*args, **kwargs)
+
+
+class IPAdapterAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`IPAdapterAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `IPAdapterAttnProcessor`"
+ deprecate("IPAdapterAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ return IPAdapterAttnProcessor(*args, **kwargs)
+
+
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
- AttnAddedKVProcessor2_0,
+ AttnAddedKVProcessor,
XFormersAttnAddedKVProcessor,
)
CROSS_ATTENTION_PROCESSORS = (
AttnProcessor,
- AttnProcessor2_0,
+ AttnProcessor,
XFormersAttnProcessor,
SlicedAttnProcessor,
IPAdapterAttnProcessor,
- IPAdapterAttnProcessor2_0,
- FluxIPAdapterJointAttnProcessor2_0,
+ IPAdapterAttnProcessor,
)
AttentionProcessor = Union[
- AttnProcessor,
- CustomDiffusionAttnProcessor,
AttnAddedKVProcessor,
- AttnAddedKVProcessor2_0,
- JointAttnProcessor2_0,
- PAGJointAttnProcessor2_0,
- PAGCFGJointAttnProcessor2_0,
- FusedJointAttnProcessor2_0,
- AllegroAttnProcessor2_0,
- AuraFlowAttnProcessor2_0,
- FusedAuraFlowAttnProcessor2_0,
- FluxAttnProcessor2_0,
- FluxAttnProcessor2_0_NPU,
- FusedFluxAttnProcessor2_0,
- FusedFluxAttnProcessor2_0_NPU,
- CogVideoXAttnProcessor2_0,
- FusedCogVideoXAttnProcessor2_0,
+ JointAttnProcessor,
+ PAGJointAttnProcessor,
+ PAGCFGJointAttnProcessor,
+ FusedJointAttnProcessor,
+ FusedAuraFlowAttnProcessor,
+ CogVideoXAttnProcessor,
+ FusedCogVideoXAttnProcessor,
XFormersAttnAddedKVProcessor,
XFormersAttnProcessor,
- XLAFlashAttnProcessor2_0,
+ XLAFlashAttnProcessor,
AttnProcessorNPU,
- AttnProcessor2_0,
- MochiVaeAttnProcessor2_0,
- MochiAttnProcessor2_0,
- StableAudioAttnProcessor2_0,
- HunyuanAttnProcessor2_0,
- FusedHunyuanAttnProcessor2_0,
- PAGHunyuanAttnProcessor2_0,
- PAGCFGHunyuanAttnProcessor2_0,
- LuminaAttnProcessor2_0,
- FusedAttnProcessor2_0,
+ AttnProcessor,
+ MochiVaeAttnProcessor,
+ StableAudioAttnProcessor,
+ FusedHunyuanAttnProcessor,
+ PAGHunyuanAttnProcessor,
+ PAGCFGHunyuanAttnProcessor,
+ LuminaAttnProcessor,
+ FusedAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
- CustomDiffusionAttnProcessor2_0,
+ CustomDiffusionAttnProcessor,
SlicedAttnProcessor,
SlicedAttnAddedKVProcessor,
- SanaLinearAttnProcessor2_0,
- PAGCFGSanaLinearAttnProcessor2_0,
- PAGIdentitySanaLinearAttnProcessor2_0,
- SanaMultiscaleLinearAttention,
- SanaMultiscaleAttnProcessor2_0,
- SanaMultiscaleAttentionProjection,
+ SanaLinearAttnProcessor,
+ PAGCFGSanaLinearAttnProcessor,
+ PAGIdentitySanaLinearAttnProcessor,
+ SanaMultiscaleAttnProcessor,
IPAdapterAttnProcessor,
- IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
- SD3IPAdapterJointAttnProcessor2_0,
- PAGIdentitySelfAttnProcessor2_0,
- PAGCFGIdentitySelfAttnProcessor2_0,
+ SD3IPAdapterJointAttnProcessor,
+ PAGIdentitySelfAttnProcessor,
+ PAGCFGIdentitySelfAttnProcessor,
LoRAAttnProcessor,
- LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
]
diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py
index 541576b13b..4db23b4e84 100644
--- a/src/diffusers/models/transformers/transformer_flux.py
+++ b/src/diffusers/models/transformers/transformer_flux.py
@@ -13,25 +13,19 @@
# limitations under the License.
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
+import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
-from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
-from ...utils.import_utils import is_torch_npu_available
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention import FeedForward
-from ..attention_processor import (
- Attention,
- AttentionProcessor,
- FluxAttnProcessor2_0,
- FluxAttnProcessor2_0_NPU,
- FusedFluxAttnProcessor2_0,
-)
+from ..attention import AttentionMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput
@@ -42,6 +36,270 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+class FluxAttnProcessor:
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
+
+ def _get_projections(self, attn, hidden_states, encoder_hidden_states=None):
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ encoder_projections = None
+ if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"):
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+ encoder_projections = (encoder_query, encoder_key, encoder_value)
+
+ return query, key, value, encoder_projections
+
+ def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None):
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ encoder_projections = None
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
+ split_size = encoder_qkv.shape[-1] // 3
+ encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1)
+ encoder_projections = (encoder_query, encoder_key, encoder_value)
+
+ return query, key, value, encoder_projections
+
+ def get_qkv_projections(self, attn, hidden_states, encoder_hidden_states=None):
+ """Public method to get projections based on whether we're using fused mode or not."""
+ if attn.is_fused and hasattr(attn, "to_qkv"):
+ return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
+
+ return self._get_projections(attn, hidden_states, encoder_hidden_states)
+
+ def __call__(
+ self,
+ attn: "FluxAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ query, key, value, encoder_projections = self.get_qkv_projections(attn, hidden_states, encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ if encoder_projections is not None:
+ encoder_query, encoder_key, encoder_value = encoder_projections
+ encoder_query = encoder_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ encoder_key = encoder_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ encoder_value = encoder_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_query = attn.norm_added_q(encoder_query)
+ if attn.norm_added_k is not None:
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ # Concatenate for joint attention
+ query = torch.cat([encoder_query, query], dim=2)
+ key = torch.cat([encoder_key, key], dim=2)
+ value = torch.cat([encoder_value, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from ..embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class FluxIPAdapterAttnProcessorSDPA(torch.nn.Module):
+ """Flux Attention processor for IP-Adapter."""
+
+ def __init__(
+ self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
+ ):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ if not isinstance(num_tokens, (tuple, list)):
+ num_tokens = [num_tokens]
+
+ if not isinstance(scale, list):
+ scale = [scale] * len(num_tokens)
+ if len(scale) != len(num_tokens):
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
+ self.scale = scale
+
+ self.to_k_ip = nn.ModuleList(
+ [
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
+ for _ in range(len(num_tokens))
+ ]
+ )
+ self.to_v_ip = nn.ModuleList(
+ [
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
+ for _ in range(len(num_tokens))
+ ]
+ )
+
+ def __call__(
+ self,
+ attn: "FluxAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ip_hidden_states: Optional[List[torch.Tensor]] = None,
+ ip_adapter_masks: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ # `sample` projections.
+ hidden_states_query_proj = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ if encoder_hidden_states is not None:
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ # IP-adapter
+ ip_query = hidden_states_query_proj
+ ip_attn_output = torch.zeros_like(hidden_states)
+
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
+ ):
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ current_ip_hidden_states = F.scaled_dot_product_attention(
+ ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+ current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
+ ip_attn_output += scale * current_ip_hidden_states
+
+ return hidden_states, encoder_hidden_states, ip_attn_output
+ else:
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class FluxAttention(nn.Module, Attention):
+ _default_processor_cls = FluxAttnProcessor
+ _available_processors = [
+ FluxAttnProcessor,
+ FluxIPAdapterAttnProcessor,
+ ]
+
+
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
@@ -53,27 +311,12 @@ class FluxSingleTransformerBlock(nn.Module):
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
- if is_torch_npu_available():
- deprecation_message = (
- "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
- "should be set explicitly using the `set_attn_processor` method."
- )
- deprecate("npu_processor", "0.34.0", deprecation_message)
- processor = FluxAttnProcessor2_0_NPU()
- else:
- processor = FluxAttnProcessor2_0()
-
- self.attn = Attention(
+ self.attn = FluxAttention(
query_dim=dim,
- cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
- out_dim=dim,
+ dropout=0.0,
bias=True,
- processor=processor,
- qk_norm="rms_norm",
- eps=1e-6,
- pre_only=True,
)
def forward(
@@ -113,18 +356,15 @@ class FluxTransformerBlock(nn.Module):
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
- self.attn = Attention(
+ # Use specialized FluxAttention instead of generic Attention
+ self.attn = FluxAttention(
query_dim=dim,
cross_attention_dim=None,
- added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
- out_dim=dim,
- context_pre_only=False,
+ dropout=0.0,
bias=True,
- processor=FluxAttnProcessor2_0(),
- qk_norm=qk_norm,
- eps=eps,
+ added_kv_proj_dim=dim,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
@@ -191,7 +431,13 @@ class FluxTransformerBlock(nn.Module):
class FluxTransformer2DModel(
- ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
+ ModelMixin,
+ ConfigMixin,
+ PeftAdapterMixin,
+ FromOriginalModelMixin,
+ FluxTransformer2DLoadersMixin,
+ CacheMixin,
+ AttentionMixin,
):
"""
The Transformer model introduced in Flux.
@@ -286,105 +532,9 @@ class FluxTransformer2DModel(
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
+ # Using inherited methods from AttentionMixin
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedFluxAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
+ # Using inherited methods from AttentionMixin
def forward(
self,