From 94ae28edea860bd6b88b24fe1feeed8eeb184cef Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 28 Apr 2025 22:39:21 +0530 Subject: [PATCH] update --- src/diffusers/models/attention_modules.py | 188 +---- src/diffusers/models/attention_processor.py | 559 +------------ .../models/transformers/transformer_flux.py | 750 +++++++++++++++++- 3 files changed, 748 insertions(+), 749 deletions(-) diff --git a/src/diffusers/models/attention_modules.py b/src/diffusers/models/attention_modules.py index 473781bed3..5f36135b82 100644 --- a/src/diffusers/models/attention_modules.py +++ b/src/diffusers/models/attention_modules.py @@ -39,9 +39,9 @@ logger = logging.get_logger(__name__) class SanaAttention(nn.Module, AttentionModuleMixin): """ Attention implementation specialized for Sana models. - + This module implements lightweight multi-scale linear attention as used in Sana. - + Args: in_channels (`int`): Number of input channels. out_channels (`int`): Number of output channels. @@ -51,10 +51,11 @@ class SanaAttention(nn.Module, AttentionModuleMixin): norm_type (`str`, defaults to "batch_norm"): Type of normalization. kernel_sizes (`Tuple[int, ...]`, defaults to (5,)): Kernel sizes for multi-scale attention. """ + # Set Sana-specific processor classes default_processor_class = SanaLinearAttnProcessorSDPA fused_processor_class = None # Sana doesn't have a fused processor yet - + def __init__( self, in_channels: int, @@ -68,13 +69,13 @@ class SanaAttention(nn.Module, AttentionModuleMixin): residual_connection: bool = False, ): super().__init__() - + # Core parameters self.eps = eps self.attention_head_dim = attention_head_dim self.norm_type = norm_type self.residual_connection = residual_connection - + # Calculate dimensions num_attention_heads = ( int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads @@ -82,28 +83,28 @@ class SanaAttention(nn.Module, AttentionModuleMixin): inner_dim = num_attention_heads * attention_head_dim self.inner_dim = inner_dim self.heads = num_attention_heads - + # Query, key, value projections self.to_q = nn.Linear(in_channels, inner_dim, bias=False) self.to_k = nn.Linear(in_channels, inner_dim, bias=False) self.to_v = nn.Linear(in_channels, inner_dim, bias=False) - + # Multi-scale attention self.to_qkv_multiscale = nn.ModuleList() for kernel_size in kernel_sizes: self.to_qkv_multiscale.append( SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size) ) - + # Output layers self.nonlinearity = nn.ReLU() self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False) self.norm_out = get_normalization(norm_type, num_features=out_channels) - + # Set default processor self.fused_projections = False self.set_processor(self.default_processor_class()) - + def forward( self, hidden_states: torch.Tensor, @@ -117,7 +118,7 @@ class SanaAttention(nn.Module, AttentionModuleMixin): class SanaMultiscaleAttentionProjection(nn.Module): """Projection layer for Sana multi-scale attention.""" - + def __init__( self, in_channels: int, @@ -125,7 +126,7 @@ class SanaMultiscaleAttentionProjection(nn.Module): kernel_size: int, ) -> None: super().__init__() - + channels = 3 * in_channels self.proj_in = nn.Conv2d( channels, @@ -136,138 +137,21 @@ class SanaMultiscaleAttentionProjection(nn.Module): bias=False, ) self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False) - + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.proj_in(hidden_states) hidden_states = self.proj_out(hidden_states) return hidden_states -@maybe_allow_in_graph -class FluxAttention(nn.Module, AttentionModuleMixin): - """ - Attention implementation specialized for Flux models. - - This module uses RMSNorm for query and key normalization and supports - rotary embeddings through its processor. - - Args: - query_dim (`int`): Number of channels in query. - cross_attention_dim (`int`, *optional*): Number of channels in encoder states. - heads (`int`, defaults to 8): Number of attention heads. - dim_head (`int`, defaults to 64): Dimension of each attention head. - dropout (`float`, defaults to 0.0): Dropout probability. - bias (`bool`, defaults to False): Whether to use bias in linear projections. - added_kv_proj_dim (`int`, *optional*): Dimension for added key/value projections. - """ - # Set Flux-specific processor classes - default_processor_class = FluxAttnProcessorSDPA - fused_processor_class = FusedFluxAttnProcessorSDPA - - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - added_kv_proj_dim: Optional[int] = None, - ): - super().__init__() - - # Core parameters - self.inner_dim = dim_head * heads - self.query_dim = query_dim - self.heads = heads - self.scale = dim_head ** -0.5 - self.use_bias = bias - self.scale_qk = True # Flux always uses scale_qk - - # Cross-attention setup - self.is_cross_attention = cross_attention_dim is not None - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - - # Projections - self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) - self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) - - # Flux-specific normalization - self.norm_q = RMSNorm(dim_head, eps=1e-6) - self.norm_k = RMSNorm(dim_head, eps=1e-6) - - # Added projections for cross-attention - self.added_kv_proj_dim = added_kv_proj_dim - if added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) - self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) - self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) - - # Normalization for added projections - self.norm_added_q = RMSNorm(dim_head, eps=1e-6) - self.norm_added_k = RMSNorm(dim_head, eps=1e-6) - self.added_proj_bias = bias - - # Output projection - self.to_out = nn.ModuleList([ - nn.Linear(self.inner_dim, query_dim, bias=bias), - nn.Dropout(dropout) - ]) - - # For cross-attention with added projections - if added_kv_proj_dim is not None: - self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias) - else: - self.to_add_out = None - - # Set default processor and fusion state - self.fused_projections = False - self.set_processor(self.default_processor_class()) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Process attention for Flux model inputs.""" - # Filter parameters to only those expected by the processor - processor_params = inspect.signature(self.processor.__call__).parameters.keys() - quiet_params = {"ip_adapter_masks", "ip_hidden_states"} - - # Check for unexpected parameters - unexpected_params = [ - k for k, _ in kwargs.items() - if k not in processor_params and k not in quiet_params - ] - if unexpected_params: - logger.warning( - f"Parameters {unexpected_params} are not expected by {self.processor.__class__.__name__} and will be ignored." - ) - - # Filter to only expected parameters - filtered_kwargs = {k: v for k, v in kwargs.items() if k in processor_params} - - # Process with appropriate processor - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **filtered_kwargs, - ) - - @maybe_allow_in_graph class SD3Attention(nn.Module, AttentionModuleMixin): """ Attention implementation specialized for SD3 models. - + This module implements the joint attention mechanism used in SD3, with native support for context pre-processing. - + Args: query_dim (`int`): Number of channels in query. cross_attention_dim (`int`, *optional*): Number of channels in encoder states. @@ -277,10 +161,11 @@ class SD3Attention(nn.Module, AttentionModuleMixin): bias (`bool`, defaults to False): Whether to use bias in linear projections. added_kv_proj_dim (`int`, *optional*): Dimension for added key/value projections. """ + # Set SD3-specific processor classes default_processor_class = JointAttnProcessorSDPA fused_processor_class = FusedJointAttnProcessorSDPA - + def __init__( self, query_dim: int, @@ -293,25 +178,25 @@ class SD3Attention(nn.Module, AttentionModuleMixin): context_pre_only: bool = False, ): super().__init__() - + # Core parameters self.inner_dim = dim_head * heads self.query_dim = query_dim self.heads = heads - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.use_bias = bias self.scale_qk = True self.context_pre_only = context_pre_only - + # Cross-attention setup self.is_cross_attention = cross_attention_dim is not None self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - + # Projections for self-attention self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) - + # Added projections for context processing self.added_kv_proj_dim = added_kv_proj_dim if added_kv_proj_dim is not None: @@ -319,23 +204,20 @@ class SD3Attention(nn.Module, AttentionModuleMixin): self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) self.added_proj_bias = bias - + # Output projection - self.to_out = nn.ModuleList([ - nn.Linear(self.inner_dim, query_dim, bias=bias), - nn.Dropout(dropout) - ]) - + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, query_dim, bias=bias), nn.Dropout(dropout)]) + # Context output projection if added_kv_proj_dim is not None and not context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias) else: self.to_add_out = None - + # Set default processor and fusion state self.fused_projections = False self.set_processor(self.default_processor_class()) - + def forward( self, hidden_states: torch.Tensor, @@ -347,20 +229,17 @@ class SD3Attention(nn.Module, AttentionModuleMixin): # Filter parameters to only those expected by the processor processor_params = inspect.signature(self.processor.__call__).parameters.keys() quiet_params = {"ip_adapter_masks", "ip_hidden_states"} - + # Check for unexpected parameters - unexpected_params = [ - k for k, _ in kwargs.items() - if k not in processor_params and k not in quiet_params - ] + unexpected_params = [k for k, _ in kwargs.items() if k not in processor_params and k not in quiet_params] if unexpected_params: logger.warning( f"Parameters {unexpected_params} are not expected by {self.processor.__class__.__name__} and will be ignored." ) - + # Filter to only expected parameters filtered_kwargs = {k: v for k, v in kwargs.items() if k in processor_params} - + # Process with appropriate processor return self.processor( self, @@ -368,4 +247,5 @@ class SD3Attention(nn.Module, AttentionModuleMixin): encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **filtered_kwargs, - ) \ No newline at end of file + ) + diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index acf315aac8..4b02a3f924 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -53,10 +53,11 @@ class AttentionModuleMixin: This mixin adds functionality to set different attention processors, handle attention masks, compute attention scores, and manage projections. """ - + # Default processor classes to be overridden by subclasses - default_processor_class = None - fused_processor_class = None + default_processor_cls = None + fused_processor_cls = None + _available_processors = None def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: """ @@ -115,7 +116,7 @@ class AttentionModuleMixin: else AttnProcessor() ) self.set_processor(processor) - + @torch.no_grad() def fuse_projections(self, fuse=True): """ @@ -178,7 +179,7 @@ class AttentionModuleMixin: self.to_added_qkv.bias.copy_(concatenated_bias) self.fused_projections = fuse - + # Update processor based on fusion state processor_class = self.fused_processor_class if fuse else self.default_processor_class if processor_class is not None: @@ -2163,554 +2164,6 @@ class FusedAuraFlowAttnProcessorSDPA: return hidden_states -class FluxAttnProcessorSDPA: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("FluxAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FluxAttnProcessorSDPA_NPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessorSDPA_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU" - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if query.dtype in (torch.float16, torch.bfloat16): - hidden_states = torch_npu.npu_fusion_attention( - query, - key, - value, - attn.heads, - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FusedFluxAttnProcessorSDPA: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FusedFluxAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - # `context` projections. - if encoder_hidden_states is not None: - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - split_size = encoder_qkv.shape[-1] // 3 - ( - encoder_hidden_states_query_proj, - encoder_hidden_states_key_proj, - encoder_hidden_states_value_proj, - ) = torch.split(encoder_qkv, split_size, dim=-1) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FusedFluxAttnProcessorSDPA_NPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessorSDPA_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU" - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - # `context` projections. - if encoder_hidden_states is not None: - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - split_size = encoder_qkv.shape[-1] // 3 - ( - encoder_hidden_states_query_proj, - encoder_hidden_states_key_proj, - encoder_hidden_states_value_proj, - ) = torch.split(encoder_qkv, split_size, dim=-1) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if query.dtype in (torch.float16, torch.bfloat16): - hidden_states = torch_npu.npu_fusion_attention( - query, - key, - value, - attn.heads, - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FluxIPAdapterJointAttnProcessorSDPA(torch.nn.Module): - """Flux Attention processor for IP-Adapter.""" - - def __init__( - self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None - ): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - - if not isinstance(num_tokens, (tuple, list)): - num_tokens = [num_tokens] - - if not isinstance(scale, list): - scale = [scale] * len(num_tokens) - if len(scale) != len(num_tokens): - raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") - self.scale = scale - - self.to_k_ip = nn.ModuleList( - [ - nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) - for _ in range(len(num_tokens)) - ] - ) - self.to_v_ip = nn.ModuleList( - [ - nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) - for _ in range(len(num_tokens)) - ] - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ip_hidden_states: Optional[List[torch.Tensor]] = None, - ip_adapter_masks: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - hidden_states_query_proj = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - hidden_states_query_proj = attn.norm_q(hidden_states_query_proj) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - # IP-adapter - ip_query = hidden_states_query_proj - ip_attn_output = torch.zeros_like(hidden_states) - - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( - ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) - ip_attn_output += scale * current_ip_hidden_states - - return hidden_states, encoder_hidden_states, ip_attn_output - else: - return hidden_states - - class CogVideoXAttnProcessorSDPA: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index cfdafca63c..b28cea6421 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -47,11 +47,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name class FluxAttnProcessor: """Flux-specific attention processor that implements normalized attention with support for rotary embeddings.""" - + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("FluxAttnProcessor requires PyTorch 2.0, please upgrade PyTorch.") - + def __call__( self, attn, @@ -62,10 +62,10 @@ class FluxAttnProcessor: **kwargs, ) -> torch.FloatTensor: batch_size, seq_len, _ = hidden_states.shape - + # Project query from hidden states query = attn.to_q(hidden_states) - + # Handle cross-attention vs self-attention if encoder_hidden_states is None: key = attn.to_k(hidden_states) @@ -73,71 +73,76 @@ class FluxAttnProcessor: else: key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) - + # If we have added_kv_proj_dim, handle additional projections if hasattr(attn, "added_kv_proj_dim") and attn.added_kv_proj_dim is not None: encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) encoder_query = attn.add_q_proj(encoder_hidden_states) - + # Reshape inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads - + encoder_query = encoder_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) encoder_key = encoder_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) encoder_value = encoder_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - + # Apply normalization if available if hasattr(attn, "norm_added_q") and attn.norm_added_q is not None: encoder_query = attn.norm_added_q(encoder_query) if hasattr(attn, "norm_added_k") and attn.norm_added_k is not None: encoder_key = attn.norm_added_k(encoder_key) - + # Reshape for multi-head attention inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads - + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - + # Apply normalization if available if hasattr(attn, "norm_q") and attn.norm_q is not None: query = attn.norm_q(query) if hasattr(attn, "norm_k") and attn.norm_k is not None: key = attn.norm_k(key) - + # Handle rotary embeddings if provided if image_rotary_emb is not None: from ...models.embeddings import apply_rotary_emb + query = apply_rotary_emb(query, image_rotary_emb) # Only apply to key in self-attention if encoder_hidden_states is None: key = apply_rotary_emb(key, image_rotary_emb) - + # Concatenate encoder projections if we have them - if encoder_hidden_states is not None and hasattr(attn, "added_kv_proj_dim") and attn.added_kv_proj_dim is not None: + if ( + encoder_hidden_states is not None + and hasattr(attn, "added_kv_proj_dim") + and attn.added_kv_proj_dim is not None + ): # Concatenate for joint attention query = torch.cat([encoder_query, query], dim=2) key = torch.cat([encoder_key, key], dim=2) value = torch.cat([encoder_value, value], dim=2) - + # Compute attention hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - + # Reshape back hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - + # Split back if we did joint attention if ( - encoder_hidden_states is not None - and hasattr(attn, "added_kv_proj_dim") + encoder_hidden_states is not None + and hasattr(attn, "added_kv_proj_dim") and attn.added_kv_proj_dim is not None - and hasattr(attn, "to_add_out") + and hasattr(attn, "to_add_out") and attn.to_add_out is not None ): context_len = encoder_hidden_states.shape[1] @@ -145,18 +150,18 @@ class FluxAttnProcessor: hidden_states[:, :context_len], hidden_states[:, context_len:], ) - + # Project output hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - + return hidden_states, encoder_hidden_states else: # Project output hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) - + return hidden_states @@ -164,11 +169,11 @@ class FluxAttnProcessor: class FluxAttention(nn.Module, AttentionModuleMixin): """ Specialized attention implementation for Flux models. - + This attention module provides optimized implementation for Flux models, with support for RMSNorm, rotary embeddings, and added key/value projections. """ - + def __init__( self, query_dim: int, @@ -180,51 +185,48 @@ class FluxAttention(nn.Module, AttentionModuleMixin): added_kv_proj_dim: Optional[int] = None, ): super().__init__() - + # Core parameters self.inner_dim = dim_head * heads self.heads = heads - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.use_bias = bias self.scale_qk = True # Flux always uses scaled QK - + # Set cross-attention parameters self.is_cross_attention = cross_attention_dim is not None self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - + # Query, Key, Value projections self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) - + # RMSNorm for Flux models self.norm_q = RMSNorm(dim_head, eps=1e-6) self.norm_k = RMSNorm(dim_head, eps=1e-6) - + # Optional added key/value projections for joint attention self.added_kv_proj_dim = added_kv_proj_dim if added_kv_proj_dim is not None: self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) - + # Normalization for added projections self.norm_added_q = RMSNorm(dim_head, eps=1e-6) self.norm_added_k = RMSNorm(dim_head, eps=1e-6) self.added_proj_bias = bias - + # Output projection for context self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias) - + # Output projection and dropout - self.to_out = nn.ModuleList([ - nn.Linear(self.inner_dim, query_dim, bias=bias), - nn.Dropout(dropout) - ]) - + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, query_dim, bias=bias), nn.Dropout(dropout)]) + # Set processor self.processor = FluxAttnProcessor() - + def forward( self, hidden_states: torch.Tensor, @@ -235,13 +237,13 @@ class FluxAttention(nn.Module, AttentionModuleMixin): ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Forward pass for Flux attention. - + Args: hidden_states: Input hidden states encoder_hidden_states: Optional encoder hidden states for cross-attention attention_mask: Optional attention mask image_rotary_emb: Optional rotary embeddings for image tokens - + Returns: Output hidden states, and optionally encoder hidden states for joint attention """ @@ -303,6 +305,670 @@ class FluxSingleTransformerBlock(nn.Module): return hidden_states +class FluxAttnProcessorSDPA: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("FluxAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FluxAttnProcessorNPU: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FluxAttnProcessorSDPA_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU" + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if query.dtype in (torch.float16, torch.bfloat16): + hidden_states = torch_npu.npu_fusion_attention( + query, + key, + value, + attn.heads, + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + else: + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FusedFluxAttnProcessorSDPA: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedFluxAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + # `context` projections. + if encoder_hidden_states is not None: + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FusedFluxAttnProcessorNPU: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FluxAttnProcessorSDPA_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU" + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + # `context` projections. + if encoder_hidden_states is not None: + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if query.dtype in (torch.float16, torch.bfloat16): + hidden_states = torch_npu.npu_fusion_attention( + query, + key, + value, + attn.heads, + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + else: + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FluxIPAdapterJointAttnProcessorSDPA(torch.nn.Module): + """Flux Attention processor for IP-Adapter.""" + + def __init__( + self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") + self.scale = scale + + self.to_k_ip = nn.ModuleList( + [ + nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) + for _ in range(len(num_tokens)) + ] + ) + self.to_v_ip = nn.ModuleList( + [ + nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) + for _ in range(len(num_tokens)) + ] + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ip_hidden_states: Optional[List[torch.Tensor]] = None, + ip_adapter_masks: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + hidden_states_query_proj = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + hidden_states_query_proj = attn.norm_q(hidden_states_query_proj) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # IP-adapter + ip_query = hidden_states_query_proj + ip_attn_output = torch.zeros_like(hidden_states) + + for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + current_ip_hidden_states = F.scaled_dot_product_attention( + ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) + ip_attn_output += scale * current_ip_hidden_states + + return hidden_states, encoder_hidden_states, ip_attn_output + else: + return hidden_states + + +@maybe_allow_in_graph +class FluxAttention(nn.Module, AttentionModuleMixin): + """ + + Args: + query_dim (`int`): Number of channels in query. + cross_attention_dim (`int`, *optional*): Number of channels in encoder states. + heads (`int`, defaults to 8): Number of attention heads. + dim_head (`int`, defaults to 64): Dimension of each attention head. + dropout (`float`, defaults to 0.0): Dropout probability. + bias (`bool`, defaults to False): Whether to use bias in linear projections. + added_kv_proj_dim (`int`, *optional*): Dimension for added key/value projections. + """ + + # Set Flux-specific processor classes + default_processor_cls = FluxAttnProcessorSDPA + fused_processor_cls = FusedFluxAttnProcessorSDPA + + _available_processors = [ + FluxAttnProcessorSDPA, + FusedFluxAttnProcessorSDPA, + FluxAttnProcessorNPU, + FusedFluxAttnProcessorNPU, + FluxIPAdapterJointAttnProcessorSDPA, + ] + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + ): + super().__init__() + + # Core parameters + self.inner_dim = dim_head * heads + self.query_dim = query_dim + self.heads = heads + self.scale = dim_head**-0.5 + self.use_bias = bias + self.scale_qk = True # Flux always uses scale_qk + + # Cross-attention setup + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + # Projections + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + + # Flux-specific normalization + self.norm_q = RMSNorm(dim_head, eps=1e-6) + self.norm_k = RMSNorm(dim_head, eps=1e-6) + + # Added projections for cross-attention + self.added_kv_proj_dim = added_kv_proj_dim + if added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias) + + # Normalization for added projections + self.norm_added_q = RMSNorm(dim_head, eps=1e-6) + self.norm_added_k = RMSNorm(dim_head, eps=1e-6) + self.added_proj_bias = bias + + # Output projection + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, query_dim, bias=bias), nn.Dropout(dropout)]) + + # For cross-attention with added projections + if added_kv_proj_dim is not None: + self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias) + else: + self.to_add_out = None + + # Set default processor and fusion state + self.fused_projections = False + self.set_processor(self.default_processor_class()) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Process attention for Flux model inputs.""" + # Filter parameters to only those expected by the processor + processor_params = inspect.signature(self.processor.__call__).parameters.keys() + quiet_params = {"ip_adapter_masks", "ip_hidden_states"} + + # Check for unexpected parameters + unexpected_params = [k for k, _ in kwargs.items() if k not in processor_params and k not in quiet_params] + if unexpected_params: + logger.warning( + f"Parameters {unexpected_params} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + + # Filter to only expected parameters + filtered_kwargs = {k: v for k, v in kwargs.items() if k in processor_params} + + # Process with appropriate processor + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **filtered_kwargs, + ) + + @maybe_allow_in_graph class FluxTransformerBlock(nn.Module): def __init__(