diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index 93b11c2b43..2bcbeb27e9 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -11,36 +11,822 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Callable, Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
-from ..utils import deprecate, logging
+
+# Import xformers only if it's available
+try:
+ import xformers
+ import xformers.ops
+except ImportError:
+ xformers = None
+
+from ..utils import logging
+from ..utils.import_utils import (
+ is_torch_npu_available,
+ is_torch_xla_available,
+ is_xformers_available,
+)
from ..utils.torch_utils import maybe_allow_in_graph
-from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
-from .attention_processor import Attention, JointAttnProcessor2_0
-from .embeddings import SinusoidalPositionalEmbedding
-from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
+from .attention_processor import (
+ Attention,
+ AttentionProcessor,
+ AttnProcessor,
+)
+from .normalization import RMSNorm
logger = logging.get_logger(__name__)
-def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
- # "feed_forward_chunk_size" can be used to save memory
- if hidden_states.shape[chunk_dim] % chunk_size != 0:
- raise ValueError(
- f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
- )
+class AttentionMixin:
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
- num_chunks = hidden_states.shape[chunk_dim] // chunk_size
- ff_output = torch.cat(
- [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
- dim=chunk_dim,
- )
- return ff_output
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, AttentionModuleMixin):
+ module.fuse_projections(fuse=True)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+
+class AttentionModuleMixin:
+ """
+ A mixin class that provides common methods for attention modules.
+
+ This mixin adds functionality to set different attention processors, handle attention masks, compute attention
+ scores, and manage projections.
+ """
+
+ # Default processor classes to be overridden by subclasses
+ default_processor_cls = None
+ _available_processors = []
+
+ def _get_compatible_processor(self, backend):
+ for processor_cls in self._available_processors:
+ if backend in processor_cls.compatible_backends:
+ processor = processor_cls()
+ return processor
+
+ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
+ """
+ Set whether to use NPU flash attention from `torch_npu` or not.
+
+ Args:
+ use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
+ """
+ processor = self.default_processor_cls()
+
+ if use_npu_flash_attention:
+ if not is_torch_npu_available():
+ raise ImportError("torch_npu is not available")
+ processor = self._get_compatible_processor("npu")
+
+ self.set_processor(processor)
+
+ def set_use_xla_flash_attention(
+ self,
+ use_xla_flash_attention: bool,
+ partition_spec: Optional[Tuple[Optional[str], ...]] = None,
+ is_flux=False,
+ ) -> None:
+ """
+ Set whether to use XLA flash attention from `torch_xla` or not.
+
+ Args:
+ use_xla_flash_attention (`bool`):
+ Whether to use pallas flash attention kernel from `torch_xla` or not.
+ partition_spec (`Tuple[]`, *optional*):
+ Specify the partition specification if using SPMD. Otherwise None.
+ is_flux (`bool`, *optional*, defaults to `False`):
+ Whether the model is a Flux model.
+ """
+ processor = self.default_processor_cls()
+ if use_xla_flash_attention:
+ if not is_torch_xla_available():
+ raise ImportError("torch_xla is not available")
+ processor = self._get_compatible_processor("xla")
+
+ self.set_processor(processor)
+
+ @torch.no_grad()
+ def fuse_projections(self, fuse=True):
+ """
+ Fuse the query, key, and value projections into a single projection for efficiency.
+
+ Args:
+ fuse (`bool`): Whether to fuse the projections or not.
+ """
+ # Skip if already in desired state
+ if getattr(self, "fused_projections", False) == fuse:
+ return
+
+ device = self.to_q.weight.data.device
+ dtype = self.to_q.weight.data.dtype
+
+ if not self.is_cross_attention:
+ # Fuse self-attention projections
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_qkv.weight.copy_(concatenated_weights)
+ if self.use_bias:
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ self.to_qkv.bias.copy_(concatenated_bias)
+
+ else:
+ # Fuse cross-attention key-value projections
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_kv.weight.copy_(concatenated_weights)
+ if self.use_bias:
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ self.to_kv.bias.copy_(concatenated_bias)
+
+ # Handle added projections for models like SD3, Flux, etc.
+ if (
+ getattr(self, "add_q_proj", None) is not None
+ and getattr(self, "add_k_proj", None) is not None
+ and getattr(self, "add_v_proj", None) is not None
+ ):
+ concatenated_weights = torch.cat(
+ [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
+ )
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_added_qkv = nn.Linear(
+ in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
+ )
+ self.to_added_qkv.weight.copy_(concatenated_weights)
+ if self.added_proj_bias:
+ concatenated_bias = torch.cat(
+ [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
+ )
+ self.to_added_qkv.bias.copy_(concatenated_bias)
+
+ self.fused_projections = fuse
+ self.processor.is_fused = fuse
+
+ def set_use_memory_efficient_attention_xformers(
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ """
+ Set whether to use memory efficient attention from `xformers` or not.
+
+ Args:
+ use_memory_efficient_attention_xformers (`bool`):
+ Whether to use memory efficient attention from `xformers` or not.
+ attention_op (`Callable`, *optional*):
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
+ `xformers`.
+ """
+ if use_memory_efficient_attention_xformers:
+ if not is_xformers_available():
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
+ " only available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ if xformers is not None:
+ dtype = None
+ if attention_op is not None:
+ op_fw, op_bw = attention_op
+ dtype, *_ = op_fw.SUPPORTED_DTYPES
+ q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
+ _ = xformers.ops.memory_efficient_attention(q, q, q)
+ except Exception as e:
+ raise e
+
+ processor = self._get_compatible_processor("xformers")
+ else:
+ # Set default processor
+ processor = self.default_processor_cls()
+
+ if processor is not None:
+ self.set_processor(processor)
+
+ def set_attention_slice(self, slice_size: int) -> None:
+ """
+ Set the slice size for attention computation.
+
+ Args:
+ slice_size (`int`):
+ The slice size for attention computation.
+ """
+ if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ processor = None
+
+ # Try to get a compatible processor for sliced attention
+ if slice_size is not None:
+ processor = self._get_compatible_processor("sliced")
+
+ # If no processor was found or slice_size is None, use default processor
+ if processor is None:
+ processor = self.default_processor_cls()
+
+ self.set_processor(processor)
+
+ def set_processor(self, processor: "AttnProcessor") -> None:
+ """
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ """
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
+ """
+ Get the attention processor in use.
+
+ Args:
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
+ Set to `True` to return the deprecated LoRA attention processor.
+
+ Returns:
+ "AttentionProcessor": The attention processor in use.
+ """
+ if not return_deprecated_lora:
+ return self.processor
+
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
+ """
+ Reshape the tensor for multi-head attention processing.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ if tensor.ndim == 3:
+ batch_size, seq_len, dim = tensor.shape
+ extra_dim = 1
+ else:
+ batch_size, extra_dim, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3)
+
+ if out_dim == 3:
+ tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
+
+ return tensor
+
+ def get_attention_scores(
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ Compute the attention scores.
+
+ Args:
+ query (`torch.Tensor`): The query tensor.
+ key (`torch.Tensor`): The key tensor.
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
+
+ Returns:
+ `torch.Tensor`: The attention probabilities/scores.
+ """
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key.transpose(-1, -2),
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ del attention_scores
+
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+ def prepare_attention_mask(
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
+ ) -> torch.Tensor:
+ """
+ Prepare the attention mask for the attention computation.
+
+ Args:
+ attention_mask (`torch.Tensor`): The attention mask to prepare.
+ target_length (`int`): The target length of the attention mask.
+ batch_size (`int`): The batch size for repeating the attention mask.
+ out_dim (`int`, *optional*, defaults to `3`): Output dimension.
+
+ Returns:
+ `torch.Tensor`: The prepared attention mask.
+ """
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
+ # remaining_length: int = target_length - current_length
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+ return attention_mask
+
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Normalize the encoder hidden states.
+
+ Args:
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
+
+ Returns:
+ `torch.Tensor`: The normalized encoder hidden states.
+ """
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
+ if isinstance(self.norm_cross, nn.LayerNorm):
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ elif isinstance(self.norm_cross, nn.GroupNorm):
+ # Group norm norms along the channels dimension and expects
+ # input to be in the shape of (N, C, *). In this case, we want
+ # to norm along the hidden dimension, so we need to move
+ # (batch_size, sequence_length, hidden_size) ->
+ # (batch_size, hidden_size, sequence_length)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ else:
+ assert False
+
+ return encoder_hidden_states
+
+
+@maybe_allow_in_graph
+class Attention(nn.Module, AttentionModuleMixin):
+ default_processor_class = AttnProcessorSDPA
+ _available_processors = []
+
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`):
+ The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8):
+ The number of heads to use for multi-head attention.
+ kv_heads (`int`, *optional*, defaults to `None`):
+ The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
+ `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
+ Query Attention (MQA) otherwise GQA is used.
+ dim_head (`int`, *optional*, defaults to 64):
+ The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ upcast_attention (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the attention computation to `float32`.
+ upcast_softmax (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the softmax computation to `float32`.
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the group norm in the cross attention.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ norm_num_groups (`int`, *optional*, defaults to `None`):
+ The number of groups to use for the group norm in the attention.
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the spatial normalization.
+ out_bias (`bool`, *optional*, defaults to `True`):
+ Set to `True` to use a bias in the output linear layer.
+ scale_qk (`bool`, *optional*, defaults to `True`):
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
+ `added_kv_proj_dim` is not `None`.
+ eps (`float`, *optional*, defaults to 1e-5):
+ An additional value added to the denominator in group normalization that is used for numerical stability.
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
+ A factor to rescale the output by dividing it with this value.
+ residual_connection (`bool`, *optional*, defaults to `False`):
+ Set to `True` to add the residual connection to the output.
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
+ Set to `True` if the attention block is loaded from a deprecated state dict.
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
+ The attention processor to use. If `None`, defaults to `AttnProcessorSDPA` if `torch 2.x` is used and
+ `AttnProcessor` otherwise.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ kv_heads: Optional[int] = None,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ cross_attention_norm_num_groups: int = 32,
+ qk_norm: Optional[str] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ norm_num_groups: Optional[int] = None,
+ spatial_norm_dim: Optional[int] = None,
+ out_bias: bool = True,
+ scale_qk: bool = True,
+ only_cross_attention: bool = False,
+ eps: float = 1e-5,
+ rescale_output_factor: float = 1.0,
+ residual_connection: bool = False,
+ _from_deprecated_attn_block: bool = False,
+ processor: Optional["AttnProcessor"] = None,
+ out_dim: int = None,
+ out_context_dim: int = None,
+ context_pre_only=None,
+ pre_only=False,
+ elementwise_affine: bool = True,
+ is_causal: bool = False,
+ ):
+ super().__init__()
+
+ # To prevent circular import.
+ from .normalization import FP32LayerNorm, LpNorm
+
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.is_cross_attention = cross_attention_dim is not None
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.rescale_output_factor = rescale_output_factor
+ self.residual_connection = residual_connection
+ self.dropout = dropout
+ self.fused_projections = False
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+ self.is_causal = is_causal
+
+ # we make use of this private variable to know whether this class is loaded
+ # with an deprecated state dict so that we can convert it on the fly
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
+
+ self.scale_qk = scale_qk
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.only_cross_attention = only_cross_attention
+
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
+ raise ValueError(
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+ )
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
+ else:
+ self.group_norm = None
+
+ if spatial_norm_dim is not None:
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
+ else:
+ self.spatial_norm = None
+
+ if qk_norm is None:
+ self.norm_q = None
+ self.norm_k = None
+ elif qk_norm == "layer_norm":
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ elif qk_norm == "fp32_layer_norm":
+ self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ elif qk_norm == "layer_norm_across_heads":
+ # Lumina applies qk norm across all heads
+ self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
+ self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
+ elif qk_norm == "rms_norm":
+ self.norm_q = RMSNorm(dim_head, eps=eps)
+ self.norm_k = RMSNorm(dim_head, eps=eps)
+ elif qk_norm == "rms_norm_across_heads":
+ # LTX applies qk norm across all heads
+ self.norm_q = RMSNorm(dim_head * heads, eps=eps)
+ self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
+ elif qk_norm == "l2":
+ self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
+ self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
+ else:
+ raise ValueError(
+ f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
+ )
+
+ if cross_attention_norm is None:
+ self.norm_cross = None
+ elif cross_attention_norm == "layer_norm":
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
+ elif cross_attention_norm == "group_norm":
+ if self.added_kv_proj_dim is not None:
+ # The given `encoder_hidden_states` are initially of shape
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
+ # before the projection, so we need to use `added_kv_proj_dim` as
+ # the number of channels for the group norm.
+ norm_cross_num_channels = added_kv_proj_dim
+ else:
+ norm_cross_num_channels = self.cross_attention_dim
+
+ self.norm_cross = nn.GroupNorm(
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
+ )
+ else:
+ raise ValueError(
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
+ )
+
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.only_cross_attention:
+ # only relevant for the `AddedKVProcessor` classes
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ else:
+ self.to_k = None
+ self.to_v = None
+
+ self.added_proj_bias = added_proj_bias
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
+ if self.context_pre_only is not None:
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ else:
+ self.add_q_proj = None
+ self.add_k_proj = None
+ self.add_v_proj = None
+
+ if not self.pre_only:
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(dropout))
+ else:
+ self.to_out = None
+
+ if self.context_pre_only is not None and not self.context_pre_only:
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
+ else:
+ self.to_add_out = None
+
+ if qk_norm is not None and added_kv_proj_dim is not None:
+ if qk_norm == "layer_norm":
+ self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ elif qk_norm == "fp32_layer_norm":
+ self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ elif qk_norm == "rms_norm":
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
+ elif qk_norm == "rms_norm_across_heads":
+ # Wan applies qk norm across all heads
+ # Wan also doesn't apply a q norm
+ self.norm_added_q = None
+ self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
+ else:
+ raise ValueError(
+ f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
+ )
+ else:
+ self.norm_added_q = None
+ self.norm_added_k = None
+
+ # set attention processor
+ # We use the AttnProcessorSDPA by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ if processor is None:
+ processor = self.default_processor_class()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ r"""
+ The forward method of the `Attention` class.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ The hidden states of the query.
+ encoder_hidden_states (`torch.Tensor`, *optional*):
+ The hidden states of the encoder.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention mask to use. If `None`, no mask is applied.
+ **cross_attention_kwargs:
+ Additional keyword arguments to pass along to the cross attention.
+
+ Returns:
+ `torch.Tensor`: The output of the attention layer.
+ """
+ # The `Attention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
+ unused_kwargs = [
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
+ ]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
+
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
@maybe_allow_in_graph
@@ -83,1169 +869,3 @@ class GatedSelfAttentionDense(nn.Module):
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
return x
-
-
-@maybe_allow_in_graph
-class JointTransformerBlock(nn.Module):
- r"""
- A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
-
- Reference: https://arxiv.org/abs/2403.03206
-
- Parameters:
- dim (`int`): The number of channels in the input and output.
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
- attention_head_dim (`int`): The number of channels in each head.
- context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
- processing of `context` conditions.
- """
-
- def __init__(
- self,
- dim: int,
- num_attention_heads: int,
- attention_head_dim: int,
- context_pre_only: bool = False,
- qk_norm: Optional[str] = None,
- use_dual_attention: bool = False,
- ):
- super().__init__()
-
- self.use_dual_attention = use_dual_attention
- self.context_pre_only = context_pre_only
- context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
-
- if use_dual_attention:
- self.norm1 = SD35AdaLayerNormZeroX(dim)
- else:
- self.norm1 = AdaLayerNormZero(dim)
-
- if context_norm_type == "ada_norm_continous":
- self.norm1_context = AdaLayerNormContinuous(
- dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
- )
- elif context_norm_type == "ada_norm_zero":
- self.norm1_context = AdaLayerNormZero(dim)
- else:
- raise ValueError(
- f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
- )
-
- if hasattr(F, "scaled_dot_product_attention"):
- processor = JointAttnProcessor2_0()
- else:
- raise ValueError(
- "The current PyTorch version does not support the `scaled_dot_product_attention` function."
- )
-
- self.attn = Attention(
- query_dim=dim,
- cross_attention_dim=None,
- added_kv_proj_dim=dim,
- dim_head=attention_head_dim,
- heads=num_attention_heads,
- out_dim=dim,
- context_pre_only=context_pre_only,
- bias=True,
- processor=processor,
- qk_norm=qk_norm,
- eps=1e-6,
- )
-
- if use_dual_attention:
- self.attn2 = Attention(
- query_dim=dim,
- cross_attention_dim=None,
- dim_head=attention_head_dim,
- heads=num_attention_heads,
- out_dim=dim,
- bias=True,
- processor=processor,
- qk_norm=qk_norm,
- eps=1e-6,
- )
- else:
- self.attn2 = None
-
- self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
- self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
-
- if not context_pre_only:
- self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
- self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
- else:
- self.norm2_context = None
- self.ff_context = None
-
- # let chunk size default to None
- self._chunk_size = None
- self._chunk_dim = 0
-
- # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
- # Sets chunk feed-forward
- self._chunk_size = chunk_size
- self._chunk_dim = dim
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor,
- temb: torch.FloatTensor,
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- ):
- joint_attention_kwargs = joint_attention_kwargs or {}
- if self.use_dual_attention:
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
- hidden_states, emb=temb
- )
- else:
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
-
- if self.context_pre_only:
- norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
- else:
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
- encoder_hidden_states, emb=temb
- )
-
- # Attention.
- attn_output, context_attn_output = self.attn(
- hidden_states=norm_hidden_states,
- encoder_hidden_states=norm_encoder_hidden_states,
- **joint_attention_kwargs,
- )
-
- # Process attention outputs for the `hidden_states`.
- attn_output = gate_msa.unsqueeze(1) * attn_output
- hidden_states = hidden_states + attn_output
-
- if self.use_dual_attention:
- attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
- attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
- hidden_states = hidden_states + attn_output2
-
- norm_hidden_states = self.norm2(hidden_states)
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
- if self._chunk_size is not None:
- # "feed_forward_chunk_size" can be used to save memory
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
- else:
- ff_output = self.ff(norm_hidden_states)
- ff_output = gate_mlp.unsqueeze(1) * ff_output
-
- hidden_states = hidden_states + ff_output
-
- # Process attention outputs for the `encoder_hidden_states`.
- if self.context_pre_only:
- encoder_hidden_states = None
- else:
- context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
- encoder_hidden_states = encoder_hidden_states + context_attn_output
-
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
- if self._chunk_size is not None:
- # "feed_forward_chunk_size" can be used to save memory
- context_ff_output = _chunked_feed_forward(
- self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
- )
- else:
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
-
- return encoder_hidden_states, hidden_states
-
-
-@maybe_allow_in_graph
-class BasicTransformerBlock(nn.Module):
- r"""
- A basic Transformer block.
-
- Parameters:
- dim (`int`): The number of channels in the input and output.
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
- attention_head_dim (`int`): The number of channels in each head.
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
- num_embeds_ada_norm (:
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
- attention_bias (:
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
- only_cross_attention (`bool`, *optional*):
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
- double_self_attention (`bool`, *optional*):
- Whether to use two self-attention layers. In this case no cross attention layers are used.
- upcast_attention (`bool`, *optional*):
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
- norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
- Whether to use learnable elementwise affine parameters for normalization.
- norm_type (`str`, *optional*, defaults to `"layer_norm"`):
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
- final_dropout (`bool` *optional*, defaults to False):
- Whether to apply a final dropout after the last feed-forward layer.
- attention_type (`str`, *optional*, defaults to `"default"`):
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
- positional_embeddings (`str`, *optional*, defaults to `None`):
- The type of positional embeddings to apply to.
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
- The maximum number of positional embeddings to apply.
- """
-
- def __init__(
- self,
- dim: int,
- num_attention_heads: int,
- attention_head_dim: int,
- dropout=0.0,
- cross_attention_dim: Optional[int] = None,
- activation_fn: str = "geglu",
- num_embeds_ada_norm: Optional[int] = None,
- attention_bias: bool = False,
- only_cross_attention: bool = False,
- double_self_attention: bool = False,
- upcast_attention: bool = False,
- norm_elementwise_affine: bool = True,
- norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
- norm_eps: float = 1e-5,
- final_dropout: bool = False,
- attention_type: str = "default",
- positional_embeddings: Optional[str] = None,
- num_positional_embeddings: Optional[int] = None,
- ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
- ada_norm_bias: Optional[int] = None,
- ff_inner_dim: Optional[int] = None,
- ff_bias: bool = True,
- attention_out_bias: bool = True,
- ):
- super().__init__()
- self.dim = dim
- self.num_attention_heads = num_attention_heads
- self.attention_head_dim = attention_head_dim
- self.dropout = dropout
- self.cross_attention_dim = cross_attention_dim
- self.activation_fn = activation_fn
- self.attention_bias = attention_bias
- self.double_self_attention = double_self_attention
- self.norm_elementwise_affine = norm_elementwise_affine
- self.positional_embeddings = positional_embeddings
- self.num_positional_embeddings = num_positional_embeddings
- self.only_cross_attention = only_cross_attention
-
- # We keep these boolean flags for backward-compatibility.
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
- self.use_layer_norm = norm_type == "layer_norm"
- self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
-
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
- raise ValueError(
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
- )
-
- self.norm_type = norm_type
- self.num_embeds_ada_norm = num_embeds_ada_norm
-
- if positional_embeddings and (num_positional_embeddings is None):
- raise ValueError(
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
- )
-
- if positional_embeddings == "sinusoidal":
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
- else:
- self.pos_embed = None
-
- # Define 3 blocks. Each block has its own normalization layer.
- # 1. Self-Attn
- if norm_type == "ada_norm":
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
- elif norm_type == "ada_norm_zero":
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
- elif norm_type == "ada_norm_continuous":
- self.norm1 = AdaLayerNormContinuous(
- dim,
- ada_norm_continous_conditioning_embedding_dim,
- norm_elementwise_affine,
- norm_eps,
- ada_norm_bias,
- "rms_norm",
- )
- else:
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
-
- self.attn1 = Attention(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
- upcast_attention=upcast_attention,
- out_bias=attention_out_bias,
- )
-
- # 2. Cross-Attn
- if cross_attention_dim is not None or double_self_attention:
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
- # the second cross attention block.
- if norm_type == "ada_norm":
- self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
- elif norm_type == "ada_norm_continuous":
- self.norm2 = AdaLayerNormContinuous(
- dim,
- ada_norm_continous_conditioning_embedding_dim,
- norm_elementwise_affine,
- norm_eps,
- ada_norm_bias,
- "rms_norm",
- )
- else:
- self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
-
- self.attn2 = Attention(
- query_dim=dim,
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- upcast_attention=upcast_attention,
- out_bias=attention_out_bias,
- ) # is self-attn if encoder_hidden_states is none
- else:
- if norm_type == "ada_norm_single": # For Latte
- self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
- else:
- self.norm2 = None
- self.attn2 = None
-
- # 3. Feed-forward
- if norm_type == "ada_norm_continuous":
- self.norm3 = AdaLayerNormContinuous(
- dim,
- ada_norm_continous_conditioning_embedding_dim,
- norm_elementwise_affine,
- norm_eps,
- ada_norm_bias,
- "layer_norm",
- )
-
- elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
- self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
- elif norm_type == "layer_norm_i2vgen":
- self.norm3 = None
-
- self.ff = FeedForward(
- dim,
- dropout=dropout,
- activation_fn=activation_fn,
- final_dropout=final_dropout,
- inner_dim=ff_inner_dim,
- bias=ff_bias,
- )
-
- # 4. Fuser
- if attention_type == "gated" or attention_type == "gated-text-image":
- self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
-
- # 5. Scale-shift for PixArt-Alpha.
- if norm_type == "ada_norm_single":
- self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
-
- # let chunk size default to None
- self._chunk_size = None
- self._chunk_dim = 0
-
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
- # Sets chunk feed-forward
- self._chunk_size = chunk_size
- self._chunk_dim = dim
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.Tensor] = None,
- timestep: Optional[torch.LongTensor] = None,
- cross_attention_kwargs: Dict[str, Any] = None,
- class_labels: Optional[torch.LongTensor] = None,
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
- ) -> torch.Tensor:
- if cross_attention_kwargs is not None:
- if cross_attention_kwargs.get("scale", None) is not None:
- logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
-
- # Notice that normalization is always applied before the real computation in the following blocks.
- # 0. Self-Attention
- batch_size = hidden_states.shape[0]
-
- if self.norm_type == "ada_norm":
- norm_hidden_states = self.norm1(hidden_states, timestep)
- elif self.norm_type == "ada_norm_zero":
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
- )
- elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
- norm_hidden_states = self.norm1(hidden_states)
- elif self.norm_type == "ada_norm_continuous":
- norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
- elif self.norm_type == "ada_norm_single":
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
- self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
- ).chunk(6, dim=1)
- norm_hidden_states = self.norm1(hidden_states)
- norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
- else:
- raise ValueError("Incorrect norm used")
-
- if self.pos_embed is not None:
- norm_hidden_states = self.pos_embed(norm_hidden_states)
-
- # 1. Prepare GLIGEN inputs
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
- gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
-
- attn_output = self.attn1(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
- attention_mask=attention_mask,
- **cross_attention_kwargs,
- )
-
- if self.norm_type == "ada_norm_zero":
- attn_output = gate_msa.unsqueeze(1) * attn_output
- elif self.norm_type == "ada_norm_single":
- attn_output = gate_msa * attn_output
-
- hidden_states = attn_output + hidden_states
- if hidden_states.ndim == 4:
- hidden_states = hidden_states.squeeze(1)
-
- # 1.2 GLIGEN Control
- if gligen_kwargs is not None:
- hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
-
- # 3. Cross-Attention
- if self.attn2 is not None:
- if self.norm_type == "ada_norm":
- norm_hidden_states = self.norm2(hidden_states, timestep)
- elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
- norm_hidden_states = self.norm2(hidden_states)
- elif self.norm_type == "ada_norm_single":
- # For PixArt norm2 isn't applied here:
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
- norm_hidden_states = hidden_states
- elif self.norm_type == "ada_norm_continuous":
- norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
- else:
- raise ValueError("Incorrect norm")
-
- if self.pos_embed is not None and self.norm_type != "ada_norm_single":
- norm_hidden_states = self.pos_embed(norm_hidden_states)
-
- attn_output = self.attn2(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- **cross_attention_kwargs,
- )
- hidden_states = attn_output + hidden_states
-
- # 4. Feed-forward
- # i2vgen doesn't have this norm 🤷♂️
- if self.norm_type == "ada_norm_continuous":
- norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
- elif not self.norm_type == "ada_norm_single":
- norm_hidden_states = self.norm3(hidden_states)
-
- if self.norm_type == "ada_norm_zero":
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
-
- if self.norm_type == "ada_norm_single":
- norm_hidden_states = self.norm2(hidden_states)
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
-
- if self._chunk_size is not None:
- # "feed_forward_chunk_size" can be used to save memory
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
- else:
- ff_output = self.ff(norm_hidden_states)
-
- if self.norm_type == "ada_norm_zero":
- ff_output = gate_mlp.unsqueeze(1) * ff_output
- elif self.norm_type == "ada_norm_single":
- ff_output = gate_mlp * ff_output
-
- hidden_states = ff_output + hidden_states
- if hidden_states.ndim == 4:
- hidden_states = hidden_states.squeeze(1)
-
- return hidden_states
-
-
-class LuminaFeedForward(nn.Module):
- r"""
- A feed-forward layer.
-
- Parameters:
- hidden_size (`int`):
- The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
- hidden representations.
- intermediate_size (`int`): The intermediate dimension of the feedforward layer.
- multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
- of this value.
- ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
- dimension. Defaults to None.
- """
-
- def __init__(
- self,
- dim: int,
- inner_dim: int,
- multiple_of: Optional[int] = 256,
- ffn_dim_multiplier: Optional[float] = None,
- ):
- super().__init__()
- # custom hidden_size factor multiplier
- if ffn_dim_multiplier is not None:
- inner_dim = int(ffn_dim_multiplier * inner_dim)
- inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
-
- self.linear_1 = nn.Linear(
- dim,
- inner_dim,
- bias=False,
- )
- self.linear_2 = nn.Linear(
- inner_dim,
- dim,
- bias=False,
- )
- self.linear_3 = nn.Linear(
- dim,
- inner_dim,
- bias=False,
- )
- self.silu = FP32SiLU()
-
- def forward(self, x):
- return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
-
-
-@maybe_allow_in_graph
-class TemporalBasicTransformerBlock(nn.Module):
- r"""
- A basic Transformer block for video like data.
-
- Parameters:
- dim (`int`): The number of channels in the input and output.
- time_mix_inner_dim (`int`): The number of channels for temporal attention.
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
- attention_head_dim (`int`): The number of channels in each head.
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
- """
-
- def __init__(
- self,
- dim: int,
- time_mix_inner_dim: int,
- num_attention_heads: int,
- attention_head_dim: int,
- cross_attention_dim: Optional[int] = None,
- ):
- super().__init__()
- self.is_res = dim == time_mix_inner_dim
-
- self.norm_in = nn.LayerNorm(dim)
-
- # Define 3 blocks. Each block has its own normalization layer.
- # 1. Self-Attn
- self.ff_in = FeedForward(
- dim,
- dim_out=time_mix_inner_dim,
- activation_fn="geglu",
- )
-
- self.norm1 = nn.LayerNorm(time_mix_inner_dim)
- self.attn1 = Attention(
- query_dim=time_mix_inner_dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- cross_attention_dim=None,
- )
-
- # 2. Cross-Attn
- if cross_attention_dim is not None:
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
- # the second cross attention block.
- self.norm2 = nn.LayerNorm(time_mix_inner_dim)
- self.attn2 = Attention(
- query_dim=time_mix_inner_dim,
- cross_attention_dim=cross_attention_dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- ) # is self-attn if encoder_hidden_states is none
- else:
- self.norm2 = None
- self.attn2 = None
-
- # 3. Feed-forward
- self.norm3 = nn.LayerNorm(time_mix_inner_dim)
- self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
-
- # let chunk size default to None
- self._chunk_size = None
- self._chunk_dim = None
-
- def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
- # Sets chunk feed-forward
- self._chunk_size = chunk_size
- # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
- self._chunk_dim = 1
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- num_frames: int,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- # Notice that normalization is always applied before the real computation in the following blocks.
- # 0. Self-Attention
- batch_size = hidden_states.shape[0]
-
- batch_frames, seq_length, channels = hidden_states.shape
- batch_size = batch_frames // num_frames
-
- hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
- hidden_states = hidden_states.permute(0, 2, 1, 3)
- hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
-
- residual = hidden_states
- hidden_states = self.norm_in(hidden_states)
-
- if self._chunk_size is not None:
- hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
- else:
- hidden_states = self.ff_in(hidden_states)
-
- if self.is_res:
- hidden_states = hidden_states + residual
-
- norm_hidden_states = self.norm1(hidden_states)
- attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
- hidden_states = attn_output + hidden_states
-
- # 3. Cross-Attention
- if self.attn2 is not None:
- norm_hidden_states = self.norm2(hidden_states)
- attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
- hidden_states = attn_output + hidden_states
-
- # 4. Feed-forward
- norm_hidden_states = self.norm3(hidden_states)
-
- if self._chunk_size is not None:
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
- else:
- ff_output = self.ff(norm_hidden_states)
-
- if self.is_res:
- hidden_states = ff_output + hidden_states
- else:
- hidden_states = ff_output
-
- hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
- hidden_states = hidden_states.permute(0, 2, 1, 3)
- hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
-
- return hidden_states
-
-
-class SkipFFTransformerBlock(nn.Module):
- def __init__(
- self,
- dim: int,
- num_attention_heads: int,
- attention_head_dim: int,
- kv_input_dim: int,
- kv_input_dim_proj_use_bias: bool,
- dropout=0.0,
- cross_attention_dim: Optional[int] = None,
- attention_bias: bool = False,
- attention_out_bias: bool = True,
- ):
- super().__init__()
- if kv_input_dim != dim:
- self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
- else:
- self.kv_mapper = None
-
- self.norm1 = RMSNorm(dim, 1e-06)
-
- self.attn1 = Attention(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- cross_attention_dim=cross_attention_dim,
- out_bias=attention_out_bias,
- )
-
- self.norm2 = RMSNorm(dim, 1e-06)
-
- self.attn2 = Attention(
- query_dim=dim,
- cross_attention_dim=cross_attention_dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- out_bias=attention_out_bias,
- )
-
- def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
-
- if self.kv_mapper is not None:
- encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
-
- norm_hidden_states = self.norm1(hidden_states)
-
- attn_output = self.attn1(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- **cross_attention_kwargs,
- )
-
- hidden_states = attn_output + hidden_states
-
- norm_hidden_states = self.norm2(hidden_states)
-
- attn_output = self.attn2(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- **cross_attention_kwargs,
- )
-
- hidden_states = attn_output + hidden_states
-
- return hidden_states
-
-
-@maybe_allow_in_graph
-class FreeNoiseTransformerBlock(nn.Module):
- r"""
- A FreeNoise Transformer block.
-
- Parameters:
- dim (`int`):
- The number of channels in the input and output.
- num_attention_heads (`int`):
- The number of heads to use for multi-head attention.
- attention_head_dim (`int`):
- The number of channels in each head.
- dropout (`float`, *optional*, defaults to 0.0):
- The dropout probability to use.
- cross_attention_dim (`int`, *optional*):
- The size of the encoder_hidden_states vector for cross attention.
- activation_fn (`str`, *optional*, defaults to `"geglu"`):
- Activation function to be used in feed-forward.
- num_embeds_ada_norm (`int`, *optional*):
- The number of diffusion steps used during training. See `Transformer2DModel`.
- attention_bias (`bool`, defaults to `False`):
- Configure if the attentions should contain a bias parameter.
- only_cross_attention (`bool`, defaults to `False`):
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
- double_self_attention (`bool`, defaults to `False`):
- Whether to use two self-attention layers. In this case no cross attention layers are used.
- upcast_attention (`bool`, defaults to `False`):
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
- norm_elementwise_affine (`bool`, defaults to `True`):
- Whether to use learnable elementwise affine parameters for normalization.
- norm_type (`str`, defaults to `"layer_norm"`):
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
- final_dropout (`bool` defaults to `False`):
- Whether to apply a final dropout after the last feed-forward layer.
- attention_type (`str`, defaults to `"default"`):
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
- positional_embeddings (`str`, *optional*):
- The type of positional embeddings to apply to.
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
- The maximum number of positional embeddings to apply.
- ff_inner_dim (`int`, *optional*):
- Hidden dimension of feed-forward MLP.
- ff_bias (`bool`, defaults to `True`):
- Whether or not to use bias in feed-forward MLP.
- attention_out_bias (`bool`, defaults to `True`):
- Whether or not to use bias in attention output project layer.
- context_length (`int`, defaults to `16`):
- The maximum number of frames that the FreeNoise block processes at once.
- context_stride (`int`, defaults to `4`):
- The number of frames to be skipped before starting to process a new batch of `context_length` frames.
- weighting_scheme (`str`, defaults to `"pyramid"`):
- The weighting scheme to use for weighting averaging of processed latent frames. As described in the
- Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
- used.
- """
-
- def __init__(
- self,
- dim: int,
- num_attention_heads: int,
- attention_head_dim: int,
- dropout: float = 0.0,
- cross_attention_dim: Optional[int] = None,
- activation_fn: str = "geglu",
- num_embeds_ada_norm: Optional[int] = None,
- attention_bias: bool = False,
- only_cross_attention: bool = False,
- double_self_attention: bool = False,
- upcast_attention: bool = False,
- norm_elementwise_affine: bool = True,
- norm_type: str = "layer_norm",
- norm_eps: float = 1e-5,
- final_dropout: bool = False,
- positional_embeddings: Optional[str] = None,
- num_positional_embeddings: Optional[int] = None,
- ff_inner_dim: Optional[int] = None,
- ff_bias: bool = True,
- attention_out_bias: bool = True,
- context_length: int = 16,
- context_stride: int = 4,
- weighting_scheme: str = "pyramid",
- ):
- super().__init__()
- self.dim = dim
- self.num_attention_heads = num_attention_heads
- self.attention_head_dim = attention_head_dim
- self.dropout = dropout
- self.cross_attention_dim = cross_attention_dim
- self.activation_fn = activation_fn
- self.attention_bias = attention_bias
- self.double_self_attention = double_self_attention
- self.norm_elementwise_affine = norm_elementwise_affine
- self.positional_embeddings = positional_embeddings
- self.num_positional_embeddings = num_positional_embeddings
- self.only_cross_attention = only_cross_attention
-
- self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
-
- # We keep these boolean flags for backward-compatibility.
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
- self.use_layer_norm = norm_type == "layer_norm"
- self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
-
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
- raise ValueError(
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
- )
-
- self.norm_type = norm_type
- self.num_embeds_ada_norm = num_embeds_ada_norm
-
- if positional_embeddings and (num_positional_embeddings is None):
- raise ValueError(
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
- )
-
- if positional_embeddings == "sinusoidal":
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
- else:
- self.pos_embed = None
-
- # Define 3 blocks. Each block has its own normalization layer.
- # 1. Self-Attn
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
-
- self.attn1 = Attention(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
- upcast_attention=upcast_attention,
- out_bias=attention_out_bias,
- )
-
- # 2. Cross-Attn
- if cross_attention_dim is not None or double_self_attention:
- self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
-
- self.attn2 = Attention(
- query_dim=dim,
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- upcast_attention=upcast_attention,
- out_bias=attention_out_bias,
- ) # is self-attn if encoder_hidden_states is none
-
- # 3. Feed-forward
- self.ff = FeedForward(
- dim,
- dropout=dropout,
- activation_fn=activation_fn,
- final_dropout=final_dropout,
- inner_dim=ff_inner_dim,
- bias=ff_bias,
- )
-
- self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
-
- # let chunk size default to None
- self._chunk_size = None
- self._chunk_dim = 0
-
- def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
- frame_indices = []
- for i in range(0, num_frames - self.context_length + 1, self.context_stride):
- window_start = i
- window_end = min(num_frames, i + self.context_length)
- frame_indices.append((window_start, window_end))
- return frame_indices
-
- def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
- if weighting_scheme == "flat":
- weights = [1.0] * num_frames
-
- elif weighting_scheme == "pyramid":
- if num_frames % 2 == 0:
- # num_frames = 4 => [1, 2, 2, 1]
- mid = num_frames // 2
- weights = list(range(1, mid + 1))
- weights = weights + weights[::-1]
- else:
- # num_frames = 5 => [1, 2, 3, 2, 1]
- mid = (num_frames + 1) // 2
- weights = list(range(1, mid))
- weights = weights + [mid] + weights[::-1]
-
- elif weighting_scheme == "delayed_reverse_sawtooth":
- if num_frames % 2 == 0:
- # num_frames = 4 => [0.01, 2, 2, 1]
- mid = num_frames // 2
- weights = [0.01] * (mid - 1) + [mid]
- weights = weights + list(range(mid, 0, -1))
- else:
- # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
- mid = (num_frames + 1) // 2
- weights = [0.01] * mid
- weights = weights + list(range(mid, 0, -1))
- else:
- raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
-
- return weights
-
- def set_free_noise_properties(
- self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
- ) -> None:
- self.context_length = context_length
- self.context_stride = context_stride
- self.weighting_scheme = weighting_scheme
-
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
- # Sets chunk feed-forward
- self._chunk_size = chunk_size
- self._chunk_dim = dim
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.Tensor] = None,
- cross_attention_kwargs: Dict[str, Any] = None,
- *args,
- **kwargs,
- ) -> torch.Tensor:
- if cross_attention_kwargs is not None:
- if cross_attention_kwargs.get("scale", None) is not None:
- logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
-
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
-
- # hidden_states: [B x H x W, F, C]
- device = hidden_states.device
- dtype = hidden_states.dtype
-
- num_frames = hidden_states.size(1)
- frame_indices = self._get_frame_indices(num_frames)
- frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
- frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
- is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
-
- # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
- # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
- # [(0, 16), (4, 20), (8, 24), (10, 26)]
- if not is_last_frame_batch_complete:
- if num_frames < self.context_length:
- raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
- last_frame_batch_length = num_frames - frame_indices[-1][1]
- frame_indices.append((num_frames - self.context_length, num_frames))
-
- num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
- accumulated_values = torch.zeros_like(hidden_states)
-
- for i, (frame_start, frame_end) in enumerate(frame_indices):
- # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
- # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
- # essentially a non-multiple of `context_length`.
- weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
- weights *= frame_weights
-
- hidden_states_chunk = hidden_states[:, frame_start:frame_end]
-
- # Notice that normalization is always applied before the real computation in the following blocks.
- # 1. Self-Attention
- norm_hidden_states = self.norm1(hidden_states_chunk)
-
- if self.pos_embed is not None:
- norm_hidden_states = self.pos_embed(norm_hidden_states)
-
- attn_output = self.attn1(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
- attention_mask=attention_mask,
- **cross_attention_kwargs,
- )
-
- hidden_states_chunk = attn_output + hidden_states_chunk
- if hidden_states_chunk.ndim == 4:
- hidden_states_chunk = hidden_states_chunk.squeeze(1)
-
- # 2. Cross-Attention
- if self.attn2 is not None:
- norm_hidden_states = self.norm2(hidden_states_chunk)
-
- if self.pos_embed is not None and self.norm_type != "ada_norm_single":
- norm_hidden_states = self.pos_embed(norm_hidden_states)
-
- attn_output = self.attn2(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- **cross_attention_kwargs,
- )
- hidden_states_chunk = attn_output + hidden_states_chunk
-
- if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
- accumulated_values[:, -last_frame_batch_length:] += (
- hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
- )
- num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
- else:
- accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
- num_times_accumulated[:, frame_start:frame_end] += weights
-
- # TODO(aryan): Maybe this could be done in a better way.
- #
- # Previously, this was:
- # hidden_states = torch.where(
- # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
- # )
- #
- # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
- # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
- # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
- # looked into this deeply because other memory optimizations led to more pronounced reductions.
- hidden_states = torch.cat(
- [
- torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
- for accumulated_split, num_times_split in zip(
- accumulated_values.split(self.context_length, dim=1),
- num_times_accumulated.split(self.context_length, dim=1),
- )
- ],
- dim=1,
- ).to(dtype)
-
- # 3. Feed-forward
- norm_hidden_states = self.norm3(hidden_states)
-
- if self._chunk_size is not None:
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
- else:
- ff_output = self.ff(norm_hidden_states)
-
- hidden_states = ff_output + hidden_states
- if hidden_states.ndim == 4:
- hidden_states = hidden_states.squeeze(1)
-
- return hidden_states
-
-
-class FeedForward(nn.Module):
- r"""
- A feed-forward layer.
-
- Parameters:
- dim (`int`): The number of channels in the input.
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
- bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
- """
-
- def __init__(
- self,
- dim: int,
- dim_out: Optional[int] = None,
- mult: int = 4,
- dropout: float = 0.0,
- activation_fn: str = "geglu",
- final_dropout: bool = False,
- inner_dim=None,
- bias: bool = True,
- ):
- super().__init__()
- if inner_dim is None:
- inner_dim = int(dim * mult)
- dim_out = dim_out if dim_out is not None else dim
-
- if activation_fn == "gelu":
- act_fn = GELU(dim, inner_dim, bias=bias)
- if activation_fn == "gelu-approximate":
- act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
- elif activation_fn == "geglu":
- act_fn = GEGLU(dim, inner_dim, bias=bias)
- elif activation_fn == "geglu-approximate":
- act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
- elif activation_fn == "swiglu":
- act_fn = SwiGLU(dim, inner_dim, bias=bias)
- elif activation_fn == "linear-silu":
- act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
-
- self.net = nn.ModuleList([])
- # project in
- self.net.append(act_fn)
- # project dropout
- self.net.append(nn.Dropout(dropout))
- # project out
- self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
- if final_dropout:
- self.net.append(nn.Dropout(dropout))
-
- def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
- if len(args) > 0 or kwargs.get("scale", None) is not None:
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
- deprecate("scale", "1.0.0", deprecation_message)
- for module in self.net:
- hidden_states = module(hidden_states)
- return hidden_states
diff --git a/src/diffusers/models/attention_modules.py b/src/diffusers/models/attention_modules.py
deleted file mode 100644
index 96b8438fb5..0000000000
--- a/src/diffusers/models/attention_modules.py
+++ /dev/null
@@ -1,247 +0,0 @@
-# Copyright 2025 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import inspect
-from typing import Optional, Tuple, Union
-
-import torch
-from torch import nn
-
-from ..utils import logging
-from ..utils.torch_utils import maybe_allow_in_graph
-from .attention_processor import (
- AttentionModuleMixin,
- FusedJointAttnProcessorSDPA,
- JointAttnProcessorSDPA,
- SanaLinearAttnProcessorSDPA,
-)
-from .normalization import get_normalization
-
-
-logger = logging.get_logger(__name__)
-
-
-@maybe_allow_in_graph
-class SanaAttention(nn.Module, AttentionModuleMixin):
- """
- Attention implementation specialized for Sana models.
-
- This module implements lightweight multi-scale linear attention as used in Sana.
-
- Args:
- in_channels (`int`): Number of input channels.
- out_channels (`int`): Number of output channels.
- num_attention_heads (`int`, *optional*): Number of attention heads.
- attention_head_dim (`int`, defaults to 8): Dimension of each attention head.
- mult (`float`, defaults to 1.0): Multiplier for inner dimension.
- norm_type (`str`, defaults to "batch_norm"): Type of normalization.
- kernel_sizes (`Tuple[int, ...]`, defaults to (5,)): Kernel sizes for multi-scale attention.
- """
-
- # Set Sana-specific processor classes
- default_processor_class = SanaLinearAttnProcessorSDPA
- fused_processor_class = None # Sana doesn't have a fused processor yet
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- num_attention_heads: Optional[int] = None,
- attention_head_dim: int = 8,
- mult: float = 1.0,
- norm_type: str = "batch_norm",
- kernel_sizes: Tuple[int, ...] = (5,),
- eps: float = 1e-15,
- residual_connection: bool = False,
- ):
- super().__init__()
-
- # Core parameters
- self.eps = eps
- self.attention_head_dim = attention_head_dim
- self.norm_type = norm_type
- self.residual_connection = residual_connection
-
- # Calculate dimensions
- num_attention_heads = (
- int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
- )
- inner_dim = num_attention_heads * attention_head_dim
- self.inner_dim = inner_dim
- self.heads = num_attention_heads
-
- # Query, key, value projections
- self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
- self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
- self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
-
- # Multi-scale attention
- self.to_qkv_multiscale = nn.ModuleList()
- for kernel_size in kernel_sizes:
- self.to_qkv_multiscale.append(
- SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
- )
-
- # Output layers
- self.nonlinearity = nn.ReLU()
- self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
- self.norm_out = get_normalization(norm_type, num_features=out_channels)
-
- # Set default processor
- self.fused_projections = False
- self.set_processor(self.default_processor_class())
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- **kwargs,
- ) -> torch.Tensor:
- """Process linear attention for Sana model inputs."""
- return self.processor(self, hidden_states)
-
-
-class SanaMultiscaleAttentionProjection(nn.Module):
- """Projection layer for Sana multi-scale attention."""
-
- def __init__(
- self,
- in_channels: int,
- num_attention_heads: int,
- kernel_size: int,
- ) -> None:
- super().__init__()
-
- channels = 3 * in_channels
- self.proj_in = nn.Conv2d(
- channels,
- channels,
- kernel_size,
- padding=kernel_size // 2,
- groups=channels,
- bias=False,
- )
- self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.proj_in(hidden_states)
- hidden_states = self.proj_out(hidden_states)
- return hidden_states
-
-
-@maybe_allow_in_graph
-class SD3Attention(nn.Module, AttentionModuleMixin):
- """
- Attention implementation specialized for SD3 models.
-
- This module implements the joint attention mechanism used in SD3,
- with native support for context pre-processing.
-
- Args:
- query_dim (`int`): Number of channels in query.
- cross_attention_dim (`int`, *optional*): Number of channels in encoder states.
- heads (`int`, defaults to 8): Number of attention heads.
- dim_head (`int`, defaults to 64): Dimension of each attention head.
- dropout (`float`, defaults to 0.0): Dropout probability.
- bias (`bool`, defaults to False): Whether to use bias in linear projections.
- added_kv_proj_dim (`int`, *optional*): Dimension for added key/value projections.
- """
-
- # Set SD3-specific processor classes
- default_processor_class = JointAttnProcessorSDPA
- fused_processor_class = FusedJointAttnProcessorSDPA
-
- def __init__(
- self,
- query_dim: int,
- cross_attention_dim: Optional[int] = None,
- heads: int = 8,
- dim_head: int = 64,
- dropout: float = 0.0,
- bias: bool = False,
- added_kv_proj_dim: Optional[int] = None,
- context_pre_only: bool = False,
- ):
- super().__init__()
-
- # Core parameters
- self.inner_dim = dim_head * heads
- self.query_dim = query_dim
- self.heads = heads
- self.scale = dim_head**-0.5
- self.use_bias = bias
- self.scale_qk = True
- self.context_pre_only = context_pre_only
-
- # Cross-attention setup
- self.is_cross_attention = cross_attention_dim is not None
- self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
-
- # Projections for self-attention
- self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
- self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
- self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
-
- # Added projections for context processing
- self.added_kv_proj_dim = added_kv_proj_dim
- if added_kv_proj_dim is not None:
- self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
- self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
- self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
- self.added_proj_bias = bias
-
- # Output projection
- self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, query_dim, bias=bias), nn.Dropout(dropout)])
-
- # Context output projection
- if added_kv_proj_dim is not None and not context_pre_only:
- self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias)
- else:
- self.to_add_out = None
-
- # Set default processor and fusion state
- self.fused_projections = False
- self.set_processor(self.default_processor_class())
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- **kwargs,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- """Process joint attention for SD3 model inputs."""
- # Filter parameters to only those expected by the processor
- processor_params = inspect.signature(self.processor.__call__).parameters.keys()
- quiet_params = {"ip_adapter_masks", "ip_hidden_states"}
-
- # Check for unexpected parameters
- unexpected_params = [k for k, _ in kwargs.items() if k not in processor_params and k not in quiet_params]
- if unexpected_params:
- logger.warning(
- f"Parameters {unexpected_params} are not expected by {self.processor.__class__.__name__} and will be ignored."
- )
-
- # Filter to only expected parameters
- filtered_kwargs = {k: v for k, v in kwargs.items() if k in processor_params}
-
- # Process with appropriate processor
- return self.processor(
- self,
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- **filtered_kwargs,
- )
-
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 80aed8d123..647ad41a6b 100755
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -22,7 +22,7 @@ from torch import nn
from ..image_processor import IPAdapterMaskProcessor
from ..utils import deprecate, is_torch_xla_available, logging
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available
-from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
+from ..utils.torch_utils import is_torch_version
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -46,596 +46,6 @@ else:
XLA_AVAILABLE = False
-class AttentionModuleMixin:
- """
- A mixin class that provides common methods for attention modules.
-
- This mixin adds functionality to set different attention processors, handle attention masks, compute attention
- scores, and manage projections.
- """
-
- # Default processor classes to be overridden by subclasses
- default_processor_cls = None
- _available_processors = []
-
- def _get_compatible_processor(self, backend):
- for processor_cls in self._available_processors:
- if backend in processor_cls.compatible_backends:
- processor = processor_cls()
- return processor
-
- def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
- """
- Set whether to use NPU flash attention from `torch_npu` or not.
-
- Args:
- use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
- """
- processor = self.default_processor_cls()
-
- if use_npu_flash_attention:
- processor = self._get_compatible_processor("npu")
-
- self.set_processor(processor)
-
- def set_use_xla_flash_attention(
- self,
- use_xla_flash_attention: bool,
- partition_spec: Optional[Tuple[Optional[str], ...]] = None,
- is_flux=False,
- ) -> None:
- """
- Set whether to use XLA flash attention from `torch_xla` or not.
-
- Args:
- use_xla_flash_attention (`bool`):
- Whether to use pallas flash attention kernel from `torch_xla` or not.
- partition_spec (`Tuple[]`, *optional*):
- Specify the partition specification if using SPMD. Otherwise None.
- is_flux (`bool`, *optional*, defaults to `False`):
- Whether the model is a Flux model.
- """
- processor = self.default_processor_cls()
- if use_xla_flash_attention:
- if not is_torch_xla_available():
- raise "torch_xla is not available"
- elif is_torch_xla_version("<", "2.3"):
- raise "flash attention pallas kernel is supported from torch_xla version 2.3"
- elif is_spmd() and is_torch_xla_version("<", "2.4"):
- raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
- else:
- processor = self._get_compatible_processor("xla")
-
- self.set_processor(processor)
-
- @torch.no_grad()
- def fuse_projections(self, fuse=True):
- """
- Fuse the query, key, and value projections into a single projection for efficiency.
-
- Args:
- fuse (`bool`): Whether to fuse the projections or not.
- """
- # Skip if already in desired state
- if getattr(self, "fused_projections", False) == fuse:
- return
-
- device = self.to_q.weight.data.device
- dtype = self.to_q.weight.data.dtype
-
- if not self.is_cross_attention:
- # Fuse self-attention projections
- concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
- in_features = concatenated_weights.shape[1]
- out_features = concatenated_weights.shape[0]
-
- self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
- self.to_qkv.weight.copy_(concatenated_weights)
- if self.use_bias:
- concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
- self.to_qkv.bias.copy_(concatenated_bias)
-
- else:
- # Fuse cross-attention key-value projections
- concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
- in_features = concatenated_weights.shape[1]
- out_features = concatenated_weights.shape[0]
-
- self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
- self.to_kv.weight.copy_(concatenated_weights)
- if self.use_bias:
- concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
- self.to_kv.bias.copy_(concatenated_bias)
-
- # Handle added projections for models like SD3, Flux, etc.
- if (
- getattr(self, "add_q_proj", None) is not None
- and getattr(self, "add_k_proj", None) is not None
- and getattr(self, "add_v_proj", None) is not None
- ):
- concatenated_weights = torch.cat(
- [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
- )
- in_features = concatenated_weights.shape[1]
- out_features = concatenated_weights.shape[0]
-
- self.to_added_qkv = nn.Linear(
- in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
- )
- self.to_added_qkv.weight.copy_(concatenated_weights)
- if self.added_proj_bias:
- concatenated_bias = torch.cat(
- [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
- )
- self.to_added_qkv.bias.copy_(concatenated_bias)
-
- self.fused_projections = fuse
- self.processor.is_fused = fuse
-
- def set_use_memory_efficient_attention_xformers(
- self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
- ) -> None:
- """
- Set whether to use memory efficient attention from `xformers` or not.
-
- Args:
- use_memory_efficient_attention_xformers (`bool`):
- Whether to use memory efficient attention from `xformers` or not.
- attention_op (`Callable`, *optional*):
- The attention operation to use. Defaults to `None` which uses the default attention operation from
- `xformers`.
- """
- is_custom_diffusion = hasattr(self, "processor") and isinstance(
- self.processor,
- (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessorSDPA),
- )
- is_added_kv_processor = hasattr(self, "processor") and isinstance(
- self.processor,
- (
- AttnAddedKVProcessor,
- AttnAddedKVProcessorSDPA,
- SlicedAttnAddedKVProcessor,
- XFormersAttnAddedKVProcessor,
- ),
- )
- is_ip_adapter = hasattr(self, "processor") and isinstance(
- self.processor,
- (IPAdapterAttnProcessor, IPAdapterAttnProcessorSDPA, IPAdapterXFormersAttnProcessor),
- )
- is_joint_processor = hasattr(self, "processor") and isinstance(
- self.processor,
- (
- JointAttnProcessorSDPA,
- XFormersJointAttnProcessor,
- ),
- )
-
- if use_memory_efficient_attention_xformers:
- if is_added_kv_processor and is_custom_diffusion:
- raise NotImplementedError(
- f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
- )
- if not is_xformers_available():
- raise ModuleNotFoundError(
- (
- "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
- " xformers"
- ),
- name="xformers",
- )
- elif not torch.cuda.is_available():
- raise ValueError(
- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
- " only available for GPU "
- )
- else:
- try:
- # Make sure we can run the memory efficient attention
- dtype = None
- if attention_op is not None:
- op_fw, op_bw = attention_op
- dtype, *_ = op_fw.SUPPORTED_DTYPES
- q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
- _ = xformers.ops.memory_efficient_attention(q, q, q)
- except Exception as e:
- raise e
-
- if is_custom_diffusion:
- processor = CustomDiffusionXFormersAttnProcessor(
- train_kv=self.processor.train_kv,
- train_q_out=self.processor.train_q_out,
- hidden_size=self.processor.hidden_size,
- cross_attention_dim=self.processor.cross_attention_dim,
- attention_op=attention_op,
- )
- processor.load_state_dict(self.processor.state_dict())
- if hasattr(self.processor, "to_k_custom_diffusion"):
- processor.to(self.processor.to_k_custom_diffusion.weight.device)
- elif is_added_kv_processor:
- # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
- # which uses this type of cross attention ONLY because the attention mask of format
- # [0, ..., -10.000, ..., 0, ...,] is not supported
- # throw warning
- logger.info(
- "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
- )
- processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
- elif is_ip_adapter:
- processor = IPAdapterXFormersAttnProcessor(
- hidden_size=self.processor.hidden_size,
- cross_attention_dim=self.processor.cross_attention_dim,
- num_tokens=self.processor.num_tokens,
- scale=self.processor.scale,
- attention_op=attention_op,
- )
- processor.load_state_dict(self.processor.state_dict())
- if hasattr(self.processor, "to_k_ip"):
- processor.to(
- device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
- )
- elif is_joint_processor:
- processor = XFormersJointAttnProcessor(attention_op=attention_op)
- else:
- processor = XFormersAttnProcessor(attention_op=attention_op)
- else:
- if is_custom_diffusion:
- attn_processor_class = (
- CustomDiffusionAttnProcessorSDPA
- if hasattr(F, "scaled_dot_product_attention")
- else CustomDiffusionAttnProcessor
- )
- processor = attn_processor_class(
- train_kv=self.processor.train_kv,
- train_q_out=self.processor.train_q_out,
- hidden_size=self.processor.hidden_size,
- cross_attention_dim=self.processor.cross_attention_dim,
- )
- processor.load_state_dict(self.processor.state_dict())
- if hasattr(self.processor, "to_k_custom_diffusion"):
- processor.to(self.processor.to_k_custom_diffusion.weight.device)
- elif is_ip_adapter:
- processor = IPAdapterAttnProcessorSDPA(
- hidden_size=self.processor.hidden_size,
- cross_attention_dim=self.processor.cross_attention_dim,
- num_tokens=self.processor.num_tokens,
- scale=self.processor.scale,
- )
- processor.load_state_dict(self.processor.state_dict())
- if hasattr(self.processor, "to_k_ip"):
- processor.to(
- device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
- )
- else:
- # set attention processor
- # We use the AttnProcessorSDPA by default when torch 2.x is used which uses
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
- processor = (
- AttnProcessorSDPA()
- if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
- else AttnProcessor()
- )
-
- self.set_processor(processor)
-
- def set_attention_slice(self, slice_size: int) -> None:
- """
- Set the slice size for attention computation.
-
- Args:
- slice_size (`int`):
- The slice size for attention computation.
- """
- if slice_size is not None and slice_size > self.sliceable_head_dim:
- raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
-
- if slice_size is not None and self.added_kv_proj_dim is not None:
- processor = SlicedAttnAddedKVProcessor(slice_size)
- elif slice_size is not None:
- processor = SlicedAttnProcessor(slice_size)
- elif self.added_kv_proj_dim is not None:
- processor = AttnAddedKVProcessor()
- else:
- # set attention processor
- # We use the AttnProcessorSDPA by default when torch 2.x is used which uses
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
- processor = (
- AttnProcessorSDPA()
- if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
- else AttnProcessor()
- )
-
- self.set_processor(processor)
-
- def set_processor(self, processor: "AttnProcessor") -> None:
- """
- Set the attention processor to use.
-
- Args:
- processor (`AttnProcessor`):
- The attention processor to use.
- """
- # if current processor is in `self._modules` and if passed `processor` is not, we need to
- # pop `processor` from `self._modules`
- if (
- hasattr(self, "processor")
- and isinstance(self.processor, torch.nn.Module)
- and not isinstance(processor, torch.nn.Module)
- ):
- logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
- self._modules.pop("processor")
-
- self.processor = processor
-
- def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
- """
- Get the attention processor in use.
-
- Args:
- return_deprecated_lora (`bool`, *optional*, defaults to `False`):
- Set to `True` to return the deprecated LoRA attention processor.
-
- Returns:
- "AttentionProcessor": The attention processor in use.
- """
- if not return_deprecated_lora:
- return self.processor
-
- def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
- """
- Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
-
- Args:
- tensor (`torch.Tensor`): The tensor to reshape.
-
- Returns:
- `torch.Tensor`: The reshaped tensor.
- """
- head_size = self.heads
- batch_size, seq_len, dim = tensor.shape
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
- return tensor
-
- def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
- """
- Reshape the tensor for multi-head attention processing.
-
- Args:
- tensor (`torch.Tensor`): The tensor to reshape.
- out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
-
- Returns:
- `torch.Tensor`: The reshaped tensor.
- """
- head_size = self.heads
- if tensor.ndim == 3:
- batch_size, seq_len, dim = tensor.shape
- extra_dim = 1
- else:
- batch_size, extra_dim, seq_len, dim = tensor.shape
- tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
- tensor = tensor.permute(0, 2, 1, 3)
-
- if out_dim == 3:
- tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
-
- return tensor
-
- def get_attention_scores(
- self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
- ) -> torch.Tensor:
- """
- Compute the attention scores.
-
- Args:
- query (`torch.Tensor`): The query tensor.
- key (`torch.Tensor`): The key tensor.
- attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
-
- Returns:
- `torch.Tensor`: The attention probabilities/scores.
- """
- dtype = query.dtype
- if self.upcast_attention:
- query = query.float()
- key = key.float()
-
- if attention_mask is None:
- baddbmm_input = torch.empty(
- query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
- )
- beta = 0
- else:
- baddbmm_input = attention_mask
- beta = 1
-
- attention_scores = torch.baddbmm(
- baddbmm_input,
- query,
- key.transpose(-1, -2),
- beta=beta,
- alpha=self.scale,
- )
- del baddbmm_input
-
- if self.upcast_softmax:
- attention_scores = attention_scores.float()
-
- attention_probs = attention_scores.softmax(dim=-1)
- del attention_scores
-
- attention_probs = attention_probs.to(dtype)
-
- return attention_probs
-
- def prepare_attention_mask(
- self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
- ) -> torch.Tensor:
- """
- Prepare the attention mask for the attention computation.
-
- Args:
- attention_mask (`torch.Tensor`): The attention mask to prepare.
- target_length (`int`): The target length of the attention mask.
- batch_size (`int`): The batch size for repeating the attention mask.
- out_dim (`int`, *optional*, defaults to `3`): Output dimension.
-
- Returns:
- `torch.Tensor`: The prepared attention mask.
- """
- head_size = self.heads
- if attention_mask is None:
- return attention_mask
-
- current_length: int = attention_mask.shape[-1]
- if current_length != target_length:
- if attention_mask.device.type == "mps":
- # HACK: MPS: Does not support padding by greater than dimension of input tensor.
- # Instead, we can manually construct the padding tensor.
- padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
- padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
- attention_mask = torch.cat([attention_mask, padding], dim=2)
- else:
- # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
- # we want to instead pad by (0, remaining_length), where remaining_length is:
- # remaining_length: int = target_length - current_length
- # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
-
- if out_dim == 3:
- if attention_mask.shape[0] < batch_size * head_size:
- attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
- elif out_dim == 4:
- attention_mask = attention_mask.unsqueeze(1)
- attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
-
- return attention_mask
-
- def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
- """
- Normalize the encoder hidden states.
-
- Args:
- encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
-
- Returns:
- `torch.Tensor`: The normalized encoder hidden states.
- """
- assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
- if isinstance(self.norm_cross, nn.LayerNorm):
- encoder_hidden_states = self.norm_cross(encoder_hidden_states)
- elif isinstance(self.norm_cross, nn.GroupNorm):
- # Group norm norms along the channels dimension and expects
- # input to be in the shape of (N, C, *). In this case, we want
- # to norm along the hidden dimension, so we need to move
- # (batch_size, sequence_length, hidden_size) ->
- # (batch_size, hidden_size, sequence_length)
- encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
- encoder_hidden_states = self.norm_cross(encoder_hidden_states)
- encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
- else:
- assert False
-
- return encoder_hidden_states
-
-
-class AttnProcessorSDPA:
- r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
- """
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: "Attention",
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- temb: Optional[torch.Tensor] = None,
- *args,
- **kwargs,
- ) -> torch.Tensor:
- if len(args) > 0 or kwargs.get("scale", None) is not None:
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
- deprecate("scale", "1.0.0", deprecation_message)
-
- residual = hidden_states
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-@maybe_allow_in_graph
class Attention(nn.Module, AttentionModuleMixin):
default_processor_class = AttnProcessorSDPA
_available_processors = []
@@ -893,11 +303,7 @@ class Attention(nn.Module, AttentionModuleMixin):
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
if processor is None:
- processor = (
- AttnProcessorSDPA()
- if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
- else AttnProcessor()
- )
+ processor = self.default_processor_class()
self.set_processor(processor)
def forward(
@@ -947,97 +353,99 @@ class Attention(nn.Module, AttentionModuleMixin):
)
-class SanaMultiscaleAttentionProjection(nn.Module):
- def __init__(
+class AttnProcessorSDPA:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
self,
- in_channels: int,
- num_attention_heads: int,
- kernel_size: int,
- ) -> None:
- super().__init__()
+ attn: "Attention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
- channels = 3 * in_channels
- self.proj_in = nn.Conv2d(
- channels,
- channels,
- kernel_size,
- padding=kernel_size // 2,
- groups=channels,
- bias=False,
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
- self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.proj_in(hidden_states)
- hidden_states = self.proj_out(hidden_states)
- return hidden_states
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-class SanaMultiscaleLinearAttention(nn.Module):
- r"""Lightweight multi-scale linear attention"""
+ query = attn.to_q(hidden_states)
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- num_attention_heads: Optional[int] = None,
- attention_head_dim: int = 8,
- mult: float = 1.0,
- norm_type: str = "batch_norm",
- kernel_sizes: Tuple[int, ...] = (5,),
- eps: float = 1e-15,
- residual_connection: bool = False,
- ):
- super().__init__()
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- # To prevent circular import
- from .normalization import get_normalization
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
- self.eps = eps
- self.attention_head_dim = attention_head_dim
- self.norm_type = norm_type
- self.residual_connection = residual_connection
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
- num_attention_heads = (
- int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
- inner_dim = num_attention_heads * attention_head_dim
- self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
- self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
- self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
- self.to_qkv_multiscale = nn.ModuleList()
- for kernel_size in kernel_sizes:
- self.to_qkv_multiscale.append(
- SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
- )
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
- self.nonlinearity = nn.ReLU()
- self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
- self.norm_out = get_normalization(norm_type, num_features=out_channels)
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
- self.processor = SanaMultiscaleAttnProcessorSDPA()
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
- def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
- value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding
- scores = torch.matmul(value, key.transpose(-1, -2))
- hidden_states = torch.matmul(scores, query)
+ hidden_states = hidden_states / attn.rescale_output_factor
- hidden_states = hidden_states.to(dtype=torch.float32)
- hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
return hidden_states
- def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
- scores = torch.matmul(key.transpose(-1, -2), query)
- scores = scores.to(dtype=torch.float32)
- scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
- hidden_states = torch.matmul(value, scores.to(value.dtype))
- return hidden_states
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- return self.processor(self, hidden_states)
-
class CustomDiffusionAttnProcessor(nn.Module):
r"""
@@ -5304,98 +4712,104 @@ class StableAudioAttnProcessor2_0:
def __new__(self, *args, **kwargs):
deprecation_message = "`StableAudioAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `StableAudioAttnProcessorSDPA`"
deprecate("StableAudioAttnProcessor2_0", "1.0.0", deprecation_message)
+
return StableAudioAttnProcessorSDPA(*args, **kwargs)
-class HunyuanAttnProcessor2_0(HunyuanAttnProcessorSDPA):
+class HunyuanAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`HunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `HunyuanAttnProcessorSDPA`"
deprecate("HunyuanAttnProcessor2_0", "1.0.0", deprecation_message)
+
return HunyuanAttnProcessorSDPA(*args, **kwargs)
-class FusedHunyuanAttnProcessor2_0(FusedHunyuanAttnProcessorSDPA):
- def __init__(self, *args, **kwargs):
+class FusedHunyuanAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
deprecation_message = "`FusedHunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedHunyuanAttnProcessorSDPA`"
deprecate("FusedHunyuanAttnProcessor2_0", "1.0.0", deprecation_message)
- super().__init__(*args, **kwargs)
+
+ return HunyuanAttnProcessorSDPA(*args, **kwargs)
-class PAGHunyuanAttnProcessor2_0(PAGHunyuanAttnProcessorSDPA):
- def __init__(self, *args, **kwargs):
+class PAGHunyuanAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
deprecation_message = "`PAGHunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGHunyuanAttnProcessorSDPA`"
deprecate("PAGHunyuanAttnProcessor2_0", "1.0.0", deprecation_message)
- super().__init__(*args, **kwargs)
+
+ return PAGHunyuanAttnProcessorSDPA(*args, **kwargs)
-class PAGCFGHunyuanAttnProcessor2_0(PAGCFGHunyuanAttnProcessorSDPA):
- def __init__(self, *args, **kwargs):
+class PAGCFGHunyuanAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
deprecation_message = "`PAGCFGHunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGHunyuanAttnProcessorSDPA`"
deprecate("PAGCFGHunyuanAttnProcessor2_0", "1.0.0", deprecation_message)
- super().__init__(*args, **kwargs)
+
+ return PAGCFGHunyuanAttnProcessorSDPA(*args, **kwargs)
-class LuminaAttnProcessor2_0(LuminaAttnProcessorSDPA):
- def __init__(self, *args, **kwargs):
+class LuminaAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
deprecation_message = "`LuminaAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `LuminaAttnProcessorSDPA`"
deprecate("LuminaAttnProcessor2_0", "1.0.0", deprecation_message)
- super().__init__(*args, **kwargs)
+
+ return LuminaAttnProcessorSDPA(*args, **kwargs)
-class FusedAttnProcessor2_0(FusedAttnProcessorSDPA):
- def __init__(self, *args, **kwargs):
- deprecation_message = "`FusedAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedAttnProcessorSDPA`"
- deprecate("FusedAttnProcessor2_0", "1.0.0", deprecation_message)
- super().__init__(*args, **kwargs)
-
-
-class PAGIdentitySelfAttnProcessor2_0(PAGIdentitySelfAttnProcessorSDPA):
- def __init__(self, *args, **kwargs):
+class PAGIdentitySelfAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
deprecation_message = "`PAGIdentitySelfAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGIdentitySelfAttnProcessorSDPA`"
deprecate("PAGIdentitySelfAttnProcessor2_0", "1.0.0", deprecation_message)
- super().__init__(*args, **kwargs)
+
+ return PAGIdentitySelfAttnProcessorSDPA(*args, **kwargs)
-class PAGCFGIdentitySelfAttnProcessor2_0(PAGCFGIdentitySelfAttnProcessorSDPA):
- def __init__(self, *args, **kwargs):
+class PAGCFGIdentitySelfAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
deprecation_message = "`PAGCFGIdentitySelfAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGIdentitySelfAttnProcessorSDPA`"
deprecate("PAGCFGIdentitySelfAttnProcessor2_0", "1.0.0", deprecation_message)
- super().__init__(*args, **kwargs)
+
+ return PAGCFGIdentitySelfAttnProcessorSDPA(*args, **kwargs)
-class SanaMultiscaleAttnProcessor2_0(SanaMultiscaleAttnProcessorSDPA):
- def __init__(self, *args, **kwargs):
+class SanaMultiscaleAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
deprecation_message = "`SanaMultiscaleAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `SanaMultiscaleAttnProcessorSDPA`"
deprecate("SanaMultiscaleAttnProcessor2_0", "1.0.0", deprecation_message)
- super().__init__(*args, **kwargs)
+
+ return SanaMultiscaleAttnProcessorSDPA(*args, **kwargs)
-class LoRAAttnProcessor2_0(LoRAAttnProcessorSDPA):
- def __init__(self, *args, **kwargs):
+class LoRAAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
deprecation_message = "`LoRAAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `LoRAAttnProcessorSDPA`"
deprecate("LoRAAttnProcessor2_0", "1.0.0", deprecation_message)
- super().__init__(*args, **kwargs)
+
+ return LoRAAttnProcessorSDPA(*args, **kwargs)
-class SanaLinearAttnProcessor2_0(SanaLinearAttnProcessorSDPA):
- def __init__(self, *args, **kwargs):
+class SanaLinearAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
deprecation_message = "`SanaLinearAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `SanaLinearAttnProcessorSDPA`"
deprecate("SanaLinearAttnProcessor2_0", "1.0.0", deprecation_message)
- super().__init__(*args, **kwargs)
+
+ return SanaLinearAttnProcessorSDPA(*args, **kwargs)
-class PAGCFGSanaLinearAttnProcessor2_0(PAGCFGSanaLinearAttnProcessorSDPA):
- def __init__(self, *args, **kwargs):
+class PAGCFGSanaLinearAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
deprecation_message = "`PAGCFGSanaLinearAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGSanaLinearAttnProcessorSDPA`"
deprecate("PAGCFGSanaLinearAttnProcessor2_0", "1.0.0", deprecation_message)
- super().__init__(*args, **kwargs)
+
+ return PAGCFGSanaLinearAttnProcessorSDPA(*args, **kwargs)
-class PAGIdentitySanaLinearAttnProcessor2_0(PAGIdentitySanaLinearAttnProcessorSDPA):
- def __init__(self, *args, **kwargs):
+class PAGIdentitySanaLinearAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
deprecation_message = "`PAGIdentitySanaLinearAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGIdentitySanaLinearAttnProcessorSDPA`"
deprecate("PAGIdentitySanaLinearAttnProcessor2_0", "1.0.0", deprecation_message)
- super().__init__(*args, **kwargs)
+
+ return PAGIdentitySanaLinearAttnProcessorSDPA(*args, **kwargs)
class IPAdapterAttnProcessor(IPAdapterAttnProcessorSDPA):
@@ -5405,11 +4819,12 @@ class IPAdapterAttnProcessor(IPAdapterAttnProcessorSDPA):
super().__init__(*args, **kwargs)
-class IPAdapterAttnProcessor2_0(IPAdapterAttnProcessorSDPA):
- def __init__(self, *args, **kwargs) -> None:
+class IPAdapterAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
deprecation_message = "`IPAdapterAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `IPAdapterAttnProcessorSDPA`"
deprecate("IPAdapterAttnProcessor2_0", "1.0.0", deprecation_message)
- super().__init__(*args, **kwargs)
+
+ return IPAdapterAttnProcessorSDPA(*args, **kwargs)
ADDED_KV_ATTENTION_PROCESSORS = (
diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py
index 9146aa5c7c..9c1f76ba41 100644
--- a/src/diffusers/models/autoencoders/autoencoder_dc.py
+++ b/src/diffusers/models/autoencoders/autoencoder_dc.py
@@ -62,6 +62,98 @@ class ResBlock(nn.Module):
return hidden_states + residual
+class SanaMultiscaleAttentionProjection(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ num_attention_heads: int,
+ kernel_size: int,
+ ) -> None:
+ super().__init__()
+
+ channels = 3 * in_channels
+ self.proj_in = nn.Conv2d(
+ channels,
+ channels,
+ kernel_size,
+ padding=kernel_size // 2,
+ groups=channels,
+ bias=False,
+ )
+ self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+ return hidden_states
+
+
+class SanaMultiscaleLinearAttention(nn.Module):
+ r"""Lightweight multi-scale linear attention"""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_attention_heads: Optional[int] = None,
+ attention_head_dim: int = 8,
+ mult: float = 1.0,
+ norm_type: str = "batch_norm",
+ kernel_sizes: Tuple[int, ...] = (5,),
+ eps: float = 1e-15,
+ residual_connection: bool = False,
+ ):
+ super().__init__()
+
+ # To prevent circular import
+ from ..normalization import get_normalization
+
+ self.eps = eps
+ self.attention_head_dim = attention_head_dim
+ self.norm_type = norm_type
+ self.residual_connection = residual_connection
+
+ num_attention_heads = (
+ int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
+ )
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
+ self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
+ self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
+
+ self.to_qkv_multiscale = nn.ModuleList()
+ for kernel_size in kernel_sizes:
+ self.to_qkv_multiscale.append(
+ SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
+ )
+
+ self.nonlinearity = nn.ReLU()
+ self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
+ self.norm_out = get_normalization(norm_type, num_features=out_channels)
+
+ self.processor = SanaMultiscaleAttnProcessorSDPA()
+
+ def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding
+ scores = torch.matmul(value, key.transpose(-1, -2))
+ hidden_states = torch.matmul(scores, query)
+
+ hidden_states = hidden_states.to(dtype=torch.float32)
+ hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
+ return hidden_states
+
+ def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
+ scores = torch.matmul(key.transpose(-1, -2), query)
+ scores = scores.to(dtype=torch.float32)
+ scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
+ hidden_states = torch.matmul(value, scores.to(value.dtype))
+ return hidden_states
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.processor(self, hidden_states)
+
+
class EfficientViTBlock(nn.Module):
def __init__(
self,
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
index a76277366c..7da90205b1 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
@@ -21,7 +21,8 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
-from ..attention_processor import Attention, SpatialNorm
+from ..attention import Attention
+from ..attention_processor import SpatialNorm
from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
from ..downsampling import Downsample2D
from ..modeling_outputs import AutoencoderKLOutput
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
index a32f4bfd76..4086095440 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
@@ -24,7 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
-from ..attention_processor import Attention
+from ..attention import Attention
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
index edf270f66e..1e5eb7805d 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
@@ -23,7 +23,8 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
-from ..attention_processor import Attention, MochiVaeAttnProcessor2_0
+from ..attention import Attention
+from ..attention_processor import MochiVaeAttnProcessor2_0
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d
diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py
index adc1716069..d050c5cf2d 100644
--- a/src/diffusers/models/controlnets/controlnet_sd3.py
+++ b/src/diffusers/models/controlnets/controlnet_sd3.py
@@ -22,7 +22,8 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
-from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
+from ..attention import Attention
+from ..attention_processor import AttentionProcessor, FusedJointAttnProcessor2_0
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index b1e14ca6a7..9324b85bd5 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -21,7 +21,7 @@ from torch import nn
from ..utils import deprecate
from .activations import FP32SiLU, get_activation
-from .attention_processor import Attention
+from .attention import Attention
def get_timestep_embedding(
diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py
index 8781424c61..e78f77ebda 100644
--- a/src/diffusers/models/transformers/auraflow_transformer_2d.py
+++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py
@@ -23,11 +23,9 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import Attention, AttentionMixin
from ..attention_processor import (
- Attention,
- AttentionProcessor,
AuraFlowAttnProcessor2_0,
- FusedAuraFlowAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
@@ -267,7 +265,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
-class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
r"""
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
@@ -357,105 +355,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedAuraFlowAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
+ # Using methods from AttentionMixin
def forward(
self,
diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
index 5e0965572d..0e7da957c5 100644
--- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py
+++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
@@ -22,12 +22,9 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention import Attention
+from ..attention import AttentionMixin
from ..attention_processor import (
AttentionModuleMixin,
- AttentionProcessor,
- CogVideoXAttnProcessor2_0,
- FusedCogVideoXAttnProcessor2_0,
)
from ..cache_utils import CacheMixin
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
@@ -103,7 +100,7 @@ class BaseCogVideoXAttnProcessor:
def __call__(
self,
- attn: Attention,
+ attn: CogVideoXAttention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
@@ -260,7 +257,7 @@ class CogVideoXBlock(nn.Module):
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
- self.attn1 = Attention(
+ self.attn1 = CogVideoXAttention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
@@ -268,7 +265,6 @@ class CogVideoXBlock(nn.Module):
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
- processor=CogVideoXAttnProcessor2_0(),
)
# 2. Feed Forward
@@ -325,7 +321,7 @@ class CogVideoXBlock(nn.Module):
return hidden_states, encoder_hidden_states
-class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
+class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin, AttentionMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
@@ -499,105 +495,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
+ # Using inherited methods from AttentionMixin
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
+ # Using inherited methods from AttentionMixin
def forward(
self,
diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py
index 9e5f5e5db3..8fdad47838 100644
--- a/src/diffusers/models/transformers/consisid_transformer_3d.py
+++ b/src/diffusers/models/transformers/consisid_transformer_3d.py
@@ -22,8 +22,8 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention import Attention
-from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
+from ..attention import Attention, AttentionMixin
+from ..attention_processor import CogVideoXAttnProcessor2_0
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -349,7 +349,7 @@ class ConsisIDBlock(nn.Module):
return hidden_states, encoder_hidden_states
-class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, AttentionMixin):
"""
A Transformer model for video-like data in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID).
@@ -621,65 +621,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
]
)
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
+ # Using methods from AttentionMixin
def forward(
self,
diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py
index e4f2a4e8c8..844eb26fa4 100644
--- a/src/diffusers/models/transformers/dit_transformer_2d.py
+++ b/src/diffusers/models/transformers/dit_transformer_2d.py
@@ -19,16 +19,17 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
-from .modeling_common BasicTransformerBlock
+from ..attention import AttentionMixin
from ..embeddings import PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
+from .modeling_common import BasicTransformerBlock
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class DiTTransformer2DModel(ModelMixin, ConfigMixin):
+class DiTTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
r"""
A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748).
diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py
index dc40ebcac2..78e929b485 100644
--- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py
+++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Optional, Union
+from typing import Optional
import torch
from torch import nn
@@ -19,7 +19,8 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0
+from ..attention import Attention, AttentionMixin
+from ..attention_processor import HunyuanAttnProcessor2_0
from ..embeddings import (
HunyuanCombinedTimestepTextSizeStyleEmbedding,
PatchEmbed,
@@ -200,7 +201,7 @@ class HunyuanDiTBlock(nn.Module):
return hidden_states
-class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
+class HunyuanDiT2DModel(ModelMixin, ConfigMixin, AttentionMixin):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
@@ -318,105 +319,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedHunyuanAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
-
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
+ # Using methods from AttentionMixin
def set_default_attn_processor(self):
"""
diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py
index 102a27b472..f2f610054e 100644
--- a/src/diffusers/models/transformers/latte_transformer_3d.py
+++ b/src/diffusers/models/transformers/latte_transformer_3d.py
@@ -19,15 +19,16 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
-from .modeling_common BasicTransformerBlock
+from ..attention import AttentionMixin
from ..cache_utils import CacheMixin
from ..embeddings import PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
+from .modeling_common import BasicTransformerBlock
-class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
+class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin, AttentionMixin):
_supports_gradient_checkpointing = True
"""
diff --git a/src/diffusers/models/transformers/modeling_common.py b/src/diffusers/models/transformers/modeling_common.py
index 93b11c2b43..a2ab97769e 100644
--- a/src/diffusers/models/transformers/modeling_common.py
+++ b/src/diffusers/models/transformers/modeling_common.py
@@ -17,12 +17,12 @@ import torch
import torch.nn.functional as F
from torch import nn
-from ..utils import deprecate, logging
-from ..utils.torch_utils import maybe_allow_in_graph
-from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
-from .attention_processor import Attention, JointAttnProcessor2_0
-from .embeddings import SinusoidalPositionalEmbedding
-from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
+from ...utils import deprecate, logging
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
+from ..attention_processor import Attention, JointAttnProcessor2_0
+from ..embeddings import SinusoidalPositionalEmbedding
+from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
logger = logging.get_logger(__name__)
diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py
index 50a0e3e67d..741697147e 100644
--- a/src/diffusers/models/transformers/pixart_transformer_2d.py
+++ b/src/diffusers/models/transformers/pixart_transformer_2d.py
@@ -11,25 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, Optional
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
-from .modeling_common BasicTransformerBlock
-from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
+from ..attention import AttentionMixin
+from ..attention_processor import AttnProcessor
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
+from .modeling_common import BasicTransformerBlock
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
+class PixArtTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
r"""
A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
https://arxiv.org/abs/2403.04692).
@@ -184,65 +185,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
in_features=self.config.caption_channels, hidden_size=self.inner_dim
)
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
+ # Using inherited method from AttentionMixin
def set_default_attn_processor(self):
"""
@@ -252,45 +195,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
"""
self.set_attn_processor(AttnProcessor())
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
+ # Using inherited methods from AttentionMixin
def forward(
self,
diff --git a/src/diffusers/models/transformers/prior_transformer.py b/src/diffusers/models/transformers/prior_transformer.py
index c0b00da671..7ef83f2f30 100644
--- a/src/diffusers/models/transformers/prior_transformer.py
+++ b/src/diffusers/models/transformers/prior_transformer.py
@@ -1,5 +1,5 @@
from dataclasses import dataclass
-from typing import Dict, Optional, Union
+from typing import Optional, Union
import torch
import torch.nn.functional as F
@@ -8,16 +8,16 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput
-from .modeling_common BasicTransformerBlock
+from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
+from .modeling_common import BasicTransformerBlock
@dataclass
@@ -33,7 +33,7 @@ class PriorTransformerOutput(BaseOutput):
predicted_image_embedding: torch.Tensor
-class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
+class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin, AttentionMixin):
"""
A Prior Transformer model.
@@ -166,65 +166,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
+ # Using inherited methods from AttentionMixin
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py
index f844130fab..eb93f0f521 100644
--- a/src/diffusers/models/transformers/sana_transformer.py
+++ b/src/diffusers/models/transformers/sana_transformer.py
@@ -21,10 +21,9 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import Attention, AttentionMixin
from ..attention_processor import (
- Attention,
AttentionModuleMixin,
- AttentionProcessor,
SanaLinearAttnProcessor2_0,
)
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
@@ -388,7 +387,7 @@ class SanaTransformerBlock(nn.Module):
return hidden_states
-class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
r"""
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
@@ -513,65 +512,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
+ # Using methods from AttentionMixin
def forward(
self,
diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py
index d81b6447ad..c37a972dc9 100644
--- a/src/diffusers/models/transformers/stable_audio_transformer.py
+++ b/src/diffusers/models/transformers/stable_audio_transformer.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from typing import Dict, Optional, Union
+from typing import Optional, Union
import numpy as np
import torch
@@ -21,10 +21,8 @@ import torch.nn as nn
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
-from ...models.attention import FeedForward
+from ...models.attention import Attention, AttentionMixin, FeedForward
from ...models.attention_processor import (
- Attention,
- AttentionProcessor,
StableAudioAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
@@ -187,7 +185,7 @@ class StableAudioDiTBlock(nn.Module):
return hidden_states
-class StableAudioDiTModel(ModelMixin, ConfigMixin):
+class StableAudioDiTModel(ModelMixin, ConfigMixin, AttentionMixin):
"""
The Diffusion Transformer model introduced in Stable Audio.
@@ -279,65 +277,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
+ # Using methods from AttentionMixin
# Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio
def set_default_attn_processor(self):
diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py
index bef841d8ce..52a9f3b75a 100644
--- a/src/diffusers/models/transformers/transformer_2d.py
+++ b/src/diffusers/models/transformers/transformer_2d.py
@@ -19,11 +19,12 @@ from torch import nn
from ...configuration_utils import LegacyConfigMixin, register_to_config
from ...utils import deprecate, logging
-from .modeling_common BasicTransformerBlock
+from ..attention import AttentionMixin
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import LegacyModelMixin
from ..normalization import AdaLayerNormSingle
+from .modeling_common import BasicTransformerBlock
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -36,7 +37,7 @@ class Transformer2DModelOutput(Transformer2DModelOutput):
super().__init__(*args, **kwargs)
-class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
+class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin, AttentionMixin):
"""
A 2D Transformer model for image-like data.
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index da7133791f..02af3304fb 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -13,16 +13,14 @@
# limitations under the License.
-from typing import Dict, Union
+from typing import Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
-from ...models.attention import FeedForward
+from ...models.attention import Attention, AttentionMixin, FeedForward
from ...models.attention_processor import (
- Attention,
- AttentionProcessor,
CogVideoXAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
@@ -130,7 +128,7 @@ class CogView3PlusTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
-class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
+class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
r"""
The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay
Diffusion](https://huggingface.co/papers/2403.05121).
@@ -229,65 +227,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
+ # Using methods from AttentionMixin
def forward(
self,
diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py
index 7abd0d6d10..8111686349 100644
--- a/src/diffusers/models/transformers/transformer_flux.py
+++ b/src/diffusers/models/transformers/transformer_flux.py
@@ -24,12 +24,13 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import FeedForward
-from ...models.attention_processor import AttentionModuleMixin, AttentionProcessor
+from ...models.attention_processor import AttentionModuleMixin
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_torch_xla_version
from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionMixin
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput
@@ -592,7 +593,7 @@ class FluxTransformerBlock(nn.Module):
class FluxTransformer2DModel(
- ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin, AttentionMixin
):
"""
The Transformer model introduced in Flux.
@@ -687,97 +688,9 @@ class FluxTransformer2DModel(
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
+ # Using inherited methods from AttentionMixin
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, AttentionModuleMixin):
- module.fuse_projections(fuse=True)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
+ # Using inherited methods from AttentionMixin
def forward(
self,
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py
index 39ffb038ce..8b4149ac4d 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py
@@ -23,7 +23,7 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
-from ..attention_processor import Attention, AttentionProcessor
+from ..attention import Attention, AttentionMixin
from ..cache_utils import CacheMixin
from ..embeddings import (
CombinedTimestepTextProjEmbeddings,
@@ -819,7 +819,7 @@ class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
-class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin):
r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
@@ -962,65 +962,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
+ # Using methods from AttentionMixin
def forward(
self,
diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py
index 3c2d6904bd..e0a06597f7 100644
--- a/src/diffusers/models/transformers/transformer_sd3.py
+++ b/src/diffusers/models/transformers/transformer_sd3.py
@@ -20,15 +20,13 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
from ...models.attention import FeedForward, JointTransformerBlock
from ...models.attention_processor import (
- Attention,
AttentionModuleMixin,
- AttentionProcessor,
- FusedJointAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionMixin
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
@@ -280,7 +278,7 @@ class SD3SingleTransformerBlock(nn.Module):
class SD3Transformer2DModel(
- ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin, AttentionMixin
):
"""
The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
@@ -416,105 +414,9 @@ class SD3Transformer2DModel(
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
+ # Using inherited methods from AttentionMixin
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedJointAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
+ # Using inherited methods from AttentionMixin
def forward(
self,
diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py
index b69b04fdcb..1288c904f6 100644
--- a/src/diffusers/models/transformers/transformer_temporal.py
+++ b/src/diffusers/models/transformers/transformer_temporal.py
@@ -19,10 +19,11 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
-from .modeling_common BasicTransformerBlock, TemporalBasicTransformerBlock
+from ..attention import AttentionMixin
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..resnet import AlphaBlender
+from .modeling_common import BasicTransformerBlock, TemporalBasicTransformerBlock
@dataclass
@@ -38,7 +39,7 @@ class TransformerTemporalModelOutput(BaseOutput):
sample: torch.Tensor
-class TransformerTemporalModel(ModelMixin, ConfigMixin):
+class TransformerTemporalModel(ModelMixin, ConfigMixin, AttentionMixin):
"""
A Transformer model for video-like data.
@@ -202,7 +203,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
return TransformerTemporalModelOutput(sample=output)
-class TransformerSpatioTemporalModel(nn.Module):
+class TransformerSpatioTemporalModel(nn.Module, AttentionMixin):
"""
A Transformer model for video-like data.
diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py
index e082d524e7..dd1b9eae6e 100644
--- a/src/diffusers/models/unets/unet_2d_blocks.py
+++ b/src/diffusers/models/unets/unet_2d_blocks.py
@@ -21,7 +21,8 @@ from torch import nn
from ...utils import deprecate, logging
from ...utils.torch_utils import apply_freeu
from ..activations import get_activation
-from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
+from ..attention import Attention
+from ..attention_processor import AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from ..normalization import AdaGroupNorm
from ..resnet import (
Downsample2D,
diff --git a/src/diffusers/models/unets/unet_kandinsky3.py b/src/diffusers/models/unets/unet_kandinsky3.py
index 73bf0020b4..ee01ae933e 100644
--- a/src/diffusers/models/unets/unet_kandinsky3.py
+++ b/src/diffusers/models/unets/unet_kandinsky3.py
@@ -21,7 +21,8 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, logging
-from ..attention_processor import Attention, AttentionProcessor, AttnProcessor
+from ..attention import Attention
+from ..attention_processor import AttentionProcessor, AttnProcessor
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py
index f57754435f..b90bd1898a 100644
--- a/src/diffusers/models/unets/unet_stable_cascade.py
+++ b/src/diffusers/models/unets/unet_stable_cascade.py
@@ -23,7 +23,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import BaseOutput
-from ..attention_processor import Attention
+from ..attention import Attention
from ..modeling_utils import ModelMixin