From 37de8e790ca8d3908b85bd536f292ce06722698e Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 15 Apr 2025 21:32:36 +0530 Subject: [PATCH] update --- src/diffusers/models/__init__.py | 2 + src/diffusers/models/attention_modules.py | 371 ++++++++++++++++++ src/diffusers/models/attention_processor.py | 134 ++++--- .../models/transformers/sana_transformer.py | 99 +++++ .../models/transformers/transformer_flux.py | 245 ++++++++++-- .../models/transformers/transformer_mochi.py | 81 +++- .../models/transformers/transformer_sd3.py | 207 +++++++++- 7 files changed, 1053 insertions(+), 86 deletions(-) create mode 100644 src/diffusers/models/attention_modules.py diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 276b1836a7..83a78c3fb8 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -27,6 +27,7 @@ _import_structure = {} if is_torch_available(): _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["auto_model"] = ["AutoModel"] + _import_structure["attention_modules"] = ["FluxAttention", "SanaAttention", "SD3Attention"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] @@ -107,6 +108,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .adapter import MultiAdapter, T2IAdapter from .auto_model import AutoModel + from .attention_modules import FluxAttention, SanaAttention, SD3Attention from .autoencoders import ( AsymmetricAutoencoderKL, AutoencoderDC, diff --git a/src/diffusers/models/attention_modules.py b/src/diffusers/models/attention_modules.py new file mode 100644 index 0000000000..473781bed3 --- /dev/null +++ b/src/diffusers/models/attention_modules.py @@ -0,0 +1,371 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ..utils import logging +from ..utils.torch_utils import maybe_allow_in_graph +from .attention_processor import ( + AttentionModuleMixin, + AttnProcessorSDPA, + FluxAttnProcessorSDPA, + FusedFluxAttnProcessorSDPA, + JointAttnProcessorSDPA, + FusedJointAttnProcessorSDPA, + SanaLinearAttnProcessorSDPA, +) +from .normalization import RMSNorm, get_normalization + + +logger = logging.get_logger(__name__) + + +@maybe_allow_in_graph +class SanaAttention(nn.Module, AttentionModuleMixin): + """ + Attention implementation specialized for Sana models. + + This module implements lightweight multi-scale linear attention as used in Sana. + + Args: + in_channels (`int`): Number of input channels. + out_channels (`int`): Number of output channels. + num_attention_heads (`int`, *optional*): Number of attention heads. + attention_head_dim (`int`, defaults to 8): Dimension of each attention head. + mult (`float`, defaults to 1.0): Multiplier for inner dimension. + norm_type (`str`, defaults to "batch_norm"): Type of normalization. + kernel_sizes (`Tuple[int, ...]`, defaults to (5,)): Kernel sizes for multi-scale attention. + """ + # Set Sana-specific processor classes + default_processor_class = SanaLinearAttnProcessorSDPA + fused_processor_class = None # Sana doesn't have a fused processor yet + + def __init__( + self, + in_channels: int, + out_channels: int, + num_attention_heads: Optional[int] = None, + attention_head_dim: int = 8, + mult: float = 1.0, + norm_type: str = "batch_norm", + kernel_sizes: Tuple[int, ...] = (5,), + eps: float = 1e-15, + residual_connection: bool = False, + ): + super().__init__() + + # Core parameters + self.eps = eps + self.attention_head_dim = attention_head_dim + self.norm_type = norm_type + self.residual_connection = residual_connection + + # Calculate dimensions + num_attention_heads = ( + int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads + ) + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + self.heads = num_attention_heads + + # Query, key, value projections + self.to_q = nn.Linear(in_channels, inner_dim, bias=False) + self.to_k = nn.Linear(in_channels, inner_dim, bias=False) + self.to_v = nn.Linear(in_channels, inner_dim, bias=False) + + # Multi-scale attention + self.to_qkv_multiscale = nn.ModuleList() + for kernel_size in kernel_sizes: + self.to_qkv_multiscale.append( + SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size) + ) + + # Output layers + self.nonlinearity = nn.ReLU() + self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False) + self.norm_out = get_normalization(norm_type, num_features=out_channels) + + # Set default processor + self.fused_projections = False + self.set_processor(self.default_processor_class()) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Process linear attention for Sana model inputs.""" + return self.processor(self, hidden_states) + + +class SanaMultiscaleAttentionProjection(nn.Module): + """Projection layer for Sana multi-scale attention.""" + + def __init__( + self, + in_channels: int, + num_attention_heads: int, + kernel_size: int, + ) -> None: + super().__init__() + + channels = 3 * in_channels + self.proj_in = nn.Conv2d( + channels, + channels, + kernel_size, + padding=kernel_size // 2, + groups=channels, + bias=False, + ) + self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj_in(hidden_states) + hidden_states = self.proj_out(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class FluxAttention(nn.Module, AttentionModuleMixin): + """ + Attention implementation specialized for Flux models. + + This module uses RMSNorm for query and key normalization and supports + rotary embeddings through its processor. + + Args: + query_dim (`int`): Number of channels in query. + cross_attention_dim (`int`, *optional*): Number of channels in encoder states. + heads (`int`, defaults to 8): Number of attention heads. + dim_head (`int`, defaults to 64): Dimension of each attention head. + dropout (`float`, defaults to 0.0): Dropout probability. + bias (`bool`, defaults to False): Whether to use bias in linear projections. + added_kv_proj_dim (`int`, *optional*): Dimension for added key/value projections. + """ + # Set Flux-specific processor classes + default_processor_class = FluxAttnProcessorSDPA + fused_processor_class = FusedFluxAttnProcessorSDPA + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + ): + super().__init__() + + # Core parameters + self.inner_dim = dim_head * heads + self.query_dim = query_dim + self.heads = heads + self.scale = dim_head ** -0.5 + self.use_bias = bias + self.scale_qk = True # Flux always uses scale_qk + + # Cross-attention setup + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + # Projections + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + + # Flux-specific normalization + self.norm_q = RMSNorm(dim_head, eps=1e-6) + self.norm_k = RMSNorm(dim_head, eps=1e-6) + + # Added projections for cross-attention + self.added_kv_proj_dim = added_kv_proj_dim + if added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + + # Normalization for added projections + self.norm_added_q = RMSNorm(dim_head, eps=1e-6) + self.norm_added_k = RMSNorm(dim_head, eps=1e-6) + self.added_proj_bias = bias + + # Output projection + self.to_out = nn.ModuleList([ + nn.Linear(self.inner_dim, query_dim, bias=bias), + nn.Dropout(dropout) + ]) + + # For cross-attention with added projections + if added_kv_proj_dim is not None: + self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias) + else: + self.to_add_out = None + + # Set default processor and fusion state + self.fused_projections = False + self.set_processor(self.default_processor_class()) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Process attention for Flux model inputs.""" + # Filter parameters to only those expected by the processor + processor_params = inspect.signature(self.processor.__call__).parameters.keys() + quiet_params = {"ip_adapter_masks", "ip_hidden_states"} + + # Check for unexpected parameters + unexpected_params = [ + k for k, _ in kwargs.items() + if k not in processor_params and k not in quiet_params + ] + if unexpected_params: + logger.warning( + f"Parameters {unexpected_params} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + + # Filter to only expected parameters + filtered_kwargs = {k: v for k, v in kwargs.items() if k in processor_params} + + # Process with appropriate processor + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **filtered_kwargs, + ) + + +@maybe_allow_in_graph +class SD3Attention(nn.Module, AttentionModuleMixin): + """ + Attention implementation specialized for SD3 models. + + This module implements the joint attention mechanism used in SD3, + with native support for context pre-processing. + + Args: + query_dim (`int`): Number of channels in query. + cross_attention_dim (`int`, *optional*): Number of channels in encoder states. + heads (`int`, defaults to 8): Number of attention heads. + dim_head (`int`, defaults to 64): Dimension of each attention head. + dropout (`float`, defaults to 0.0): Dropout probability. + bias (`bool`, defaults to False): Whether to use bias in linear projections. + added_kv_proj_dim (`int`, *optional*): Dimension for added key/value projections. + """ + # Set SD3-specific processor classes + default_processor_class = JointAttnProcessorSDPA + fused_processor_class = FusedJointAttnProcessorSDPA + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + context_pre_only: bool = False, + ): + super().__init__() + + # Core parameters + self.inner_dim = dim_head * heads + self.query_dim = query_dim + self.heads = heads + self.scale = dim_head ** -0.5 + self.use_bias = bias + self.scale_qk = True + self.context_pre_only = context_pre_only + + # Cross-attention setup + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + # Projections for self-attention + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) + + # Added projections for context processing + self.added_kv_proj_dim = added_kv_proj_dim + if added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + self.added_proj_bias = bias + + # Output projection + self.to_out = nn.ModuleList([ + nn.Linear(self.inner_dim, query_dim, bias=bias), + nn.Dropout(dropout) + ]) + + # Context output projection + if added_kv_proj_dim is not None and not context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias) + else: + self.to_add_out = None + + # Set default processor and fusion state + self.fused_projections = False + self.set_processor(self.default_processor_class()) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Process joint attention for SD3 model inputs.""" + # Filter parameters to only those expected by the processor + processor_params = inspect.signature(self.processor.__call__).parameters.keys() + quiet_params = {"ip_adapter_masks", "ip_hidden_states"} + + # Check for unexpected parameters + unexpected_params = [ + k for k, _ in kwargs.items() + if k not in processor_params and k not in quiet_params + ] + if unexpected_params: + logger.warning( + f"Parameters {unexpected_params} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + + # Filter to only expected parameters + filtered_kwargs = {k: v for k, v in kwargs.items() if k in processor_params} + + # Process with appropriate processor + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **filtered_kwargs, + ) \ No newline at end of file diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5999032c3d..acf315aac8 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -53,6 +53,10 @@ class AttentionModuleMixin: This mixin adds functionality to set different attention processors, handle attention masks, compute attention scores, and manage projections. """ + + # Default processor classes to be overridden by subclasses + default_processor_class = None + fused_processor_class = None def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: """ @@ -111,6 +115,74 @@ class AttentionModuleMixin: else AttnProcessor() ) self.set_processor(processor) + + @torch.no_grad() + def fuse_projections(self, fuse=True): + """ + Fuse the query, key, and value projections into a single projection for efficiency. + + Args: + fuse (`bool`): Whether to fuse the projections or not. + """ + # Skip if already in desired state + if getattr(self, "fused_projections", False) == fuse: + return + + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if not self.is_cross_attention: + # Fuse self-attention projections + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + self.to_qkv.bias.copy_(concatenated_bias) + + else: + # Fuse cross-attention key-value projections + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + self.to_kv.bias.copy_(concatenated_bias) + + # Handle added projections for models like SD3, Flux, etc. + if ( + getattr(self, "add_q_proj", None) is not None + and getattr(self, "add_k_proj", None) is not None + and getattr(self, "add_v_proj", None) is not None + ): + concatenated_weights = torch.cat( + [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Linear( + in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype + ) + self.to_added_qkv.weight.copy_(concatenated_weights) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) + self.to_added_qkv.bias.copy_(concatenated_bias) + + self.fused_projections = fuse + + # Update processor based on fusion state + processor_class = self.fused_processor_class if fuse else self.default_processor_class + if processor_class is not None: + self.set_processor(processor_class()) def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None @@ -480,68 +552,12 @@ class AttentionModuleMixin: return encoder_hidden_states - @torch.no_grad() - def fuse_projections(self, fuse=True): - """ - Fuse the query, key, and value projections into a single projection for efficiency. - - Args: - fuse (`bool`): Whether to fuse the projections or not. - """ - device = self.to_q.weight.data.device - dtype = self.to_q.weight.data.dtype - - if not self.is_cross_attention: - # fetch weight matrices. - concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - # create a new single projection layer and copy over the weights. - self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) - self.to_qkv.weight.copy_(concatenated_weights) - if self.use_bias: - concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) - self.to_qkv.bias.copy_(concatenated_bias) - - else: - concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) - self.to_kv.weight.copy_(concatenated_weights) - if self.use_bias: - concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) - self.to_kv.bias.copy_(concatenated_bias) - - # handle added projections for SD3 and others. - if ( - getattr(self, "add_q_proj", None) is not None - and getattr(self, "add_k_proj", None) is not None - and getattr(self, "add_v_proj", None) is not None - ): - concatenated_weights = torch.cat( - [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] - ) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - self.to_added_qkv = nn.Linear( - in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype - ) - self.to_added_qkv.weight.copy_(concatenated_weights) - if self.added_proj_bias: - concatenated_bias = torch.cat( - [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] - ) - self.to_added_qkv.bias.copy_(concatenated_bias) - - self.fused_projections = fuse - @maybe_allow_in_graph class Attention(nn.Module, AttentionModuleMixin): + # Set default and fused processor classes + default_processor_class = AttnProcessorSDPA + fused_processor_class = None # Will be set appropriately in the future r""" A cross attention layer. diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 52236275dc..ff730e9454 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -24,6 +24,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_ from ..attention_processor import ( Attention, AttentionProcessor, + AttentionModuleMixin, SanaLinearAttnProcessor2_0, ) from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps @@ -35,6 +36,104 @@ from ..normalization import AdaLayerNormSingle, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name +@maybe_allow_in_graph +class SanaAttention(nn.Module, AttentionModuleMixin): + """ + Attention implementation specialized for Sana models. + + This module implements lightweight multi-scale linear attention as used in Sana. + """ + # Set Sana-specific processor classes + default_processor_class = SanaLinearAttnProcessor2_0 + + def __init__( + self, + in_channels: int, + out_channels: int, + num_attention_heads: Optional[int] = None, + attention_head_dim: int = 8, + mult: float = 1.0, + norm_type: str = "batch_norm", + kernel_sizes: Tuple[int, ...] = (5,), + eps: float = 1e-15, + residual_connection: bool = False, + ): + super().__init__() + + # Core parameters + self.eps = eps + self.attention_head_dim = attention_head_dim + self.norm_type = norm_type + self.residual_connection = residual_connection + + # Calculate dimensions + num_attention_heads = ( + int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads + ) + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + self.heads = num_attention_heads + + # Query, key, value projections + self.to_q = nn.Linear(in_channels, inner_dim, bias=False) + self.to_k = nn.Linear(in_channels, inner_dim, bias=False) + self.to_v = nn.Linear(in_channels, inner_dim, bias=False) + + # Multi-scale attention + self.to_qkv_multiscale = nn.ModuleList() + for kernel_size in kernel_sizes: + self.to_qkv_multiscale.append( + SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size) + ) + + # Output layers + self.nonlinearity = nn.ReLU() + self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False) + + # Get normalization based on type + if norm_type == "batch_norm": + self.norm_out = nn.BatchNorm1d(out_channels) + elif norm_type == "layer_norm": + self.norm_out = nn.LayerNorm(out_channels) + elif norm_type == "group_norm": + self.norm_out = nn.GroupNorm(32, out_channels) + elif norm_type == "instance_norm": + self.norm_out = nn.InstanceNorm1d(out_channels) + else: + self.norm_out = nn.Identity() + + # Set processor + self.processor = self.default_processor_class() + + +class SanaMultiscaleAttentionProjection(nn.Module): + """Projection layer for Sana multi-scale attention.""" + + def __init__( + self, + in_channels: int, + num_attention_heads: int, + kernel_size: int, + ) -> None: + super().__init__() + + channels = 3 * in_channels + self.proj_in = nn.Conv2d( + channels, + channels, + kernel_size, + padding=kernel_size // 2, + groups=channels, + bias=False, + ) + self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj_in(hidden_states) + hidden_states = self.proj_out(hidden_states) + return hidden_states + + class GLUMBConv(nn.Module): def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 87537890d2..cfdafca63c 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -18,6 +18,7 @@ from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin @@ -25,12 +26,14 @@ from ...models.attention import FeedForward from ...models.attention_processor import ( Attention, AttentionProcessor, + AttentionModuleMixin, FluxAttnProcessor2_0, FluxAttnProcessor2_0_NPU, FusedFluxAttnProcessor2_0, ) from ...models.modeling_utils import ModelMixin -from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm +from ...utils.torch_utils import maybe_allow_in_graph from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph @@ -42,6 +45,216 @@ from ..modeling_outputs import Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class FluxAttnProcessor: + """Flux-specific attention processor that implements normalized attention with support for rotary embeddings.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("FluxAttnProcessor requires PyTorch 2.0, please upgrade PyTorch.") + + def __call__( + self, + attn, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.FloatTensor: + batch_size, seq_len, _ = hidden_states.shape + + # Project query from hidden states + query = attn.to_q(hidden_states) + + # Handle cross-attention vs self-attention + if encoder_hidden_states is None: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # If we have added_kv_proj_dim, handle additional projections + if hasattr(attn, "added_kv_proj_dim") and attn.added_kv_proj_dim is not None: + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_query = attn.add_q_proj(encoder_hidden_states) + + # Reshape + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + encoder_query = encoder_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + encoder_key = encoder_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + encoder_value = encoder_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # Apply normalization if available + if hasattr(attn, "norm_added_q") and attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if hasattr(attn, "norm_added_k") and attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + # Reshape for multi-head attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # Apply normalization if available + if hasattr(attn, "norm_q") and attn.norm_q is not None: + query = attn.norm_q(query) + if hasattr(attn, "norm_k") and attn.norm_k is not None: + key = attn.norm_k(key) + + # Handle rotary embeddings if provided + if image_rotary_emb is not None: + from ...models.embeddings import apply_rotary_emb + query = apply_rotary_emb(query, image_rotary_emb) + # Only apply to key in self-attention + if encoder_hidden_states is None: + key = apply_rotary_emb(key, image_rotary_emb) + + # Concatenate encoder projections if we have them + if encoder_hidden_states is not None and hasattr(attn, "added_kv_proj_dim") and attn.added_kv_proj_dim is not None: + # Concatenate for joint attention + query = torch.cat([encoder_query, query], dim=2) + key = torch.cat([encoder_key, key], dim=2) + value = torch.cat([encoder_value, value], dim=2) + + # Compute attention + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + # Reshape back + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split back if we did joint attention + if ( + encoder_hidden_states is not None + and hasattr(attn, "added_kv_proj_dim") + and attn.added_kv_proj_dim is not None + and hasattr(attn, "to_add_out") + and attn.to_add_out is not None + ): + context_len = encoder_hidden_states.shape[1] + encoder_hidden_states, hidden_states = ( + hidden_states[:, :context_len], + hidden_states[:, context_len:], + ) + + # Project output + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + # Project output + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +@maybe_allow_in_graph +class FluxAttention(nn.Module, AttentionModuleMixin): + """ + Specialized attention implementation for Flux models. + + This attention module provides optimized implementation for Flux models, + with support for RMSNorm, rotary embeddings, and added key/value projections. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + ): + super().__init__() + + # Core parameters + self.inner_dim = dim_head * heads + self.heads = heads + self.scale = dim_head ** -0.5 + self.use_bias = bias + self.scale_qk = True # Flux always uses scaled QK + + # Set cross-attention parameters + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + # Query, Key, Value projections + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + + # RMSNorm for Flux models + self.norm_q = RMSNorm(dim_head, eps=1e-6) + self.norm_k = RMSNorm(dim_head, eps=1e-6) + + # Optional added key/value projections for joint attention + self.added_kv_proj_dim = added_kv_proj_dim + if added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + + # Normalization for added projections + self.norm_added_q = RMSNorm(dim_head, eps=1e-6) + self.norm_added_k = RMSNorm(dim_head, eps=1e-6) + self.added_proj_bias = bias + + # Output projection for context + self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias) + + # Output projection and dropout + self.to_out = nn.ModuleList([ + nn.Linear(self.inner_dim, query_dim, bias=bias), + nn.Dropout(dropout) + ]) + + # Set processor + self.processor = FluxAttnProcessor() + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass for Flux attention. + + Args: + hidden_states: Input hidden states + encoder_hidden_states: Optional encoder hidden states for cross-attention + attention_mask: Optional attention mask + image_rotary_emb: Optional rotary embeddings for image tokens + + Returns: + Output hidden states, and optionally encoder hidden states for joint attention + """ + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + **kwargs, + ) + + @maybe_allow_in_graph class FluxSingleTransformerBlock(nn.Module): def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): @@ -53,27 +266,14 @@ class FluxSingleTransformerBlock(nn.Module): self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) - if is_torch_npu_available(): - deprecation_message = ( - "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " - "should be set explicitly using the `set_attn_processor` method." - ) - deprecate("npu_processor", "0.34.0", deprecation_message) - processor = FluxAttnProcessor2_0_NPU() - else: - processor = FluxAttnProcessor2_0() - - self.attn = Attention( + # Use specialized FluxAttention instead of generic Attention + self.attn = FluxAttention( query_dim=dim, cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, - out_dim=dim, + dropout=0.0, bias=True, - processor=processor, - qk_norm="rms_norm", - eps=1e-6, - pre_only=True, ) def forward( @@ -113,18 +313,15 @@ class FluxTransformerBlock(nn.Module): self.norm1 = AdaLayerNormZero(dim) self.norm1_context = AdaLayerNormZero(dim) - self.attn = Attention( + # Use specialized FluxAttention instead of generic Attention + self.attn = FluxAttention( query_dim=dim, cross_attention_dim=None, - added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, - out_dim=dim, - context_pre_only=False, + dropout=0.0, bias=True, - processor=FluxAttnProcessor2_0(), - qk_norm=qk_norm, - eps=eps, + added_kv_proj_dim=dim, ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index e6532f080d..4a559968ba 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -24,7 +24,7 @@ from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward -from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 +from ..attention_processor import MochiAttention, MochiAttnProcessor2_0, AttentionModuleMixin from ..cache_utils import CacheMixin from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -35,6 +35,85 @@ from ..normalization import AdaLayerNormContinuous, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name +@maybe_allow_in_graph +class MochiAttention(nn.Module, AttentionModuleMixin): + """ + Specialized attention module for Mochi video models. + + Features RMSNorm normalization and rotary position embeddings. + """ + # Set Mochi-specific processor classes + default_processor_class = MochiAttnProcessor2_0 + + def __init__( + self, + query_dim: int, + added_kv_proj_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_proj_bias: bool = True, + out_dim: Optional[int] = None, + out_context_dim: Optional[int] = None, + context_pre_only: bool = False, + eps: float = 1e-5, + ): + super().__init__() + + # Import here to avoid circular imports + from ..normalization import MochiRMSNorm + + # Core parameters + self.inner_dim = dim_head * heads + self.query_dim = query_dim + self.heads = heads + self.scale = dim_head ** -0.5 + self.use_bias = bias + self.scale_qk = True # Always use scaled attention + self.context_pre_only = context_pre_only + self.eps = eps + + # Set output dimensions + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim else query_dim + + # Self-attention projections + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) + + # Normalization for queries and keys + self.norm_q = MochiRMSNorm(dim_head, eps, True) + self.norm_k = MochiRMSNorm(dim_head, eps, True) + + # Added key/value projections for joint processing + self.added_kv_proj_dim = added_kv_proj_dim + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + + # Normalization for added projections + self.norm_added_q = MochiRMSNorm(dim_head, eps, True) + self.norm_added_k = MochiRMSNorm(dim_head, eps, True) + self.added_proj_bias = added_proj_bias + + # Output projections + self.to_out = nn.ModuleList([ + nn.Linear(self.inner_dim, self.out_dim, bias=bias), + nn.Dropout(dropout) + ]) + + # Context output projection + if not context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=added_proj_bias) + else: + self.to_add_out = None + + # Initialize attention processor using the default class + self.processor = self.default_processor_class() + + class MochiModulatedRMSNorm(nn.Module): def __init__(self, eps: float): super().__init__() diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index e41fad220d..a0dd576727 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -22,6 +22,7 @@ from ...models.attention import FeedForward, JointTransformerBlock from ...models.attention_processor import ( Attention, AttentionProcessor, + AttentionModuleMixin, FusedJointAttnProcessor2_0, JointAttnProcessor2_0, ) @@ -36,6 +37,208 @@ from ..modeling_outputs import Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class JointAttnProcessor: + """Attention processor used for processing joint attention.""" + + def __init__(self): + if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError("JointAttnProcessor requires PyTorch 2.0, please upgrade PyTorch.") + + def __call__( + self, + attn, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + **kwargs, + ) -> torch.FloatTensor: + batch_size, sequence_length, _ = hidden_states.shape + + # Project query from hidden states + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + # Self-attention: Use hidden_states for key and value + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + else: + # Cross-attention: Use encoder_hidden_states for key and value + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Handle additional context for joint attention + if hasattr(attn, "added_kv_proj_dim") and attn.added_kv_proj_dim is not None: + context_key = attn.add_k_proj(encoder_hidden_states) + context_value = attn.add_v_proj(encoder_hidden_states) + context_query = attn.add_q_proj(encoder_hidden_states) + + # Joint query, key, value with context + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + # Reshape for multi-head attention + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + context_query = context_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + context_key = context_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + context_value = context_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # Concatenate for joint attention + query = torch.cat([context_query, query], dim=2) + key = torch.cat([context_key, key], dim=2) + value = torch.cat([context_value, value], dim=2) + + # Apply joint attention + hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + # Reshape back to original dimensions + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # Split context and hidden states + context_len = encoder_hidden_states.shape[1] + encoder_hidden_states, hidden_states = ( + hidden_states[:, :context_len], + hidden_states[:, context_len:], + ) + + # Apply output projections + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if not attn.context_pre_only and hasattr(attn, "to_add_out") and attn.to_add_out is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + return hidden_states, encoder_hidden_states + + return hidden_states + + # Handle standard attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + # Reshape for multi-head attention + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # Apply attention + hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + # Reshape output + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # Apply output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +@maybe_allow_in_graph +class SD3Attention(nn.Module, AttentionModuleMixin): + """ + Specialized attention implementation for SD3 models. + + Features joint attention mechanisms and custom handling of + context projections. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + out_dim: Optional[int] = None, + context_pre_only: bool = False, + eps: float = 1e-6, + ): + super().__init__() + + # Core parameters + self.inner_dim = dim_head * heads + self.query_dim = query_dim + self.heads = heads + self.scale = dim_head ** -0.5 + self.scale_qk = True # SD3 always scales query-key dot products + self.use_bias = bias + self.context_pre_only = context_pre_only + self.eps = eps + + # Set output dimension + out_dim = out_dim if out_dim is not None else query_dim + + # Set cross-attention parameters + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + # Linear projections for self-attention + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) + + # Optional added key/value projections for joint attention + self.added_kv_proj_dim = added_kv_proj_dim + if added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + self.added_proj_bias = bias + + # Output projection for context + if not context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, out_dim, bias=bias) + else: + self.to_add_out = None + + # Output projection and dropout + self.to_out = nn.ModuleList([ + nn.Linear(self.inner_dim, out_dim, bias=bias), + nn.Dropout(dropout) + ]) + + # Set processor + self.processor = JointAttnProcessor() + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass for SD3 attention. + + Args: + hidden_states: Input hidden states + encoder_hidden_states: Optional encoder hidden states for cross/joint attention + attention_mask: Optional attention mask + position_ids: Optional position IDs + + Returns: + Output hidden states, and optionally encoder hidden states for joint attention + """ + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + **kwargs, + ) + + @maybe_allow_in_graph class SD3SingleTransformerBlock(nn.Module): def __init__( @@ -47,13 +250,13 @@ class SD3SingleTransformerBlock(nn.Module): super().__init__() self.norm1 = AdaLayerNormZero(dim) - self.attn = Attention( + # Use specialized SD3Attention instead of generic Attention + self.attn = SD3Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, bias=True, - processor=JointAttnProcessor2_0(), eps=1e-6, )