diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 93b11c2b43..2bcbeb27e9 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -11,36 +11,822 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn -from ..utils import deprecate, logging + +# Import xformers only if it's available +try: + import xformers + import xformers.ops +except ImportError: + xformers = None + +from ..utils import logging +from ..utils.import_utils import ( + is_torch_npu_available, + is_torch_xla_available, + is_xformers_available, +) from ..utils.torch_utils import maybe_allow_in_graph -from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU -from .attention_processor import Attention, JointAttnProcessor2_0 -from .embeddings import SinusoidalPositionalEmbedding -from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX +from .attention_processor import ( + Attention, + AttentionProcessor, + AttnProcessor, +) +from .normalization import RMSNorm logger = logging.get_logger(__name__) -def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): - # "feed_forward_chunk_size" can be used to save memory - if hidden_states.shape[chunk_dim] % chunk_size != 0: - raise ValueError( - f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." - ) +class AttentionMixin: + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} - num_chunks = hidden_states.shape[chunk_dim] // chunk_size - ff_output = torch.cat( - [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], - dim=chunk_dim, - ) - return ff_output + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, AttentionModuleMixin): + module.fuse_projections(fuse=True) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + +class AttentionModuleMixin: + """ + A mixin class that provides common methods for attention modules. + + This mixin adds functionality to set different attention processors, handle attention masks, compute attention + scores, and manage projections. + """ + + # Default processor classes to be overridden by subclasses + default_processor_cls = None + _available_processors = [] + + def _get_compatible_processor(self, backend): + for processor_cls in self._available_processors: + if backend in processor_cls.compatible_backends: + processor = processor_cls() + return processor + + def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: + """ + Set whether to use NPU flash attention from `torch_npu` or not. + + Args: + use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not. + """ + processor = self.default_processor_cls() + + if use_npu_flash_attention: + if not is_torch_npu_available(): + raise ImportError("torch_npu is not available") + processor = self._get_compatible_processor("npu") + + self.set_processor(processor) + + def set_use_xla_flash_attention( + self, + use_xla_flash_attention: bool, + partition_spec: Optional[Tuple[Optional[str], ...]] = None, + is_flux=False, + ) -> None: + """ + Set whether to use XLA flash attention from `torch_xla` or not. + + Args: + use_xla_flash_attention (`bool`): + Whether to use pallas flash attention kernel from `torch_xla` or not. + partition_spec (`Tuple[]`, *optional*): + Specify the partition specification if using SPMD. Otherwise None. + is_flux (`bool`, *optional*, defaults to `False`): + Whether the model is a Flux model. + """ + processor = self.default_processor_cls() + if use_xla_flash_attention: + if not is_torch_xla_available(): + raise ImportError("torch_xla is not available") + processor = self._get_compatible_processor("xla") + + self.set_processor(processor) + + @torch.no_grad() + def fuse_projections(self, fuse=True): + """ + Fuse the query, key, and value projections into a single projection for efficiency. + + Args: + fuse (`bool`): Whether to fuse the projections or not. + """ + # Skip if already in desired state + if getattr(self, "fused_projections", False) == fuse: + return + + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if not self.is_cross_attention: + # Fuse self-attention projections + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + self.to_qkv.bias.copy_(concatenated_bias) + + else: + # Fuse cross-attention key-value projections + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + self.to_kv.bias.copy_(concatenated_bias) + + # Handle added projections for models like SD3, Flux, etc. + if ( + getattr(self, "add_q_proj", None) is not None + and getattr(self, "add_k_proj", None) is not None + and getattr(self, "add_v_proj", None) is not None + ): + concatenated_weights = torch.cat( + [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Linear( + in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype + ) + self.to_added_qkv.weight.copy_(concatenated_weights) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) + self.to_added_qkv.bias.copy_(concatenated_bias) + + self.fused_projections = fuse + self.processor.is_fused = fuse + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + """ + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + if use_memory_efficient_attention_xformers: + if not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + if xformers is not None: + dtype = None + if attention_op is not None: + op_fw, op_bw = attention_op + dtype, *_ = op_fw.SUPPORTED_DTYPES + q = torch.randn((1, 2, 40), device="cuda", dtype=dtype) + _ = xformers.ops.memory_efficient_attention(q, q, q) + except Exception as e: + raise e + + processor = self._get_compatible_processor("xformers") + else: + # Set default processor + processor = self.default_processor_cls() + + if processor is not None: + self.set_processor(processor) + + def set_attention_slice(self, slice_size: int) -> None: + """ + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + processor = None + + # Try to get a compatible processor for sliced attention + if slice_size is not None: + processor = self._get_compatible_processor("sliced") + + # If no processor was found or slice_size is None, use default processor + if processor is None: + processor = self.default_processor_cls() + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + """ + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + """ + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + """ + Reshape the tensor for multi-head attention processing. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + """ + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): The attention mask to prepare. + target_length (`int`): The target length of the attention mask. + batch_size (`int`): The batch size for repeating the attention mask. + out_dim (`int`, *optional*, defaults to `3`): Output dimension. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + """ + Normalize the encoder hidden states. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + +@maybe_allow_in_graph +class Attention(nn.Module, AttentionModuleMixin): + default_processor_class = AttnProcessorSDPA + _available_processors = [] + + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + kv_heads (`int`, *optional*, defaults to `None`): + The number of key and value heads to use for multi-head attention. Defaults to `heads`. If + `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi + Query Attention (MQA) otherwise GQA is used. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessorSDPA` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + out_context_dim: int = None, + context_pre_only=None, + pre_only=False, + elementwise_affine: bool = True, + is_causal: bool = False, + ): + super().__init__() + + # To prevent circular import. + from .normalization import FP32LayerNorm, LpNorm + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.is_causal = is_causal + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "fp32_layer_norm": + self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "layer_norm_across_heads": + # Lumina applies qk norm across all heads + self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) + self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim_head * heads, eps=eps) + self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "l2": + self.norm_q = LpNorm(p=2, dim=-1, eps=eps) + self.norm_k = LpNorm(p=2, dim=-1, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'." + ) + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "layer_norm": + self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # Wan applies qk norm across all heads + # Wan also doesn't apply a q norm + self.norm_added_q = None + self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" + ) + else: + self.norm_added_q = None + self.norm_added_k = None + + # set attention processor + # We use the AttnProcessorSDPA by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = self.default_processor_class() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) @maybe_allow_in_graph @@ -83,1169 +869,3 @@ class GatedSelfAttentionDense(nn.Module): x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) return x - - -@maybe_allow_in_graph -class JointTransformerBlock(nn.Module): - r""" - A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. - - Reference: https://arxiv.org/abs/2403.03206 - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the - processing of `context` conditions. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - context_pre_only: bool = False, - qk_norm: Optional[str] = None, - use_dual_attention: bool = False, - ): - super().__init__() - - self.use_dual_attention = use_dual_attention - self.context_pre_only = context_pre_only - context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" - - if use_dual_attention: - self.norm1 = SD35AdaLayerNormZeroX(dim) - else: - self.norm1 = AdaLayerNormZero(dim) - - if context_norm_type == "ada_norm_continous": - self.norm1_context = AdaLayerNormContinuous( - dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" - ) - elif context_norm_type == "ada_norm_zero": - self.norm1_context = AdaLayerNormZero(dim) - else: - raise ValueError( - f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" - ) - - if hasattr(F, "scaled_dot_product_attention"): - processor = JointAttnProcessor2_0() - else: - raise ValueError( - "The current PyTorch version does not support the `scaled_dot_product_attention` function." - ) - - self.attn = Attention( - query_dim=dim, - cross_attention_dim=None, - added_kv_proj_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - context_pre_only=context_pre_only, - bias=True, - processor=processor, - qk_norm=qk_norm, - eps=1e-6, - ) - - if use_dual_attention: - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=None, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - bias=True, - processor=processor, - qk_norm=qk_norm, - eps=1e-6, - ) - else: - self.attn2 = None - - self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - - if not context_pre_only: - self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - else: - self.norm2_context = None - self.ff_context = None - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor, - temb: torch.FloatTensor, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - ): - joint_attention_kwargs = joint_attention_kwargs or {} - if self.use_dual_attention: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( - hidden_states, emb=temb - ) - else: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) - - if self.context_pre_only: - norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) - else: - norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( - encoder_hidden_states, emb=temb - ) - - # Attention. - attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - **joint_attention_kwargs, - ) - - # Process attention outputs for the `hidden_states`. - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = hidden_states + attn_output - - if self.use_dual_attention: - attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs) - attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 - hidden_states = hidden_states + attn_output2 - - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - else: - ff_output = self.ff(norm_hidden_states) - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = hidden_states + ff_output - - # Process attention outputs for the `encoder_hidden_states`. - if self.context_pre_only: - encoder_hidden_states = None - else: - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output - - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - context_ff_output = _chunked_feed_forward( - self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size - ) - else: - context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output - - return encoder_hidden_states, hidden_states - - -@maybe_allow_in_graph -class BasicTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. - upcast_attention (`bool`, *optional*): - Whether to upcast the attention computation to float32. This is useful for mixed precision training. - norm_elementwise_affine (`bool`, *optional*, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_type (`str`, *optional*, defaults to `"layer_norm"`): - The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. - final_dropout (`bool` *optional*, defaults to False): - Whether to apply a final dropout after the last feed-forward layer. - attention_type (`str`, *optional*, defaults to `"default"`): - The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. - positional_embeddings (`str`, *optional*, defaults to `None`): - The type of positional embeddings to apply to. - num_positional_embeddings (`int`, *optional*, defaults to `None`): - The maximum number of positional embeddings to apply. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' - norm_eps: float = 1e-5, - final_dropout: bool = False, - attention_type: str = "default", - positional_embeddings: Optional[str] = None, - num_positional_embeddings: Optional[int] = None, - ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, - ada_norm_bias: Optional[int] = None, - ff_inner_dim: Optional[int] = None, - ff_bias: bool = True, - attention_out_bias: bool = True, - ): - super().__init__() - self.dim = dim - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - self.dropout = dropout - self.cross_attention_dim = cross_attention_dim - self.activation_fn = activation_fn - self.attention_bias = attention_bias - self.double_self_attention = double_self_attention - self.norm_elementwise_affine = norm_elementwise_affine - self.positional_embeddings = positional_embeddings - self.num_positional_embeddings = num_positional_embeddings - self.only_cross_attention = only_cross_attention - - # We keep these boolean flags for backward-compatibility. - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - self.use_ada_layer_norm_single = norm_type == "ada_norm_single" - self.use_layer_norm = norm_type == "layer_norm" - self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - self.norm_type = norm_type - self.num_embeds_ada_norm = num_embeds_ada_norm - - if positional_embeddings and (num_positional_embeddings is None): - raise ValueError( - "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." - ) - - if positional_embeddings == "sinusoidal": - self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) - else: - self.pos_embed = None - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - if norm_type == "ada_norm": - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif norm_type == "ada_norm_zero": - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - elif norm_type == "ada_norm_continuous": - self.norm1 = AdaLayerNormContinuous( - dim, - ada_norm_continous_conditioning_embedding_dim, - norm_elementwise_affine, - norm_eps, - ada_norm_bias, - "rms_norm", - ) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - out_bias=attention_out_bias, - ) - - # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - if norm_type == "ada_norm": - self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif norm_type == "ada_norm_continuous": - self.norm2 = AdaLayerNormContinuous( - dim, - ada_norm_continous_conditioning_embedding_dim, - norm_elementwise_affine, - norm_eps, - ada_norm_bias, - "rms_norm", - ) - else: - self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - out_bias=attention_out_bias, - ) # is self-attn if encoder_hidden_states is none - else: - if norm_type == "ada_norm_single": # For Latte - self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - else: - self.norm2 = None - self.attn2 = None - - # 3. Feed-forward - if norm_type == "ada_norm_continuous": - self.norm3 = AdaLayerNormContinuous( - dim, - ada_norm_continous_conditioning_embedding_dim, - norm_elementwise_affine, - norm_eps, - ada_norm_bias, - "layer_norm", - ) - - elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: - self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - elif norm_type == "layer_norm_i2vgen": - self.norm3 = None - - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - inner_dim=ff_inner_dim, - bias=ff_bias, - ) - - # 4. Fuser - if attention_type == "gated" or attention_type == "gated-text-image": - self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) - - # 5. Scale-shift for PixArt-Alpha. - if norm_type == "ada_norm_single": - self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - ) -> torch.Tensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - # Notice that normalization is always applied before the real computation in the following blocks. - # 0. Self-Attention - batch_size = hidden_states.shape[0] - - if self.norm_type == "ada_norm": - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.norm_type == "ada_norm_zero": - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm1(hidden_states) - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) - elif self.norm_type == "ada_norm_single": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) - ).chunk(6, dim=1) - norm_hidden_states = self.norm1(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - else: - raise ValueError("Incorrect norm used") - - if self.pos_embed is not None: - norm_hidden_states = self.pos_embed(norm_hidden_states) - - # 1. Prepare GLIGEN inputs - cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - gligen_kwargs = cross_attention_kwargs.pop("gligen", None) - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - if self.norm_type == "ada_norm_zero": - attn_output = gate_msa.unsqueeze(1) * attn_output - elif self.norm_type == "ada_norm_single": - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - # 1.2 GLIGEN Control - if gligen_kwargs is not None: - hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) - - # 3. Cross-Attention - if self.attn2 is not None: - if self.norm_type == "ada_norm": - norm_hidden_states = self.norm2(hidden_states, timestep) - elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm2(hidden_states) - elif self.norm_type == "ada_norm_single": - # For PixArt norm2 isn't applied here: - # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 - norm_hidden_states = hidden_states - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) - else: - raise ValueError("Incorrect norm") - - if self.pos_embed is not None and self.norm_type != "ada_norm_single": - norm_hidden_states = self.pos_embed(norm_hidden_states) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 4. Feed-forward - # i2vgen doesn't have this norm 🤷‍♂️ - if self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) - elif not self.norm_type == "ada_norm_single": - norm_hidden_states = self.norm3(hidden_states) - - if self.norm_type == "ada_norm_zero": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - if self.norm_type == "ada_norm_single": - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp - - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - else: - ff_output = self.ff(norm_hidden_states) - - if self.norm_type == "ada_norm_zero": - ff_output = gate_mlp.unsqueeze(1) * ff_output - elif self.norm_type == "ada_norm_single": - ff_output = gate_mlp * ff_output - - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - return hidden_states - - -class LuminaFeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - hidden_size (`int`): - The dimensionality of the hidden layers in the model. This parameter determines the width of the model's - hidden representations. - intermediate_size (`int`): The intermediate dimension of the feedforward layer. - multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple - of this value. - ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden - dimension. Defaults to None. - """ - - def __init__( - self, - dim: int, - inner_dim: int, - multiple_of: Optional[int] = 256, - ffn_dim_multiplier: Optional[float] = None, - ): - super().__init__() - # custom hidden_size factor multiplier - if ffn_dim_multiplier is not None: - inner_dim = int(ffn_dim_multiplier * inner_dim) - inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) - - self.linear_1 = nn.Linear( - dim, - inner_dim, - bias=False, - ) - self.linear_2 = nn.Linear( - inner_dim, - dim, - bias=False, - ) - self.linear_3 = nn.Linear( - dim, - inner_dim, - bias=False, - ) - self.silu = FP32SiLU() - - def forward(self, x): - return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x)) - - -@maybe_allow_in_graph -class TemporalBasicTransformerBlock(nn.Module): - r""" - A basic Transformer block for video like data. - - Parameters: - dim (`int`): The number of channels in the input and output. - time_mix_inner_dim (`int`): The number of channels for temporal attention. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - """ - - def __init__( - self, - dim: int, - time_mix_inner_dim: int, - num_attention_heads: int, - attention_head_dim: int, - cross_attention_dim: Optional[int] = None, - ): - super().__init__() - self.is_res = dim == time_mix_inner_dim - - self.norm_in = nn.LayerNorm(dim) - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - self.ff_in = FeedForward( - dim, - dim_out=time_mix_inner_dim, - activation_fn="geglu", - ) - - self.norm1 = nn.LayerNorm(time_mix_inner_dim) - self.attn1 = Attention( - query_dim=time_mix_inner_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - cross_attention_dim=None, - ) - - # 2. Cross-Attn - if cross_attention_dim is not None: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = nn.LayerNorm(time_mix_inner_dim) - self.attn2 = Attention( - query_dim=time_mix_inner_dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - ) # is self-attn if encoder_hidden_states is none - else: - self.norm2 = None - self.attn2 = None - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(time_mix_inner_dim) - self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = None - - def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): - # Sets chunk feed-forward - self._chunk_size = chunk_size - # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off - self._chunk_dim = 1 - - def forward( - self, - hidden_states: torch.Tensor, - num_frames: int, - encoder_hidden_states: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # Notice that normalization is always applied before the real computation in the following blocks. - # 0. Self-Attention - batch_size = hidden_states.shape[0] - - batch_frames, seq_length, channels = hidden_states.shape - batch_size = batch_frames // num_frames - - hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) - hidden_states = hidden_states.permute(0, 2, 1, 3) - hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) - - residual = hidden_states - hidden_states = self.norm_in(hidden_states) - - if self._chunk_size is not None: - hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) - else: - hidden_states = self.ff_in(hidden_states) - - if self.is_res: - hidden_states = hidden_states + residual - - norm_hidden_states = self.norm1(hidden_states) - attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) - hidden_states = attn_output + hidden_states - - # 3. Cross-Attention - if self.attn2 is not None: - norm_hidden_states = self.norm2(hidden_states) - attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) - hidden_states = attn_output + hidden_states - - # 4. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self._chunk_size is not None: - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - else: - ff_output = self.ff(norm_hidden_states) - - if self.is_res: - hidden_states = ff_output + hidden_states - else: - hidden_states = ff_output - - hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) - hidden_states = hidden_states.permute(0, 2, 1, 3) - hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) - - return hidden_states - - -class SkipFFTransformerBlock(nn.Module): - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - kv_input_dim: int, - kv_input_dim_proj_use_bias: bool, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - attention_out_bias: bool = True, - ): - super().__init__() - if kv_input_dim != dim: - self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) - else: - self.kv_mapper = None - - self.norm1 = RMSNorm(dim, 1e-06) - - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim, - out_bias=attention_out_bias, - ) - - self.norm2 = RMSNorm(dim, 1e-06) - - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - out_bias=attention_out_bias, - ) - - def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): - cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - - if self.kv_mapper is not None: - encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) - - norm_hidden_states = self.norm1(hidden_states) - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - **cross_attention_kwargs, - ) - - hidden_states = attn_output + hidden_states - - norm_hidden_states = self.norm2(hidden_states) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - **cross_attention_kwargs, - ) - - hidden_states = attn_output + hidden_states - - return hidden_states - - -@maybe_allow_in_graph -class FreeNoiseTransformerBlock(nn.Module): - r""" - A FreeNoise Transformer block. - - Parameters: - dim (`int`): - The number of channels in the input and output. - num_attention_heads (`int`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`): - The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): - The dropout probability to use. - cross_attention_dim (`int`, *optional*): - The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): - Activation function to be used in feed-forward. - num_embeds_ada_norm (`int`, *optional*): - The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (`bool`, defaults to `False`): - Configure if the attentions should contain a bias parameter. - only_cross_attention (`bool`, defaults to `False`): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, defaults to `False`): - Whether to use two self-attention layers. In this case no cross attention layers are used. - upcast_attention (`bool`, defaults to `False`): - Whether to upcast the attention computation to float32. This is useful for mixed precision training. - norm_elementwise_affine (`bool`, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_type (`str`, defaults to `"layer_norm"`): - The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. - final_dropout (`bool` defaults to `False`): - Whether to apply a final dropout after the last feed-forward layer. - attention_type (`str`, defaults to `"default"`): - The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. - positional_embeddings (`str`, *optional*): - The type of positional embeddings to apply to. - num_positional_embeddings (`int`, *optional*, defaults to `None`): - The maximum number of positional embeddings to apply. - ff_inner_dim (`int`, *optional*): - Hidden dimension of feed-forward MLP. - ff_bias (`bool`, defaults to `True`): - Whether or not to use bias in feed-forward MLP. - attention_out_bias (`bool`, defaults to `True`): - Whether or not to use bias in attention output project layer. - context_length (`int`, defaults to `16`): - The maximum number of frames that the FreeNoise block processes at once. - context_stride (`int`, defaults to `4`): - The number of frames to be skipped before starting to process a new batch of `context_length` frames. - weighting_scheme (`str`, defaults to `"pyramid"`): - The weighting scheme to use for weighting averaging of processed latent frames. As described in the - Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting - used. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout: float = 0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - norm_eps: float = 1e-5, - final_dropout: bool = False, - positional_embeddings: Optional[str] = None, - num_positional_embeddings: Optional[int] = None, - ff_inner_dim: Optional[int] = None, - ff_bias: bool = True, - attention_out_bias: bool = True, - context_length: int = 16, - context_stride: int = 4, - weighting_scheme: str = "pyramid", - ): - super().__init__() - self.dim = dim - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - self.dropout = dropout - self.cross_attention_dim = cross_attention_dim - self.activation_fn = activation_fn - self.attention_bias = attention_bias - self.double_self_attention = double_self_attention - self.norm_elementwise_affine = norm_elementwise_affine - self.positional_embeddings = positional_embeddings - self.num_positional_embeddings = num_positional_embeddings - self.only_cross_attention = only_cross_attention - - self.set_free_noise_properties(context_length, context_stride, weighting_scheme) - - # We keep these boolean flags for backward-compatibility. - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - self.use_ada_layer_norm_single = norm_type == "ada_norm_single" - self.use_layer_norm = norm_type == "layer_norm" - self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - self.norm_type = norm_type - self.num_embeds_ada_norm = num_embeds_ada_norm - - if positional_embeddings and (num_positional_embeddings is None): - raise ValueError( - "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." - ) - - if positional_embeddings == "sinusoidal": - self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) - else: - self.pos_embed = None - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - out_bias=attention_out_bias, - ) - - # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - out_bias=attention_out_bias, - ) # is self-attn if encoder_hidden_states is none - - # 3. Feed-forward - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - inner_dim=ff_inner_dim, - bias=ff_bias, - ) - - self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: - frame_indices = [] - for i in range(0, num_frames - self.context_length + 1, self.context_stride): - window_start = i - window_end = min(num_frames, i + self.context_length) - frame_indices.append((window_start, window_end)) - return frame_indices - - def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: - if weighting_scheme == "flat": - weights = [1.0] * num_frames - - elif weighting_scheme == "pyramid": - if num_frames % 2 == 0: - # num_frames = 4 => [1, 2, 2, 1] - mid = num_frames // 2 - weights = list(range(1, mid + 1)) - weights = weights + weights[::-1] - else: - # num_frames = 5 => [1, 2, 3, 2, 1] - mid = (num_frames + 1) // 2 - weights = list(range(1, mid)) - weights = weights + [mid] + weights[::-1] - - elif weighting_scheme == "delayed_reverse_sawtooth": - if num_frames % 2 == 0: - # num_frames = 4 => [0.01, 2, 2, 1] - mid = num_frames // 2 - weights = [0.01] * (mid - 1) + [mid] - weights = weights + list(range(mid, 0, -1)) - else: - # num_frames = 5 => [0.01, 0.01, 3, 2, 1] - mid = (num_frames + 1) // 2 - weights = [0.01] * mid - weights = weights + list(range(mid, 0, -1)) - else: - raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}") - - return weights - - def set_free_noise_properties( - self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid" - ) -> None: - self.context_length = context_length - self.context_stride = context_stride - self.weighting_scheme = weighting_scheme - - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None: - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - *args, - **kwargs, - ) -> torch.Tensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - - # hidden_states: [B x H x W, F, C] - device = hidden_states.device - dtype = hidden_states.dtype - - num_frames = hidden_states.size(1) - frame_indices = self._get_frame_indices(num_frames) - frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme) - frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1) - is_last_frame_batch_complete = frame_indices[-1][1] == num_frames - - # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length - # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges: - # [(0, 16), (4, 20), (8, 24), (10, 26)] - if not is_last_frame_batch_complete: - if num_frames < self.context_length: - raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}") - last_frame_batch_length = num_frames - frame_indices[-1][1] - frame_indices.append((num_frames - self.context_length, num_frames)) - - num_times_accumulated = torch.zeros((1, num_frames, 1), device=device) - accumulated_values = torch.zeros_like(hidden_states) - - for i, (frame_start, frame_end) in enumerate(frame_indices): - # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle - # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or - # essentially a non-multiple of `context_length`. - weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end]) - weights *= frame_weights - - hidden_states_chunk = hidden_states[:, frame_start:frame_end] - - # Notice that normalization is always applied before the real computation in the following blocks. - # 1. Self-Attention - norm_hidden_states = self.norm1(hidden_states_chunk) - - if self.pos_embed is not None: - norm_hidden_states = self.pos_embed(norm_hidden_states) - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - hidden_states_chunk = attn_output + hidden_states_chunk - if hidden_states_chunk.ndim == 4: - hidden_states_chunk = hidden_states_chunk.squeeze(1) - - # 2. Cross-Attention - if self.attn2 is not None: - norm_hidden_states = self.norm2(hidden_states_chunk) - - if self.pos_embed is not None and self.norm_type != "ada_norm_single": - norm_hidden_states = self.pos_embed(norm_hidden_states) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states_chunk = attn_output + hidden_states_chunk - - if i == len(frame_indices) - 1 and not is_last_frame_batch_complete: - accumulated_values[:, -last_frame_batch_length:] += ( - hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:] - ) - num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length] - else: - accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights - num_times_accumulated[:, frame_start:frame_end] += weights - - # TODO(aryan): Maybe this could be done in a better way. - # - # Previously, this was: - # hidden_states = torch.where( - # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values - # ) - # - # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory - # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes - # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly - # looked into this deeply because other memory optimizations led to more pronounced reductions. - hidden_states = torch.cat( - [ - torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split) - for accumulated_split, num_times_split in zip( - accumulated_values.split(self.context_length, dim=1), - num_times_accumulated.split(self.context_length, dim=1), - ) - ], - dim=1, - ).to(dtype) - - # 3. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self._chunk_size is not None: - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - else: - ff_output = self.ff(norm_hidden_states) - - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - return hidden_states - - -class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - mult: int = 4, - dropout: float = 0.0, - activation_fn: str = "geglu", - final_dropout: bool = False, - inner_dim=None, - bias: bool = True, - ): - super().__init__() - if inner_dim is None: - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - - if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim, bias=bias) - if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) - elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim, bias=bias) - elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim, bias=bias) - elif activation_fn == "swiglu": - act_fn = SwiGLU(dim, inner_dim, bias=bias) - elif activation_fn == "linear-silu": - act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") - - self.net = nn.ModuleList([]) - # project in - self.net.append(act_fn) - # project dropout - self.net.append(nn.Dropout(dropout)) - # project out - self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - if final_dropout: - self.net.append(nn.Dropout(dropout)) - - def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states diff --git a/src/diffusers/models/attention_modules.py b/src/diffusers/models/attention_modules.py deleted file mode 100644 index 96b8438fb5..0000000000 --- a/src/diffusers/models/attention_modules.py +++ /dev/null @@ -1,247 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -from typing import Optional, Tuple, Union - -import torch -from torch import nn - -from ..utils import logging -from ..utils.torch_utils import maybe_allow_in_graph -from .attention_processor import ( - AttentionModuleMixin, - FusedJointAttnProcessorSDPA, - JointAttnProcessorSDPA, - SanaLinearAttnProcessorSDPA, -) -from .normalization import get_normalization - - -logger = logging.get_logger(__name__) - - -@maybe_allow_in_graph -class SanaAttention(nn.Module, AttentionModuleMixin): - """ - Attention implementation specialized for Sana models. - - This module implements lightweight multi-scale linear attention as used in Sana. - - Args: - in_channels (`int`): Number of input channels. - out_channels (`int`): Number of output channels. - num_attention_heads (`int`, *optional*): Number of attention heads. - attention_head_dim (`int`, defaults to 8): Dimension of each attention head. - mult (`float`, defaults to 1.0): Multiplier for inner dimension. - norm_type (`str`, defaults to "batch_norm"): Type of normalization. - kernel_sizes (`Tuple[int, ...]`, defaults to (5,)): Kernel sizes for multi-scale attention. - """ - - # Set Sana-specific processor classes - default_processor_class = SanaLinearAttnProcessorSDPA - fused_processor_class = None # Sana doesn't have a fused processor yet - - def __init__( - self, - in_channels: int, - out_channels: int, - num_attention_heads: Optional[int] = None, - attention_head_dim: int = 8, - mult: float = 1.0, - norm_type: str = "batch_norm", - kernel_sizes: Tuple[int, ...] = (5,), - eps: float = 1e-15, - residual_connection: bool = False, - ): - super().__init__() - - # Core parameters - self.eps = eps - self.attention_head_dim = attention_head_dim - self.norm_type = norm_type - self.residual_connection = residual_connection - - # Calculate dimensions - num_attention_heads = ( - int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads - ) - inner_dim = num_attention_heads * attention_head_dim - self.inner_dim = inner_dim - self.heads = num_attention_heads - - # Query, key, value projections - self.to_q = nn.Linear(in_channels, inner_dim, bias=False) - self.to_k = nn.Linear(in_channels, inner_dim, bias=False) - self.to_v = nn.Linear(in_channels, inner_dim, bias=False) - - # Multi-scale attention - self.to_qkv_multiscale = nn.ModuleList() - for kernel_size in kernel_sizes: - self.to_qkv_multiscale.append( - SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size) - ) - - # Output layers - self.nonlinearity = nn.ReLU() - self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False) - self.norm_out = get_normalization(norm_type, num_features=out_channels) - - # Set default processor - self.fused_projections = False - self.set_processor(self.default_processor_class()) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """Process linear attention for Sana model inputs.""" - return self.processor(self, hidden_states) - - -class SanaMultiscaleAttentionProjection(nn.Module): - """Projection layer for Sana multi-scale attention.""" - - def __init__( - self, - in_channels: int, - num_attention_heads: int, - kernel_size: int, - ) -> None: - super().__init__() - - channels = 3 * in_channels - self.proj_in = nn.Conv2d( - channels, - channels, - kernel_size, - padding=kernel_size // 2, - groups=channels, - bias=False, - ) - self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.proj_in(hidden_states) - hidden_states = self.proj_out(hidden_states) - return hidden_states - - -@maybe_allow_in_graph -class SD3Attention(nn.Module, AttentionModuleMixin): - """ - Attention implementation specialized for SD3 models. - - This module implements the joint attention mechanism used in SD3, - with native support for context pre-processing. - - Args: - query_dim (`int`): Number of channels in query. - cross_attention_dim (`int`, *optional*): Number of channels in encoder states. - heads (`int`, defaults to 8): Number of attention heads. - dim_head (`int`, defaults to 64): Dimension of each attention head. - dropout (`float`, defaults to 0.0): Dropout probability. - bias (`bool`, defaults to False): Whether to use bias in linear projections. - added_kv_proj_dim (`int`, *optional*): Dimension for added key/value projections. - """ - - # Set SD3-specific processor classes - default_processor_class = JointAttnProcessorSDPA - fused_processor_class = FusedJointAttnProcessorSDPA - - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - added_kv_proj_dim: Optional[int] = None, - context_pre_only: bool = False, - ): - super().__init__() - - # Core parameters - self.inner_dim = dim_head * heads - self.query_dim = query_dim - self.heads = heads - self.scale = dim_head**-0.5 - self.use_bias = bias - self.scale_qk = True - self.context_pre_only = context_pre_only - - # Cross-attention setup - self.is_cross_attention = cross_attention_dim is not None - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - - # Projections for self-attention - self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) - - # Added projections for context processing - self.added_kv_proj_dim = added_kv_proj_dim - if added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) - self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) - self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) - self.added_proj_bias = bias - - # Output projection - self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, query_dim, bias=bias), nn.Dropout(dropout)]) - - # Context output projection - if added_kv_proj_dim is not None and not context_pre_only: - self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias) - else: - self.to_add_out = None - - # Set default processor and fusion state - self.fused_projections = False - self.set_processor(self.default_processor_class()) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Process joint attention for SD3 model inputs.""" - # Filter parameters to only those expected by the processor - processor_params = inspect.signature(self.processor.__call__).parameters.keys() - quiet_params = {"ip_adapter_masks", "ip_hidden_states"} - - # Check for unexpected parameters - unexpected_params = [k for k, _ in kwargs.items() if k not in processor_params and k not in quiet_params] - if unexpected_params: - logger.warning( - f"Parameters {unexpected_params} are not expected by {self.processor.__class__.__name__} and will be ignored." - ) - - # Filter to only expected parameters - filtered_kwargs = {k: v for k, v in kwargs.items() if k in processor_params} - - # Process with appropriate processor - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **filtered_kwargs, - ) - diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 80aed8d123..647ad41a6b 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -22,7 +22,7 @@ from torch import nn from ..image_processor import IPAdapterMaskProcessor from ..utils import deprecate, is_torch_xla_available, logging from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available -from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph +from ..utils.torch_utils import is_torch_version logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -46,596 +46,6 @@ else: XLA_AVAILABLE = False -class AttentionModuleMixin: - """ - A mixin class that provides common methods for attention modules. - - This mixin adds functionality to set different attention processors, handle attention masks, compute attention - scores, and manage projections. - """ - - # Default processor classes to be overridden by subclasses - default_processor_cls = None - _available_processors = [] - - def _get_compatible_processor(self, backend): - for processor_cls in self._available_processors: - if backend in processor_cls.compatible_backends: - processor = processor_cls() - return processor - - def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: - """ - Set whether to use NPU flash attention from `torch_npu` or not. - - Args: - use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not. - """ - processor = self.default_processor_cls() - - if use_npu_flash_attention: - processor = self._get_compatible_processor("npu") - - self.set_processor(processor) - - def set_use_xla_flash_attention( - self, - use_xla_flash_attention: bool, - partition_spec: Optional[Tuple[Optional[str], ...]] = None, - is_flux=False, - ) -> None: - """ - Set whether to use XLA flash attention from `torch_xla` or not. - - Args: - use_xla_flash_attention (`bool`): - Whether to use pallas flash attention kernel from `torch_xla` or not. - partition_spec (`Tuple[]`, *optional*): - Specify the partition specification if using SPMD. Otherwise None. - is_flux (`bool`, *optional*, defaults to `False`): - Whether the model is a Flux model. - """ - processor = self.default_processor_cls() - if use_xla_flash_attention: - if not is_torch_xla_available(): - raise "torch_xla is not available" - elif is_torch_xla_version("<", "2.3"): - raise "flash attention pallas kernel is supported from torch_xla version 2.3" - elif is_spmd() and is_torch_xla_version("<", "2.4"): - raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4" - else: - processor = self._get_compatible_processor("xla") - - self.set_processor(processor) - - @torch.no_grad() - def fuse_projections(self, fuse=True): - """ - Fuse the query, key, and value projections into a single projection for efficiency. - - Args: - fuse (`bool`): Whether to fuse the projections or not. - """ - # Skip if already in desired state - if getattr(self, "fused_projections", False) == fuse: - return - - device = self.to_q.weight.data.device - dtype = self.to_q.weight.data.dtype - - if not self.is_cross_attention: - # Fuse self-attention projections - concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) - self.to_qkv.weight.copy_(concatenated_weights) - if self.use_bias: - concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) - self.to_qkv.bias.copy_(concatenated_bias) - - else: - # Fuse cross-attention key-value projections - concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) - self.to_kv.weight.copy_(concatenated_weights) - if self.use_bias: - concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) - self.to_kv.bias.copy_(concatenated_bias) - - # Handle added projections for models like SD3, Flux, etc. - if ( - getattr(self, "add_q_proj", None) is not None - and getattr(self, "add_k_proj", None) is not None - and getattr(self, "add_v_proj", None) is not None - ): - concatenated_weights = torch.cat( - [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] - ) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - self.to_added_qkv = nn.Linear( - in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype - ) - self.to_added_qkv.weight.copy_(concatenated_weights) - if self.added_proj_bias: - concatenated_bias = torch.cat( - [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] - ) - self.to_added_qkv.bias.copy_(concatenated_bias) - - self.fused_projections = fuse - self.processor.is_fused = fuse - - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None - ) -> None: - """ - Set whether to use memory efficient attention from `xformers` or not. - - Args: - use_memory_efficient_attention_xformers (`bool`): - Whether to use memory efficient attention from `xformers` or not. - attention_op (`Callable`, *optional*): - The attention operation to use. Defaults to `None` which uses the default attention operation from - `xformers`. - """ - is_custom_diffusion = hasattr(self, "processor") and isinstance( - self.processor, - (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessorSDPA), - ) - is_added_kv_processor = hasattr(self, "processor") and isinstance( - self.processor, - ( - AttnAddedKVProcessor, - AttnAddedKVProcessorSDPA, - SlicedAttnAddedKVProcessor, - XFormersAttnAddedKVProcessor, - ), - ) - is_ip_adapter = hasattr(self, "processor") and isinstance( - self.processor, - (IPAdapterAttnProcessor, IPAdapterAttnProcessorSDPA, IPAdapterXFormersAttnProcessor), - ) - is_joint_processor = hasattr(self, "processor") and isinstance( - self.processor, - ( - JointAttnProcessorSDPA, - XFormersJointAttnProcessor, - ), - ) - - if use_memory_efficient_attention_xformers: - if is_added_kv_processor and is_custom_diffusion: - raise NotImplementedError( - f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}" - ) - if not is_xformers_available(): - raise ModuleNotFoundError( - ( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers" - ), - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - " only available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - dtype = None - if attention_op is not None: - op_fw, op_bw = attention_op - dtype, *_ = op_fw.SUPPORTED_DTYPES - q = torch.randn((1, 2, 40), device="cuda", dtype=dtype) - _ = xformers.ops.memory_efficient_attention(q, q, q) - except Exception as e: - raise e - - if is_custom_diffusion: - processor = CustomDiffusionXFormersAttnProcessor( - train_kv=self.processor.train_kv, - train_q_out=self.processor.train_q_out, - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - attention_op=attention_op, - ) - processor.load_state_dict(self.processor.state_dict()) - if hasattr(self.processor, "to_k_custom_diffusion"): - processor.to(self.processor.to_k_custom_diffusion.weight.device) - elif is_added_kv_processor: - # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP - # which uses this type of cross attention ONLY because the attention mask of format - # [0, ..., -10.000, ..., 0, ...,] is not supported - # throw warning - logger.info( - "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." - ) - processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) - elif is_ip_adapter: - processor = IPAdapterXFormersAttnProcessor( - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - num_tokens=self.processor.num_tokens, - scale=self.processor.scale, - attention_op=attention_op, - ) - processor.load_state_dict(self.processor.state_dict()) - if hasattr(self.processor, "to_k_ip"): - processor.to( - device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype - ) - elif is_joint_processor: - processor = XFormersJointAttnProcessor(attention_op=attention_op) - else: - processor = XFormersAttnProcessor(attention_op=attention_op) - else: - if is_custom_diffusion: - attn_processor_class = ( - CustomDiffusionAttnProcessorSDPA - if hasattr(F, "scaled_dot_product_attention") - else CustomDiffusionAttnProcessor - ) - processor = attn_processor_class( - train_kv=self.processor.train_kv, - train_q_out=self.processor.train_q_out, - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - ) - processor.load_state_dict(self.processor.state_dict()) - if hasattr(self.processor, "to_k_custom_diffusion"): - processor.to(self.processor.to_k_custom_diffusion.weight.device) - elif is_ip_adapter: - processor = IPAdapterAttnProcessorSDPA( - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - num_tokens=self.processor.num_tokens, - scale=self.processor.scale, - ) - processor.load_state_dict(self.processor.state_dict()) - if hasattr(self.processor, "to_k_ip"): - processor.to( - device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype - ) - else: - # set attention processor - # We use the AttnProcessorSDPA by default when torch 2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - processor = ( - AttnProcessorSDPA() - if hasattr(F, "scaled_dot_product_attention") and self.scale_qk - else AttnProcessor() - ) - - self.set_processor(processor) - - def set_attention_slice(self, slice_size: int) -> None: - """ - Set the slice size for attention computation. - - Args: - slice_size (`int`): - The slice size for attention computation. - """ - if slice_size is not None and slice_size > self.sliceable_head_dim: - raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") - - if slice_size is not None and self.added_kv_proj_dim is not None: - processor = SlicedAttnAddedKVProcessor(slice_size) - elif slice_size is not None: - processor = SlicedAttnProcessor(slice_size) - elif self.added_kv_proj_dim is not None: - processor = AttnAddedKVProcessor() - else: - # set attention processor - # We use the AttnProcessorSDPA by default when torch 2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - processor = ( - AttnProcessorSDPA() - if hasattr(F, "scaled_dot_product_attention") and self.scale_qk - else AttnProcessor() - ) - - self.set_processor(processor) - - def set_processor(self, processor: "AttnProcessor") -> None: - """ - Set the attention processor to use. - - Args: - processor (`AttnProcessor`): - The attention processor to use. - """ - # if current processor is in `self._modules` and if passed `processor` is not, we need to - # pop `processor` from `self._modules` - if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) - ): - logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") - self._modules.pop("processor") - - self.processor = processor - - def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": - """ - Get the attention processor in use. - - Args: - return_deprecated_lora (`bool`, *optional*, defaults to `False`): - Set to `True` to return the deprecated LoRA attention processor. - - Returns: - "AttentionProcessor": The attention processor in use. - """ - if not return_deprecated_lora: - return self.processor - - def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: - """ - Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. - - Args: - tensor (`torch.Tensor`): The tensor to reshape. - - Returns: - `torch.Tensor`: The reshaped tensor. - """ - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: - """ - Reshape the tensor for multi-head attention processing. - - Args: - tensor (`torch.Tensor`): The tensor to reshape. - out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. - - Returns: - `torch.Tensor`: The reshaped tensor. - """ - head_size = self.heads - if tensor.ndim == 3: - batch_size, seq_len, dim = tensor.shape - extra_dim = 1 - else: - batch_size, extra_dim, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3) - - if out_dim == 3: - tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) - - return tensor - - def get_attention_scores( - self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """ - Compute the attention scores. - - Args: - query (`torch.Tensor`): The query tensor. - key (`torch.Tensor`): The key tensor. - attention_mask (`torch.Tensor`, *optional*): The attention mask to use. - - Returns: - `torch.Tensor`: The attention probabilities/scores. - """ - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - - if attention_mask is None: - baddbmm_input = torch.empty( - query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device - ) - beta = 0 - else: - baddbmm_input = attention_mask - beta = 1 - - attention_scores = torch.baddbmm( - baddbmm_input, - query, - key.transpose(-1, -2), - beta=beta, - alpha=self.scale, - ) - del baddbmm_input - - if self.upcast_softmax: - attention_scores = attention_scores.float() - - attention_probs = attention_scores.softmax(dim=-1) - del attention_scores - - attention_probs = attention_probs.to(dtype) - - return attention_probs - - def prepare_attention_mask( - self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 - ) -> torch.Tensor: - """ - Prepare the attention mask for the attention computation. - - Args: - attention_mask (`torch.Tensor`): The attention mask to prepare. - target_length (`int`): The target length of the attention mask. - batch_size (`int`): The batch size for repeating the attention mask. - out_dim (`int`, *optional*, defaults to `3`): Output dimension. - - Returns: - `torch.Tensor`: The prepared attention mask. - """ - head_size = self.heads - if attention_mask is None: - return attention_mask - - current_length: int = attention_mask.shape[-1] - if current_length != target_length: - if attention_mask.device.type == "mps": - # HACK: MPS: Does not support padding by greater than dimension of input tensor. - # Instead, we can manually construct the padding tensor. - padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) - padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) - attention_mask = torch.cat([attention_mask, padding], dim=2) - else: - # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: - # we want to instead pad by (0, remaining_length), where remaining_length is: - # remaining_length: int = target_length - current_length - # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - - if out_dim == 3: - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = attention_mask.repeat_interleave(head_size, dim=0) - elif out_dim == 4: - attention_mask = attention_mask.unsqueeze(1) - attention_mask = attention_mask.repeat_interleave(head_size, dim=1) - - return attention_mask - - def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: - """ - Normalize the encoder hidden states. - - Args: - encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. - - Returns: - `torch.Tensor`: The normalized encoder hidden states. - """ - assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" - if isinstance(self.norm_cross, nn.LayerNorm): - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - elif isinstance(self.norm_cross, nn.GroupNorm): - # Group norm norms along the channels dimension and expects - # input to be in the shape of (N, C, *). In this case, we want - # to norm along the hidden dimension, so we need to move - # (batch_size, sequence_length, hidden_size) -> - # (batch_size, hidden_size, sequence_length) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - else: - assert False - - return encoder_hidden_states - - -class AttnProcessorSDPA: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: "Attention", - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, - ) -> torch.Tensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - residual = hidden_states - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -@maybe_allow_in_graph class Attention(nn.Module, AttentionModuleMixin): default_processor_class = AttnProcessorSDPA _available_processors = [] @@ -893,11 +303,7 @@ class Attention(nn.Module, AttentionModuleMixin): # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 if processor is None: - processor = ( - AttnProcessorSDPA() - if hasattr(F, "scaled_dot_product_attention") and self.scale_qk - else AttnProcessor() - ) + processor = self.default_processor_class() self.set_processor(processor) def forward( @@ -947,97 +353,99 @@ class Attention(nn.Module, AttentionModuleMixin): ) -class SanaMultiscaleAttentionProjection(nn.Module): - def __init__( +class AttnProcessorSDPA: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( self, - in_channels: int, - num_attention_heads: int, - kernel_size: int, - ) -> None: - super().__init__() + attn: "Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) - channels = 3 * in_channels - self.proj_in = nn.Conv2d( - channels, - channels, - kernel_size, - padding=kernel_size // 2, - groups=channels, - bias=False, + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) - self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.proj_in(hidden_states) - hidden_states = self.proj_out(hidden_states) - return hidden_states + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) -class SanaMultiscaleLinearAttention(nn.Module): - r"""Lightweight multi-scale linear attention""" + query = attn.to_q(hidden_states) - def __init__( - self, - in_channels: int, - out_channels: int, - num_attention_heads: Optional[int] = None, - attention_head_dim: int = 8, - mult: float = 1.0, - norm_type: str = "batch_norm", - kernel_sizes: Tuple[int, ...] = (5,), - eps: float = 1e-15, - residual_connection: bool = False, - ): - super().__init__() + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - # To prevent circular import - from .normalization import get_normalization + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) - self.eps = eps - self.attention_head_dim = attention_head_dim - self.norm_type = norm_type - self.residual_connection = residual_connection + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads - num_attention_heads = ( - int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - inner_dim = num_attention_heads * attention_head_dim - self.to_q = nn.Linear(in_channels, inner_dim, bias=False) - self.to_k = nn.Linear(in_channels, inner_dim, bias=False) - self.to_v = nn.Linear(in_channels, inner_dim, bias=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) - self.to_qkv_multiscale = nn.ModuleList() - for kernel_size in kernel_sizes: - self.to_qkv_multiscale.append( - SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size) - ) + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) - self.nonlinearity = nn.ReLU() - self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False) - self.norm_out = get_normalization(norm_type, num_features=out_channels) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - self.processor = SanaMultiscaleAttnProcessorSDPA() + if attn.residual_connection: + hidden_states = hidden_states + residual - def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding - scores = torch.matmul(value, key.transpose(-1, -2)) - hidden_states = torch.matmul(scores, query) + hidden_states = hidden_states / attn.rescale_output_factor - hidden_states = hidden_states.to(dtype=torch.float32) - hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps) return hidden_states - def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - scores = torch.matmul(key.transpose(-1, -2), query) - scores = scores.to(dtype=torch.float32) - scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps) - hidden_states = torch.matmul(value, scores.to(value.dtype)) - return hidden_states - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return self.processor(self, hidden_states) - class CustomDiffusionAttnProcessor(nn.Module): r""" @@ -5304,98 +4712,104 @@ class StableAudioAttnProcessor2_0: def __new__(self, *args, **kwargs): deprecation_message = "`StableAudioAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `StableAudioAttnProcessorSDPA`" deprecate("StableAudioAttnProcessor2_0", "1.0.0", deprecation_message) + return StableAudioAttnProcessorSDPA(*args, **kwargs) -class HunyuanAttnProcessor2_0(HunyuanAttnProcessorSDPA): +class HunyuanAttnProcessor2_0: def __new__(cls, *args, **kwargs): deprecation_message = "`HunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `HunyuanAttnProcessorSDPA`" deprecate("HunyuanAttnProcessor2_0", "1.0.0", deprecation_message) + return HunyuanAttnProcessorSDPA(*args, **kwargs) -class FusedHunyuanAttnProcessor2_0(FusedHunyuanAttnProcessorSDPA): - def __init__(self, *args, **kwargs): +class FusedHunyuanAttnProcessor2_0: + def __new__(cls, *args, **kwargs): deprecation_message = "`FusedHunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedHunyuanAttnProcessorSDPA`" deprecate("FusedHunyuanAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + + return HunyuanAttnProcessorSDPA(*args, **kwargs) -class PAGHunyuanAttnProcessor2_0(PAGHunyuanAttnProcessorSDPA): - def __init__(self, *args, **kwargs): +class PAGHunyuanAttnProcessor2_0: + def __new__(cls, *args, **kwargs): deprecation_message = "`PAGHunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGHunyuanAttnProcessorSDPA`" deprecate("PAGHunyuanAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + + return PAGHunyuanAttnProcessorSDPA(*args, **kwargs) -class PAGCFGHunyuanAttnProcessor2_0(PAGCFGHunyuanAttnProcessorSDPA): - def __init__(self, *args, **kwargs): +class PAGCFGHunyuanAttnProcessor2_0: + def __new__(cls, *args, **kwargs): deprecation_message = "`PAGCFGHunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGHunyuanAttnProcessorSDPA`" deprecate("PAGCFGHunyuanAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + + return PAGCFGHunyuanAttnProcessorSDPA(*args, **kwargs) -class LuminaAttnProcessor2_0(LuminaAttnProcessorSDPA): - def __init__(self, *args, **kwargs): +class LuminaAttnProcessor2_0: + def __new__(cls, *args, **kwargs): deprecation_message = "`LuminaAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `LuminaAttnProcessorSDPA`" deprecate("LuminaAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + + return LuminaAttnProcessorSDPA(*args, **kwargs) -class FusedAttnProcessor2_0(FusedAttnProcessorSDPA): - def __init__(self, *args, **kwargs): - deprecation_message = "`FusedAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedAttnProcessorSDPA`" - deprecate("FusedAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) - - -class PAGIdentitySelfAttnProcessor2_0(PAGIdentitySelfAttnProcessorSDPA): - def __init__(self, *args, **kwargs): +class PAGIdentitySelfAttnProcessor2_0: + def __new__(cls, *args, **kwargs): deprecation_message = "`PAGIdentitySelfAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGIdentitySelfAttnProcessorSDPA`" deprecate("PAGIdentitySelfAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + + return PAGIdentitySelfAttnProcessorSDPA(*args, **kwargs) -class PAGCFGIdentitySelfAttnProcessor2_0(PAGCFGIdentitySelfAttnProcessorSDPA): - def __init__(self, *args, **kwargs): +class PAGCFGIdentitySelfAttnProcessor2_0: + def __new__(cls, *args, **kwargs): deprecation_message = "`PAGCFGIdentitySelfAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGIdentitySelfAttnProcessorSDPA`" deprecate("PAGCFGIdentitySelfAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + + return PAGCFGIdentitySelfAttnProcessorSDPA(*args, **kwargs) -class SanaMultiscaleAttnProcessor2_0(SanaMultiscaleAttnProcessorSDPA): - def __init__(self, *args, **kwargs): +class SanaMultiscaleAttnProcessor2_0: + def __new__(cls, *args, **kwargs): deprecation_message = "`SanaMultiscaleAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `SanaMultiscaleAttnProcessorSDPA`" deprecate("SanaMultiscaleAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + + return SanaMultiscaleAttnProcessorSDPA(*args, **kwargs) -class LoRAAttnProcessor2_0(LoRAAttnProcessorSDPA): - def __init__(self, *args, **kwargs): +class LoRAAttnProcessor2_0: + def __new__(cls, *args, **kwargs): deprecation_message = "`LoRAAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `LoRAAttnProcessorSDPA`" deprecate("LoRAAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + + return LoRAAttnProcessorSDPA(*args, **kwargs) -class SanaLinearAttnProcessor2_0(SanaLinearAttnProcessorSDPA): - def __init__(self, *args, **kwargs): +class SanaLinearAttnProcessor2_0: + def __new__(cls, *args, **kwargs): deprecation_message = "`SanaLinearAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `SanaLinearAttnProcessorSDPA`" deprecate("SanaLinearAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + + return SanaLinearAttnProcessorSDPA(*args, **kwargs) -class PAGCFGSanaLinearAttnProcessor2_0(PAGCFGSanaLinearAttnProcessorSDPA): - def __init__(self, *args, **kwargs): +class PAGCFGSanaLinearAttnProcessor2_0: + def __new__(cls, *args, **kwargs): deprecation_message = "`PAGCFGSanaLinearAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGSanaLinearAttnProcessorSDPA`" deprecate("PAGCFGSanaLinearAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + + return PAGCFGSanaLinearAttnProcessorSDPA(*args, **kwargs) -class PAGIdentitySanaLinearAttnProcessor2_0(PAGIdentitySanaLinearAttnProcessorSDPA): - def __init__(self, *args, **kwargs): +class PAGIdentitySanaLinearAttnProcessor2_0: + def __new__(cls, *args, **kwargs): deprecation_message = "`PAGIdentitySanaLinearAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGIdentitySanaLinearAttnProcessorSDPA`" deprecate("PAGIdentitySanaLinearAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + + return PAGIdentitySanaLinearAttnProcessorSDPA(*args, **kwargs) class IPAdapterAttnProcessor(IPAdapterAttnProcessorSDPA): @@ -5405,11 +4819,12 @@ class IPAdapterAttnProcessor(IPAdapterAttnProcessorSDPA): super().__init__(*args, **kwargs) -class IPAdapterAttnProcessor2_0(IPAdapterAttnProcessorSDPA): - def __init__(self, *args, **kwargs) -> None: +class IPAdapterAttnProcessor2_0: + def __new__(cls, *args, **kwargs): deprecation_message = "`IPAdapterAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `IPAdapterAttnProcessorSDPA`" deprecate("IPAdapterAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + + return IPAdapterAttnProcessorSDPA(*args, **kwargs) ADDED_KV_ATTENTION_PROCESSORS = ( diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py index 9146aa5c7c..9c1f76ba41 100644 --- a/src/diffusers/models/autoencoders/autoencoder_dc.py +++ b/src/diffusers/models/autoencoders/autoencoder_dc.py @@ -62,6 +62,98 @@ class ResBlock(nn.Module): return hidden_states + residual +class SanaMultiscaleAttentionProjection(nn.Module): + def __init__( + self, + in_channels: int, + num_attention_heads: int, + kernel_size: int, + ) -> None: + super().__init__() + + channels = 3 * in_channels + self.proj_in = nn.Conv2d( + channels, + channels, + kernel_size, + padding=kernel_size // 2, + groups=channels, + bias=False, + ) + self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj_in(hidden_states) + hidden_states = self.proj_out(hidden_states) + return hidden_states + + +class SanaMultiscaleLinearAttention(nn.Module): + r"""Lightweight multi-scale linear attention""" + + def __init__( + self, + in_channels: int, + out_channels: int, + num_attention_heads: Optional[int] = None, + attention_head_dim: int = 8, + mult: float = 1.0, + norm_type: str = "batch_norm", + kernel_sizes: Tuple[int, ...] = (5,), + eps: float = 1e-15, + residual_connection: bool = False, + ): + super().__init__() + + # To prevent circular import + from ..normalization import get_normalization + + self.eps = eps + self.attention_head_dim = attention_head_dim + self.norm_type = norm_type + self.residual_connection = residual_connection + + num_attention_heads = ( + int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads + ) + inner_dim = num_attention_heads * attention_head_dim + + self.to_q = nn.Linear(in_channels, inner_dim, bias=False) + self.to_k = nn.Linear(in_channels, inner_dim, bias=False) + self.to_v = nn.Linear(in_channels, inner_dim, bias=False) + + self.to_qkv_multiscale = nn.ModuleList() + for kernel_size in kernel_sizes: + self.to_qkv_multiscale.append( + SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size) + ) + + self.nonlinearity = nn.ReLU() + self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False) + self.norm_out = get_normalization(norm_type, num_features=out_channels) + + self.processor = SanaMultiscaleAttnProcessorSDPA() + + def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding + scores = torch.matmul(value, key.transpose(-1, -2)) + hidden_states = torch.matmul(scores, query) + + hidden_states = hidden_states.to(dtype=torch.float32) + hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps) + return hidden_states + + def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + scores = torch.matmul(key.transpose(-1, -2), query) + scores = scores.to(dtype=torch.float32) + scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps) + hidden_states = torch.matmul(value, scores.to(value.dtype)) + return hidden_states + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.processor(self, hidden_states) + + class EfficientViTBlock(nn.Module): def __init__( self, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index a76277366c..7da90205b1 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -21,7 +21,8 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils.accelerate_utils import apply_forward_hook -from ..attention_processor import Attention, SpatialNorm +from ..attention import Attention +from ..attention_processor import SpatialNorm from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution from ..downsampling import Downsample2D from ..modeling_outputs import AutoencoderKLOutput diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index a32f4bfd76..4086095440 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -24,7 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation -from ..attention_processor import Attention +from ..attention import Attention from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from .vae import DecoderOutput, DiagonalGaussianDistribution diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index edf270f66e..1e5eb7805d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -23,7 +23,8 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation -from ..attention_processor import Attention, MochiVaeAttnProcessor2_0 +from ..attention import Attention +from ..attention_processor import MochiVaeAttnProcessor2_0 from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index adc1716069..d050c5cf2d 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -22,7 +22,8 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 +from ..attention import Attention +from ..attention_processor import AttentionProcessor, FusedJointAttnProcessor2_0 from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b1e14ca6a7..9324b85bd5 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -21,7 +21,7 @@ from torch import nn from ..utils import deprecate from .activations import FP32SiLU, get_activation -from .attention_processor import Attention +from .attention import Attention def get_timestep_embedding( diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 8781424c61..e78f77ebda 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -23,11 +23,9 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import Attention, AttentionMixin from ..attention_processor import ( - Attention, - AttentionProcessor, AuraFlowAttnProcessor2_0, - FusedAuraFlowAttnProcessor2_0, ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput @@ -267,7 +265,7 @@ class AuraFlowJointTransformerBlock(nn.Module): return encoder_hidden_states, hidden_states -class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin): r""" A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). @@ -357,105 +355,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedAuraFlowAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) + # Using methods from AttentionMixin def forward( self, diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 5e0965572d..0e7da957c5 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -22,12 +22,9 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention +from ..attention import AttentionMixin from ..attention_processor import ( AttentionModuleMixin, - AttentionProcessor, - CogVideoXAttnProcessor2_0, - FusedCogVideoXAttnProcessor2_0, ) from ..cache_utils import CacheMixin from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps @@ -103,7 +100,7 @@ class BaseCogVideoXAttnProcessor: def __call__( self, - attn: Attention, + attn: CogVideoXAttention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -260,7 +257,7 @@ class CogVideoXBlock(nn.Module): # 1. Self Attention self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) - self.attn1 = Attention( + self.attn1 = CogVideoXAttention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, @@ -268,7 +265,6 @@ class CogVideoXBlock(nn.Module): eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, - processor=CogVideoXAttnProcessor2_0(), ) # 2. Feed Forward @@ -325,7 +321,7 @@ class CogVideoXBlock(nn.Module): return hidden_states, encoder_hidden_states -class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): +class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin, AttentionMixin): """ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). @@ -499,105 +495,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} + # Using inherited methods from AttentionMixin - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) + # Using inherited methods from AttentionMixin def forward( self, diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py index 9e5f5e5db3..8fdad47838 100644 --- a/src/diffusers/models/transformers/consisid_transformer_3d.py +++ b/src/diffusers/models/transformers/consisid_transformer_3d.py @@ -22,8 +22,8 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention -from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0 +from ..attention import Attention, AttentionMixin +from ..attention_processor import CogVideoXAttnProcessor2_0 from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -349,7 +349,7 @@ class ConsisIDBlock(nn.Module): return hidden_states, encoder_hidden_states -class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, AttentionMixin): """ A Transformer model for video-like data in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID). @@ -621,65 +621,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ] ) - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) + # Using methods from AttentionMixin def forward( self, diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index e4f2a4e8c8..844eb26fa4 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -19,16 +19,17 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging -from .modeling_common BasicTransformerBlock +from ..attention import AttentionMixin from ..embeddings import PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin +from .modeling_common import BasicTransformerBlock logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class DiTTransformer2DModel(ModelMixin, ConfigMixin): +class DiTTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): r""" A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748). diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index dc40ebcac2..78e929b485 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Union +from typing import Optional import torch from torch import nn @@ -19,7 +19,8 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0 +from ..attention import Attention, AttentionMixin +from ..attention_processor import HunyuanAttnProcessor2_0 from ..embeddings import ( HunyuanCombinedTimestepTextSizeStyleEmbedding, PatchEmbed, @@ -200,7 +201,7 @@ class HunyuanDiTBlock(nn.Module): return hidden_states -class HunyuanDiT2DModel(ModelMixin, ConfigMixin): +class HunyuanDiT2DModel(ModelMixin, ConfigMixin, AttentionMixin): """ HunYuanDiT: Diffusion model with a Transformer backbone. @@ -318,105 +319,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedHunyuanAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) + # Using methods from AttentionMixin def set_default_attn_processor(self): """ diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 102a27b472..f2f610054e 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -19,15 +19,16 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid -from .modeling_common BasicTransformerBlock +from ..attention import AttentionMixin from ..cache_utils import CacheMixin from ..embeddings import PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle +from .modeling_common import BasicTransformerBlock -class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): +class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin, AttentionMixin): _supports_gradient_checkpointing = True """ diff --git a/src/diffusers/models/transformers/modeling_common.py b/src/diffusers/models/transformers/modeling_common.py index 93b11c2b43..a2ab97769e 100644 --- a/src/diffusers/models/transformers/modeling_common.py +++ b/src/diffusers/models/transformers/modeling_common.py @@ -17,12 +17,12 @@ import torch import torch.nn.functional as F from torch import nn -from ..utils import deprecate, logging -from ..utils.torch_utils import maybe_allow_in_graph -from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU -from .attention_processor import Attention, JointAttnProcessor2_0 -from .embeddings import SinusoidalPositionalEmbedding -from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX +from ...utils import deprecate, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU +from ..attention_processor import Attention, JointAttnProcessor2_0 +from ..embeddings import SinusoidalPositionalEmbedding +from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX logger = logging.get_logger(__name__) diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 50a0e3e67d..741697147e 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -11,25 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging -from .modeling_common BasicTransformerBlock -from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0 +from ..attention import AttentionMixin +from ..attention_processor import AttnProcessor from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle +from .modeling_common import BasicTransformerBlock logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class PixArtTransformer2DModel(ModelMixin, ConfigMixin): +class PixArtTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): r""" A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426, https://arxiv.org/abs/2403.04692). @@ -184,65 +185,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): in_features=self.config.caption_channels, hidden_size=self.inner_dim ) - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) + # Using inherited method from AttentionMixin def set_default_attn_processor(self): """ @@ -252,45 +195,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): """ self.set_attn_processor(AttnProcessor()) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) + # Using inherited methods from AttentionMixin def forward( self, diff --git a/src/diffusers/models/transformers/prior_transformer.py b/src/diffusers/models/transformers/prior_transformer.py index c0b00da671..7ef83f2f30 100644 --- a/src/diffusers/models/transformers/prior_transformer.py +++ b/src/diffusers/models/transformers/prior_transformer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, Optional, Union +from typing import Optional, Union import torch import torch.nn.functional as F @@ -8,16 +8,16 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ...utils import BaseOutput -from .modeling_common BasicTransformerBlock +from ..attention import AttentionMixin from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, - AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin +from .modeling_common import BasicTransformerBlock @dataclass @@ -33,7 +33,7 @@ class PriorTransformerOutput(BaseOutput): predicted_image_embedding: torch.Tensor -class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): +class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin, AttentionMixin): """ A Prior Transformer model. @@ -166,65 +166,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim)) self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) + # Using inherited methods from AttentionMixin # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index f844130fab..eb93f0f521 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -21,10 +21,9 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import Attention, AttentionMixin from ..attention_processor import ( - Attention, AttentionModuleMixin, - AttentionProcessor, SanaLinearAttnProcessor2_0, ) from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps @@ -388,7 +387,7 @@ class SanaTransformerBlock(nn.Module): return hidden_states -class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin): r""" A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models. @@ -513,65 +512,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) + # Using methods from AttentionMixin def forward( self, diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py index d81b6447ad..c37a972dc9 100644 --- a/src/diffusers/models/transformers/stable_audio_transformer.py +++ b/src/diffusers/models/transformers/stable_audio_transformer.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Dict, Optional, Union +from typing import Optional, Union import numpy as np import torch @@ -21,10 +21,8 @@ import torch.nn as nn import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config -from ...models.attention import FeedForward +from ...models.attention import Attention, AttentionMixin, FeedForward from ...models.attention_processor import ( - Attention, - AttentionProcessor, StableAudioAttnProcessor2_0, ) from ...models.modeling_utils import ModelMixin @@ -187,7 +185,7 @@ class StableAudioDiTBlock(nn.Module): return hidden_states -class StableAudioDiTModel(ModelMixin, ConfigMixin): +class StableAudioDiTModel(ModelMixin, ConfigMixin, AttentionMixin): """ The Diffusion Transformer model introduced in Stable Audio. @@ -279,65 +277,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin): self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) + # Using methods from AttentionMixin # Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio def set_default_attn_processor(self): diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index bef841d8ce..52a9f3b75a 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -19,11 +19,12 @@ from torch import nn from ...configuration_utils import LegacyConfigMixin, register_to_config from ...utils import deprecate, logging -from .modeling_common BasicTransformerBlock +from ..attention import AttentionMixin from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import LegacyModelMixin from ..normalization import AdaLayerNormSingle +from .modeling_common import BasicTransformerBlock logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -36,7 +37,7 @@ class Transformer2DModelOutput(Transformer2DModelOutput): super().__init__(*args, **kwargs) -class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): +class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin, AttentionMixin): """ A 2D Transformer model for image-like data. diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index da7133791f..02af3304fb 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -13,16 +13,14 @@ # limitations under the License. -from typing import Dict, Union +from typing import Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...models.attention import FeedForward +from ...models.attention import Attention, AttentionMixin, FeedForward from ...models.attention_processor import ( - Attention, - AttentionProcessor, CogVideoXAttnProcessor2_0, ) from ...models.modeling_utils import ModelMixin @@ -130,7 +128,7 @@ class CogView3PlusTransformerBlock(nn.Module): return hidden_states, encoder_hidden_states -class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): +class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): r""" The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121). @@ -229,65 +227,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) + # Using methods from AttentionMixin def forward( self, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 7abd0d6d10..8111686349 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -24,12 +24,13 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...models.attention import FeedForward -from ...models.attention_processor import AttentionModuleMixin, AttentionProcessor +from ...models.attention_processor import AttentionModuleMixin from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_torch_xla_version from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import AttentionMixin from ..cache_utils import CacheMixin from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -592,7 +593,7 @@ class FluxTransformerBlock(nn.Module): class FluxTransformer2DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin, AttentionMixin ): """ The Transformer model introduced in Flux. @@ -687,97 +688,9 @@ class FluxTransformer2DModel( self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} + # Using inherited methods from AttentionMixin - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, AttentionModuleMixin): - module.fuse_projections(fuse=True) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) + # Using inherited methods from AttentionMixin def forward( self, diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 39ffb038ce..8b4149ac4d 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -23,7 +23,7 @@ from diffusers.loaders import FromOriginalModelMixin from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention_processor import Attention, AttentionProcessor +from ..attention import Attention, AttentionMixin from ..cache_utils import CacheMixin from ..embeddings import ( CombinedTimestepTextProjEmbeddings, @@ -819,7 +819,7 @@ class HunyuanVideoTokenReplaceTransformerBlock(nn.Module): return hidden_states, encoder_hidden_states -class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin): r""" A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). @@ -962,65 +962,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) + # Using methods from AttentionMixin def forward( self, diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 3c2d6904bd..e0a06597f7 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -20,15 +20,13 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin from ...models.attention import FeedForward, JointTransformerBlock from ...models.attention_processor import ( - Attention, AttentionModuleMixin, - AttentionProcessor, - FusedJointAttnProcessor2_0, ) from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import AttentionMixin from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -280,7 +278,7 @@ class SD3SingleTransformerBlock(nn.Module): class SD3Transformer2DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin, AttentionMixin ): """ The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206). @@ -416,105 +414,9 @@ class SD3Transformer2DModel( for module in self.children(): fn_recursive_feed_forward(module, None, 0) - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} + # Using inherited methods from AttentionMixin - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedJointAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) + # Using inherited methods from AttentionMixin def forward( self, diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index b69b04fdcb..1288c904f6 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -19,10 +19,11 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput -from .modeling_common BasicTransformerBlock, TemporalBasicTransformerBlock +from ..attention import AttentionMixin from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..resnet import AlphaBlender +from .modeling_common import BasicTransformerBlock, TemporalBasicTransformerBlock @dataclass @@ -38,7 +39,7 @@ class TransformerTemporalModelOutput(BaseOutput): sample: torch.Tensor -class TransformerTemporalModel(ModelMixin, ConfigMixin): +class TransformerTemporalModel(ModelMixin, ConfigMixin, AttentionMixin): """ A Transformer model for video-like data. @@ -202,7 +203,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): return TransformerTemporalModelOutput(sample=output) -class TransformerSpatioTemporalModel(nn.Module): +class TransformerSpatioTemporalModel(nn.Module, AttentionMixin): """ A Transformer model for video-like data. diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index e082d524e7..dd1b9eae6e 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -21,7 +21,8 @@ from torch import nn from ...utils import deprecate, logging from ...utils.torch_utils import apply_freeu from ..activations import get_activation -from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from ..attention import Attention +from ..attention_processor import AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from ..normalization import AdaGroupNorm from ..resnet import ( Downsample2D, diff --git a/src/diffusers/models/unets/unet_kandinsky3.py b/src/diffusers/models/unets/unet_kandinsky3.py index 73bf0020b4..ee01ae933e 100644 --- a/src/diffusers/models/unets/unet_kandinsky3.py +++ b/src/diffusers/models/unets/unet_kandinsky3.py @@ -21,7 +21,8 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput, logging -from ..attention_processor import Attention, AttentionProcessor, AttnProcessor +from ..attention import Attention +from ..attention_processor import AttentionProcessor, AttnProcessor from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index f57754435f..b90bd1898a 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -23,7 +23,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin from ...utils import BaseOutput -from ..attention_processor import Attention +from ..attention import Attention from ..modeling_utils import ModelMixin