From 200e4ac462ff96c244f93af2859c775bda54295a Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 29 Apr 2025 22:57:59 +0530 Subject: [PATCH] update --- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/attention_modules.py | 8 +- src/diffusers/models/attention_processor.py | 58 +- .../models/transformers/sana_transformer.py | 24 +- .../models/transformers/transformer_flux.py | 885 +++++------------- .../models/transformers/transformer_mochi.py | 24 +- .../models/transformers/transformer_sd3.py | 65 +- 7 files changed, 298 insertions(+), 768 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 83a78c3fb8..d8b36c81de 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -107,8 +107,8 @@ if is_flax_available(): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .adapter import MultiAdapter, T2IAdapter - from .auto_model import AutoModel from .attention_modules import FluxAttention, SanaAttention, SD3Attention + from .auto_model import AutoModel from .autoencoders import ( AsymmetricAutoencoderKL, AutoencoderDC, diff --git a/src/diffusers/models/attention_modules.py b/src/diffusers/models/attention_modules.py index 5f36135b82..96b8438fb5 100644 --- a/src/diffusers/models/attention_modules.py +++ b/src/diffusers/models/attention_modules.py @@ -15,21 +15,17 @@ import inspect from typing import Optional, Tuple, Union import torch -import torch.nn.functional as F from torch import nn from ..utils import logging from ..utils.torch_utils import maybe_allow_in_graph from .attention_processor import ( AttentionModuleMixin, - AttnProcessorSDPA, - FluxAttnProcessorSDPA, - FusedFluxAttnProcessorSDPA, - JointAttnProcessorSDPA, FusedJointAttnProcessorSDPA, + JointAttnProcessorSDPA, SanaLinearAttnProcessorSDPA, ) -from .normalization import RMSNorm, get_normalization +from .normalization import get_normalization logger = logging.get_logger(__name__) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4b02a3f924..2bef1a0218 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -56,8 +56,13 @@ class AttentionModuleMixin: # Default processor classes to be overridden by subclasses default_processor_cls = None - fused_processor_cls = None - _available_processors = None + _available_processors = [] + + def _get_compatible_processor(self, backend): + for processor_cls in self._available_processors: + if backend in processor_cls.compatible_backends: + processor = processor_cls() + return processor def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: """ @@ -66,18 +71,11 @@ class AttentionModuleMixin: Args: use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not. """ + processor = self.default_processor_cls() + if use_npu_flash_attention: - processor = AttnProcessorNPU() - else: - # set attention processor - # We use the AttnProcessorSDPA by default when torch 2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - processor = ( - AttnProcessorSDPA() - if hasattr(F, "scaled_dot_product_attention") and self.scale_qk - else AttnProcessor() - ) + processor = self._get_compatible_processor("npu") + self.set_processor(processor) def set_use_xla_flash_attention( @@ -97,24 +95,17 @@ class AttentionModuleMixin: is_flux (`bool`, *optional*, defaults to `False`): Whether the model is a Flux model. """ + processor = self.default_processor_cls() if use_xla_flash_attention: - if not is_torch_xla_available: + if not is_torch_xla_available(): raise "torch_xla is not available" elif is_torch_xla_version("<", "2.3"): raise "flash attention pallas kernel is supported from torch_xla version 2.3" elif is_spmd() and is_torch_xla_version("<", "2.4"): raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4" else: - if is_flux: - processor = XLAFluxFlashAttnProcessorSDPA(partition_spec) - else: - processor = XLAFlashAttnProcessorSDPA(partition_spec) - else: - processor = ( - AttnProcessorSDPA() - if hasattr(F, "scaled_dot_product_attention") and self.scale_qk - else AttnProcessor() - ) + processor = self._get_compatible_processor("xla") + self.set_processor(processor) @torch.no_grad() @@ -179,11 +170,7 @@ class AttentionModuleMixin: self.to_added_qkv.bias.copy_(concatenated_bias) self.fused_projections = fuse - - # Update processor based on fusion state - processor_class = self.fused_processor_class if fuse else self.default_processor_class - if processor_class is not None: - self.set_processor(processor_class()) + self.processor.is_fused = fuse def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None @@ -557,8 +544,9 @@ class AttentionModuleMixin: @maybe_allow_in_graph class Attention(nn.Module, AttentionModuleMixin): # Set default and fused processor classes - default_processor_class = AttnProcessorSDPA - fused_processor_class = None # Will be set appropriately in the future + default_processor_class = None + _available_processors = [] + r""" A cross attention layer. @@ -958,7 +946,10 @@ class SanaMultiscaleLinearAttention(nn.Module): return self.processor(self, hidden_states) -class MochiAttention(nn.Module): +class MochiAttention(nn.Module, AttentionModuleMixin): + default_processor_cls = MochiAttnProcessorSDPA + _available_processors = [MochiAttnProcessorSDPA] + def __init__( self, query_dim: int, @@ -1006,7 +997,8 @@ class MochiAttention(nn.Module): if not self.context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) - self.processor = processor + processor = processor if processor is not None else self.default_processor_cls() + self.set_processor(processor) def forward( self, diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index ff730e9454..f844130fab 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -23,8 +23,8 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention_processor import ( Attention, - AttentionProcessor, AttentionModuleMixin, + AttentionProcessor, SanaLinearAttnProcessor2_0, ) from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps @@ -45,7 +45,7 @@ class SanaAttention(nn.Module, AttentionModuleMixin): """ # Set Sana-specific processor classes default_processor_class = SanaLinearAttnProcessor2_0 - + def __init__( self, in_channels: int, @@ -59,13 +59,13 @@ class SanaAttention(nn.Module, AttentionModuleMixin): residual_connection: bool = False, ): super().__init__() - + # Core parameters self.eps = eps self.attention_head_dim = attention_head_dim self.norm_type = norm_type self.residual_connection = residual_connection - + # Calculate dimensions num_attention_heads = ( int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads @@ -73,23 +73,23 @@ class SanaAttention(nn.Module, AttentionModuleMixin): inner_dim = num_attention_heads * attention_head_dim self.inner_dim = inner_dim self.heads = num_attention_heads - + # Query, key, value projections self.to_q = nn.Linear(in_channels, inner_dim, bias=False) self.to_k = nn.Linear(in_channels, inner_dim, bias=False) self.to_v = nn.Linear(in_channels, inner_dim, bias=False) - + # Multi-scale attention self.to_qkv_multiscale = nn.ModuleList() for kernel_size in kernel_sizes: self.to_qkv_multiscale.append( SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size) ) - + # Output layers self.nonlinearity = nn.ReLU() self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False) - + # Get normalization based on type if norm_type == "batch_norm": self.norm_out = nn.BatchNorm1d(out_channels) @@ -101,14 +101,14 @@ class SanaAttention(nn.Module, AttentionModuleMixin): self.norm_out = nn.InstanceNorm1d(out_channels) else: self.norm_out = nn.Identity() - + # Set processor self.processor = self.default_processor_class() class SanaMultiscaleAttentionProjection(nn.Module): """Projection layer for Sana multi-scale attention.""" - + def __init__( self, in_channels: int, @@ -116,7 +116,7 @@ class SanaMultiscaleAttentionProjection(nn.Module): kernel_size: int, ) -> None: super().__init__() - + channels = 3 * in_channels self.proj_in = nn.Conv2d( channels, @@ -127,7 +127,7 @@ class SanaMultiscaleAttentionProjection(nn.Module): bias=False, ) self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False) - + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.proj_in(hidden_states) hidden_states = self.proj_out(hidden_states) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index b28cea6421..c479fa939d 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -13,6 +13,7 @@ # limitations under the License. +import math from typing import Any, Dict, Optional, Tuple, Union import numpy as np @@ -23,19 +24,11 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...models.attention import FeedForward -from ...models.attention_processor import ( - Attention, - AttentionProcessor, - AttentionModuleMixin, - FluxAttnProcessor2_0, - FluxAttnProcessor2_0_NPU, - FusedFluxAttnProcessor2_0, -) +from ...models.attention_processor import AttentionModuleMixin from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm -from ...utils.torch_utils import maybe_allow_in_graph -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers -from ...utils.import_utils import is_torch_npu_available +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.import_utils import is_torch_npu_available, is_torch_xla_available from ...utils.torch_utils import maybe_allow_in_graph from ..cache_utils import CacheMixin from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed @@ -45,56 +38,75 @@ from ..modeling_outputs import Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class FluxAttnProcessor: - """Flux-specific attention processor that implements normalized attention with support for rotary embeddings.""" +class BaseFluxAttnProcessor: + """Base attention processor for Flux models with common functionality.""" + + is_fused = False + compatible_backends = [] def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("FluxAttnProcessor requires PyTorch 2.0, please upgrade PyTorch.") + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0, please upgrade PyTorch to 2.0.") + + def get_projections(self, attn, hidden_states, encoder_hidden_states=None): + """Public method to get projections based on whether we're using fused mode or not.""" + if self.is_fused and hasattr(attn, "to_qkv"): + return self._get_fused_projections(attn, hidden_states, encoder_hidden_states) + + return self._get_projections(attn, hidden_states, encoder_hidden_states) + + def _get_projections(self, attn, hidden_states, encoder_hidden_states=None): + """Get projections using standard separate projection matrices.""" + # Standard separate projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # Handle encoder projections if present + encoder_projections = None + if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"): + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_projections = (encoder_query, encoder_key, encoder_value) + + return query, key, value, encoder_projections + + def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None): + """Get projections using fused QKV projection matrices.""" + # Fused QKV projection + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + # Handle encoder projections if present + encoder_projections = None + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1) + encoder_projections = (encoder_query, encoder_key, encoder_value) + + return query, key, value, encoder_projections + + def _compute_attention(self, query, key, value, attention_mask=None): + """Computes the attention. Can be overridden by hardware-specific implementations.""" + return F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) def __call__( self, - attn, + attn: "FluxAttention", hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - **kwargs, ) -> torch.FloatTensor: - batch_size, seq_len, _ = hidden_states.shape + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - # Project query from hidden states - query = attn.to_q(hidden_states) + query, key, value, encoder_projections = self.get_projections(attn, hidden_states, encoder_hidden_states) - # Handle cross-attention vs self-attention - if encoder_hidden_states is None: - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - else: - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - # If we have added_kv_proj_dim, handle additional projections - if hasattr(attn, "added_kv_proj_dim") and attn.added_kv_proj_dim is not None: - encoder_key = attn.add_k_proj(encoder_hidden_states) - encoder_value = attn.add_v_proj(encoder_hidden_states) - encoder_query = attn.add_q_proj(encoder_hidden_states) - - # Reshape - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - encoder_query = encoder_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - encoder_key = encoder_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - encoder_value = encoder_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # Apply normalization if available - if hasattr(attn, "norm_added_q") and attn.norm_added_q is not None: - encoder_query = attn.norm_added_q(encoder_query) - if hasattr(attn, "norm_added_k") and attn.norm_added_k is not None: - encoder_key = attn.norm_added_k(encoder_key) - - # Reshape for multi-head attention inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -102,614 +114,124 @@ class FluxAttnProcessor: key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # Apply normalization if available - if hasattr(attn, "norm_q") and attn.norm_q is not None: + if attn.norm_q is not None: query = attn.norm_q(query) - if hasattr(attn, "norm_k") and attn.norm_k is not None: + if attn.norm_k is not None: key = attn.norm_k(key) - # Handle rotary embeddings if provided - if image_rotary_emb is not None: - from ...models.embeddings import apply_rotary_emb + if encoder_projections is not None: + encoder_query, encoder_key, encoder_value = encoder_projections + encoder_query = encoder_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + encoder_key = encoder_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + encoder_value = encoder_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - query = apply_rotary_emb(query, image_rotary_emb) - # Only apply to key in self-attention - if encoder_hidden_states is None: - key = apply_rotary_emb(key, image_rotary_emb) + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) - # Concatenate encoder projections if we have them - if ( - encoder_hidden_states is not None - and hasattr(attn, "added_kv_proj_dim") - and attn.added_kv_proj_dim is not None - ): # Concatenate for joint attention query = torch.cat([encoder_query, query], dim=2) key = torch.cat([encoder_key, key], dim=2) value = torch.cat([encoder_value, value], dim=2) - # Compute attention - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = self._compute_attention(query, key, value, attention_mask) - # Reshape back hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - # Split back if we did joint attention - if ( - encoder_hidden_states is not None - and hasattr(attn, "added_kv_proj_dim") - and attn.added_kv_proj_dim is not None - and hasattr(attn, "to_add_out") - and attn.to_add_out is not None - ): - context_len = encoder_hidden_states.shape[1] + if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = ( - hidden_states[:, :context_len], - hidden_states[:, context_len:], + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], ) - # Project output hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states else: - # Project output - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - return hidden_states -@maybe_allow_in_graph -class FluxAttention(nn.Module, AttentionModuleMixin): - """ - Specialized attention implementation for Flux models. +class FluxAttnProcessorSDPA(BaseFluxAttnProcessor): + compatible_backends = ["cuda", "xpu", "cpu"] - This attention module provides optimized implementation for Flux models, - with support for RMSNorm, rotary embeddings, and added key/value projections. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - added_kv_proj_dim: Optional[int] = None, - ): + def __init__(self): super().__init__() - # Core parameters - self.inner_dim = dim_head * heads - self.heads = heads - self.scale = dim_head**-0.5 - self.use_bias = bias - self.scale_qk = True # Flux always uses scaled QK - # Set cross-attention parameters - self.is_cross_attention = cross_attention_dim is not None - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim +class FluxAttnProcessorNPU(BaseFluxAttnProcessor): + """NPU-specific implementation of Flux attention processor.""" - # Query, Key, Value projections - self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) - self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + compatible_backends = ["npu"] - # RMSNorm for Flux models - self.norm_q = RMSNorm(dim_head, eps=1e-6) - self.norm_k = RMSNorm(dim_head, eps=1e-6) - - # Optional added key/value projections for joint attention - self.added_kv_proj_dim = added_kv_proj_dim - if added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) - self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) - self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) - - # Normalization for added projections - self.norm_added_q = RMSNorm(dim_head, eps=1e-6) - self.norm_added_k = RMSNorm(dim_head, eps=1e-6) - self.added_proj_bias = bias - - # Output projection for context - self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias) - - # Output projection and dropout - self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, query_dim, bias=bias), nn.Dropout(dropout)]) - - # Set processor - self.processor = FluxAttnProcessor() - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Forward pass for Flux attention. - - Args: - hidden_states: Input hidden states - encoder_hidden_states: Optional encoder hidden states for cross-attention - attention_mask: Optional attention mask - image_rotary_emb: Optional rotary embeddings for image tokens - - Returns: - Output hidden states, and optionally encoder hidden states for joint attention - """ - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - image_rotary_emb=image_rotary_emb, - **kwargs, - ) - - -@maybe_allow_in_graph -class FluxSingleTransformerBlock(nn.Module): - def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + def __init__(self): super().__init__() - self.mlp_hidden_dim = int(dim * mlp_ratio) + if not is_torch_npu_available(): + raise ImportError("FluxAttnProcessorNPU requires torch_npu, please install it.") + import torch_npu - self.norm = AdaLayerNormZeroSingle(dim) - self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) - self.act_mlp = nn.GELU(approximate="tanh") - self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + self.attn_fn = torch_npu.npu_fusion_attention - # Use specialized FluxAttention instead of generic Attention - self.attn = FluxAttention( - query_dim=dim, - cross_attention_dim=None, - dim_head=attention_head_dim, - heads=num_attention_heads, - dropout=0.0, - bias=True, - ) + def _compute_attention(self, query, key, value, attention_mask=None): + if query.dtype in (torch.float16, torch.bfloat16): + # NPU-specific implementation + return self.attn_fn( + query, + key, + value, + query.shape[1], # number of heads + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + else: + # Fall back to standard implementation for other dtypes + return F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) - def forward( - self, - hidden_states: torch.Tensor, - temb: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - ) -> torch.Tensor: - residual = hidden_states - norm_hidden_states, gate = self.norm(hidden_states, emb=temb) - mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) - joint_attention_kwargs = joint_attention_kwargs or {} - attn_output = self.attn( - hidden_states=norm_hidden_states, - image_rotary_emb=image_rotary_emb, - **joint_attention_kwargs, - ) - hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) - gate = gate.unsqueeze(1) - hidden_states = gate * self.proj_out(hidden_states) - hidden_states = residual + hidden_states - if hidden_states.dtype == torch.float16: - hidden_states = hidden_states.clip(-65504, 65504) +class FluxAttnProcessorXLA(BaseFluxAttnProcessor): + """XLA-specific implementation of Flux attention processor.""" + + compatible_backends = ["xla"] + + def __init__(self): + super().__init__() + + if not is_torch_xla_available(): + raise ImportError( + "FluxAttnProcessorXLA requires torch_xla, please install it using `pip install torch_xla`" + ) + if is_torch_xla_version("<", "2.3"): + raise ImportError("XLA flash attention requires torch_xla version >= 2.3.") + + from torch_xla.experimental.custom_kernel import flash_attention + + self.attn_fn = flash_attention + + def _compute_attention(self, query, key, value, attention_mask=None): + query /= math.sqrt(query.shape[3]) + hidden_states = self.attn_fn(query, key, value, causal=False) return hidden_states -class FluxAttnProcessorSDPA: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("FluxAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FluxAttnProcessorNPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessorSDPA_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU" - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if query.dtype in (torch.float16, torch.bfloat16): - hidden_states = torch_npu.npu_fusion_attention( - query, - key, - value, - attn.heads, - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FusedFluxAttnProcessorSDPA: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FusedFluxAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - # `context` projections. - if encoder_hidden_states is not None: - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - split_size = encoder_qkv.shape[-1] // 3 - ( - encoder_hidden_states_query_proj, - encoder_hidden_states_key_proj, - encoder_hidden_states_value_proj, - ) = torch.split(encoder_qkv, split_size, dim=-1) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FusedFluxAttnProcessorNPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessorSDPA_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU" - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - # `context` projections. - if encoder_hidden_states is not None: - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - split_size = encoder_qkv.shape[-1] // 3 - ( - encoder_hidden_states_query_proj, - encoder_hidden_states_key_proj, - encoder_hidden_states_value_proj, - ) = torch.split(encoder_qkv, split_size, dim=-1) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if query.dtype in (torch.float16, torch.bfloat16): - hidden_states = torch_npu.npu_fusion_attention( - query, - key, - value, - attn.heads, - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FluxIPAdapterJointAttnProcessorSDPA(torch.nn.Module): +class FluxIPAdapterAttnProcessorSDPA(torch.nn.Module): """Flux Attention processor for IP-Adapter.""" def __init__( @@ -856,33 +378,23 @@ class FluxIPAdapterJointAttnProcessorSDPA(torch.nn.Module): @maybe_allow_in_graph class FluxAttention(nn.Module, AttentionModuleMixin): """ + Specialized attention implementation for Flux models. - Args: - query_dim (`int`): Number of channels in query. - cross_attention_dim (`int`, *optional*): Number of channels in encoder states. - heads (`int`, defaults to 8): Number of attention heads. - dim_head (`int`, defaults to 64): Dimension of each attention head. - dropout (`float`, defaults to 0.0): Dropout probability. - bias (`bool`, defaults to False): Whether to use bias in linear projections. - added_kv_proj_dim (`int`, *optional*): Dimension for added key/value projections. + This attention module provides optimized implementation for Flux models, + with support for RMSNorm, rotary embeddings, and added key/value projections. """ - # Set Flux-specific processor classes - default_processor_cls = FluxAttnProcessorSDPA - fused_processor_cls = FusedFluxAttnProcessorSDPA - + _default_processor_cls = FluxAttnProcessorSDPA _available_processors = [ FluxAttnProcessorSDPA, - FusedFluxAttnProcessorSDPA, FluxAttnProcessorNPU, - FusedFluxAttnProcessorNPU, - FluxIPAdapterJointAttnProcessorSDPA, + FluxAttnProcessorXLA, + FluxIPAdapterAttnProcessorSDPA, ] def __init__( self, query_dim: int, - cross_attention_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, @@ -893,26 +405,20 @@ class FluxAttention(nn.Module, AttentionModuleMixin): # Core parameters self.inner_dim = dim_head * heads - self.query_dim = query_dim self.heads = heads self.scale = dim_head**-0.5 self.use_bias = bias - self.scale_qk = True # Flux always uses scale_qk - # Cross-attention setup - self.is_cross_attention = cross_attention_dim is not None - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - - # Projections + # Query, Key, Value projections self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) - self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) - # Flux-specific normalization + # RMSNorm for Flux models self.norm_q = RMSNorm(dim_head, eps=1e-6) self.norm_k = RMSNorm(dim_head, eps=1e-6) - # Added projections for cross-attention + # Optional added key/value projections for joint attention self.added_kv_proj_dim = added_kv_proj_dim if added_kv_proj_dim is not None: self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) @@ -924,51 +430,91 @@ class FluxAttention(nn.Module, AttentionModuleMixin): self.norm_added_k = RMSNorm(dim_head, eps=1e-6) self.added_proj_bias = bias - # Output projection + # Output projection for context + self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias) + + # Output projection and dropout self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, query_dim, bias=bias), nn.Dropout(dropout)]) - # For cross-attention with added projections - if added_kv_proj_dim is not None: - self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias) - else: - self.to_add_out = None - - # Set default processor and fusion state - self.fused_projections = False - self.set_processor(self.default_processor_class()) + # Set processor + self.processor = self.set_processor(self.default_processor_cls()) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Process attention for Flux model inputs.""" - # Filter parameters to only those expected by the processor - processor_params = inspect.signature(self.processor.__call__).parameters.keys() - quiet_params = {"ip_adapter_masks", "ip_hidden_states"} + """ + Forward pass for Flux attention. - # Check for unexpected parameters - unexpected_params = [k for k, _ in kwargs.items() if k not in processor_params and k not in quiet_params] - if unexpected_params: - logger.warning( - f"Parameters {unexpected_params} are not expected by {self.processor.__class__.__name__} and will be ignored." - ) + Args: + hidden_states: Input hidden states + encoder_hidden_states: Optional encoder hidden states for cross-attention + attention_mask: Optional attention mask + image_rotary_emb: Optional rotary embeddings for image tokens - # Filter to only expected parameters - filtered_kwargs = {k: v for k, v in kwargs.items() if k in processor_params} - - # Process with appropriate processor + Returns: + Output hidden states, and optionally encoder hidden states for joint attention + """ return self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, - **filtered_kwargs, + image_rotary_emb=image_rotary_emb, + **kwargs, ) +@maybe_allow_in_graph +class FluxSingleTransformerBlock(nn.Module): + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + self.attn = FluxAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + dropout=0.0, + bias=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return hidden_states + + @maybe_allow_in_graph class FluxTransformerBlock(nn.Module): def __init__( @@ -1209,7 +755,6 @@ class FluxTransformer2DModel( for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) @@ -1233,8 +778,6 @@ class FluxTransformer2DModel( if isinstance(module, Attention): module.fuse_projections(fuse=True) - self.set_attn_processor(FusedFluxAttnProcessor2_0()) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 4a559968ba..f67e16feaa 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -24,7 +24,7 @@ from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward -from ..attention_processor import MochiAttention, MochiAttnProcessor2_0, AttentionModuleMixin +from ..attention_processor import AttentionModuleMixin, MochiAttention, MochiAttnProcessor2_0 from ..cache_utils import CacheMixin from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -44,7 +44,7 @@ class MochiAttention(nn.Module, AttentionModuleMixin): """ # Set Mochi-specific processor classes default_processor_class = MochiAttnProcessor2_0 - + def __init__( self, query_dim: int, @@ -60,10 +60,10 @@ class MochiAttention(nn.Module, AttentionModuleMixin): eps: float = 1e-5, ): super().__init__() - + # Import here to avoid circular imports from ..normalization import MochiRMSNorm - + # Core parameters self.inner_dim = dim_head * heads self.query_dim = query_dim @@ -73,43 +73,43 @@ class MochiAttention(nn.Module, AttentionModuleMixin): self.scale_qk = True # Always use scaled attention self.context_pre_only = context_pre_only self.eps = eps - + # Set output dimensions self.out_dim = out_dim if out_dim is not None else query_dim self.out_context_dim = out_context_dim if out_context_dim else query_dim - + # Self-attention projections self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) - + # Normalization for queries and keys self.norm_q = MochiRMSNorm(dim_head, eps, True) self.norm_k = MochiRMSNorm(dim_head, eps, True) - + # Added key/value projections for joint processing self.added_kv_proj_dim = added_kv_proj_dim self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - + # Normalization for added projections self.norm_added_q = MochiRMSNorm(dim_head, eps, True) self.norm_added_k = MochiRMSNorm(dim_head, eps, True) self.added_proj_bias = added_proj_bias - + # Output projections self.to_out = nn.ModuleList([ nn.Linear(self.inner_dim, self.out_dim, bias=bias), nn.Dropout(dropout) ]) - + # Context output projection if not context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=added_proj_bias) else: self.to_add_out = None - + # Initialize attention processor using the default class self.processor = self.default_processor_class() diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index a0dd576727..3c2d6904bd 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -21,10 +21,9 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2 from ...models.attention import FeedForward, JointTransformerBlock from ...models.attention_processor import ( Attention, - AttentionProcessor, AttentionModuleMixin, + AttentionProcessor, FusedJointAttnProcessor2_0, - JointAttnProcessor2_0, ) from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero @@ -39,11 +38,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name class JointAttnProcessor: """Attention processor used for processing joint attention.""" - + def __init__(self): if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): raise ImportError("JointAttnProcessor requires PyTorch 2.0, please upgrade PyTorch.") - + def __call__( self, attn, @@ -54,10 +53,10 @@ class JointAttnProcessor: **kwargs, ) -> torch.FloatTensor: batch_size, sequence_length, _ = hidden_states.shape - + # Project query from hidden states query = attn.to_q(hidden_states) - + if encoder_hidden_states is None: # Self-attention: Use hidden_states for key and value key = attn.to_k(hidden_states) @@ -66,77 +65,77 @@ class JointAttnProcessor: # Cross-attention: Use encoder_hidden_states for key and value key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) - + # Handle additional context for joint attention if hasattr(attn, "added_kv_proj_dim") and attn.added_kv_proj_dim is not None: context_key = attn.add_k_proj(encoder_hidden_states) context_value = attn.add_v_proj(encoder_hidden_states) context_query = attn.add_q_proj(encoder_hidden_states) - + # Joint query, key, value with context inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads - + # Reshape for multi-head attention query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - + context_query = context_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) context_key = context_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) context_value = context_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - + # Concatenate for joint attention query = torch.cat([context_query, query], dim=2) key = torch.cat([context_key, key], dim=2) value = torch.cat([context_value, value], dim=2) - + # Apply joint attention hidden_states = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - + # Reshape back to original dimensions hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - + # Split context and hidden states context_len = encoder_hidden_states.shape[1] encoder_hidden_states, hidden_states = ( hidden_states[:, :context_len], hidden_states[:, context_len:], ) - + # Apply output projections hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) - + if not attn.context_pre_only and hasattr(attn, "to_add_out") and attn.to_add_out is not None: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states - + return hidden_states - + # Handle standard attention inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads - + # Reshape for multi-head attention query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - + # Apply attention hidden_states = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - + # Reshape output hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - + # Apply output projection hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) - + return hidden_states @@ -148,7 +147,7 @@ class SD3Attention(nn.Module, AttentionModuleMixin): Features joint attention mechanisms and custom handling of context projections. """ - + def __init__( self, query_dim: int, @@ -163,7 +162,7 @@ class SD3Attention(nn.Module, AttentionModuleMixin): eps: float = 1e-6, ): super().__init__() - + # Core parameters self.inner_dim = dim_head * heads self.query_dim = query_dim @@ -173,19 +172,19 @@ class SD3Attention(nn.Module, AttentionModuleMixin): self.use_bias = bias self.context_pre_only = context_pre_only self.eps = eps - + # Set output dimension out_dim = out_dim if out_dim is not None else query_dim - + # Set cross-attention parameters self.is_cross_attention = cross_attention_dim is not None self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - + # Linear projections for self-attention self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) - + # Optional added key/value projections for joint attention self.added_kv_proj_dim = added_kv_proj_dim if added_kv_proj_dim is not None: @@ -193,22 +192,22 @@ class SD3Attention(nn.Module, AttentionModuleMixin): self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) self.added_proj_bias = bias - + # Output projection for context if not context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, out_dim, bias=bias) else: self.to_add_out = None - + # Output projection and dropout self.to_out = nn.ModuleList([ nn.Linear(self.inner_dim, out_dim, bias=bias), nn.Dropout(dropout) ]) - + # Set processor self.processor = JointAttnProcessor() - + def forward( self, hidden_states: torch.Tensor,