1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-04-15 21:32:36 +05:30
parent 1b4067c0d1
commit 37de8e790c
7 changed files with 1053 additions and 86 deletions

View File

@@ -27,6 +27,7 @@ _import_structure = {}
if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["attention_modules"] = ["FluxAttention", "SanaAttention", "SD3Attention"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
@@ -107,6 +108,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
from .auto_model import AutoModel
from .attention_modules import FluxAttention, SanaAttention, SD3Attention
from .autoencoders import (
AsymmetricAutoencoderKL,
AutoencoderDC,

View File

@@ -0,0 +1,371 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from ..utils import logging
from ..utils.torch_utils import maybe_allow_in_graph
from .attention_processor import (
AttentionModuleMixin,
AttnProcessorSDPA,
FluxAttnProcessorSDPA,
FusedFluxAttnProcessorSDPA,
JointAttnProcessorSDPA,
FusedJointAttnProcessorSDPA,
SanaLinearAttnProcessorSDPA,
)
from .normalization import RMSNorm, get_normalization
logger = logging.get_logger(__name__)
@maybe_allow_in_graph
class SanaAttention(nn.Module, AttentionModuleMixin):
"""
Attention implementation specialized for Sana models.
This module implements lightweight multi-scale linear attention as used in Sana.
Args:
in_channels (`int`): Number of input channels.
out_channels (`int`): Number of output channels.
num_attention_heads (`int`, *optional*): Number of attention heads.
attention_head_dim (`int`, defaults to 8): Dimension of each attention head.
mult (`float`, defaults to 1.0): Multiplier for inner dimension.
norm_type (`str`, defaults to "batch_norm"): Type of normalization.
kernel_sizes (`Tuple[int, ...]`, defaults to (5,)): Kernel sizes for multi-scale attention.
"""
# Set Sana-specific processor classes
default_processor_class = SanaLinearAttnProcessorSDPA
fused_processor_class = None # Sana doesn't have a fused processor yet
def __init__(
self,
in_channels: int,
out_channels: int,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 8,
mult: float = 1.0,
norm_type: str = "batch_norm",
kernel_sizes: Tuple[int, ...] = (5,),
eps: float = 1e-15,
residual_connection: bool = False,
):
super().__init__()
# Core parameters
self.eps = eps
self.attention_head_dim = attention_head_dim
self.norm_type = norm_type
self.residual_connection = residual_connection
# Calculate dimensions
num_attention_heads = (
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
)
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
self.heads = num_attention_heads
# Query, key, value projections
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
# Multi-scale attention
self.to_qkv_multiscale = nn.ModuleList()
for kernel_size in kernel_sizes:
self.to_qkv_multiscale.append(
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
)
# Output layers
self.nonlinearity = nn.ReLU()
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
self.norm_out = get_normalization(norm_type, num_features=out_channels)
# Set default processor
self.fused_projections = False
self.set_processor(self.default_processor_class())
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""Process linear attention for Sana model inputs."""
return self.processor(self, hidden_states)
class SanaMultiscaleAttentionProjection(nn.Module):
"""Projection layer for Sana multi-scale attention."""
def __init__(
self,
in_channels: int,
num_attention_heads: int,
kernel_size: int,
) -> None:
super().__init__()
channels = 3 * in_channels
self.proj_in = nn.Conv2d(
channels,
channels,
kernel_size,
padding=kernel_size // 2,
groups=channels,
bias=False,
)
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_out(hidden_states)
return hidden_states
@maybe_allow_in_graph
class FluxAttention(nn.Module, AttentionModuleMixin):
"""
Attention implementation specialized for Flux models.
This module uses RMSNorm for query and key normalization and supports
rotary embeddings through its processor.
Args:
query_dim (`int`): Number of channels in query.
cross_attention_dim (`int`, *optional*): Number of channels in encoder states.
heads (`int`, defaults to 8): Number of attention heads.
dim_head (`int`, defaults to 64): Dimension of each attention head.
dropout (`float`, defaults to 0.0): Dropout probability.
bias (`bool`, defaults to False): Whether to use bias in linear projections.
added_kv_proj_dim (`int`, *optional*): Dimension for added key/value projections.
"""
# Set Flux-specific processor classes
default_processor_class = FluxAttnProcessorSDPA
fused_processor_class = FusedFluxAttnProcessorSDPA
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
):
super().__init__()
# Core parameters
self.inner_dim = dim_head * heads
self.query_dim = query_dim
self.heads = heads
self.scale = dim_head ** -0.5
self.use_bias = bias
self.scale_qk = True # Flux always uses scale_qk
# Cross-attention setup
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
# Projections
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
# Flux-specific normalization
self.norm_q = RMSNorm(dim_head, eps=1e-6)
self.norm_k = RMSNorm(dim_head, eps=1e-6)
# Added projections for cross-attention
self.added_kv_proj_dim = added_kv_proj_dim
if added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
# Normalization for added projections
self.norm_added_q = RMSNorm(dim_head, eps=1e-6)
self.norm_added_k = RMSNorm(dim_head, eps=1e-6)
self.added_proj_bias = bias
# Output projection
self.to_out = nn.ModuleList([
nn.Linear(self.inner_dim, query_dim, bias=bias),
nn.Dropout(dropout)
])
# For cross-attention with added projections
if added_kv_proj_dim is not None:
self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias)
else:
self.to_add_out = None
# Set default processor and fusion state
self.fused_projections = False
self.set_processor(self.default_processor_class())
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Process attention for Flux model inputs."""
# Filter parameters to only those expected by the processor
processor_params = inspect.signature(self.processor.__call__).parameters.keys()
quiet_params = {"ip_adapter_masks", "ip_hidden_states"}
# Check for unexpected parameters
unexpected_params = [
k for k, _ in kwargs.items()
if k not in processor_params and k not in quiet_params
]
if unexpected_params:
logger.warning(
f"Parameters {unexpected_params} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
# Filter to only expected parameters
filtered_kwargs = {k: v for k, v in kwargs.items() if k in processor_params}
# Process with appropriate processor
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**filtered_kwargs,
)
@maybe_allow_in_graph
class SD3Attention(nn.Module, AttentionModuleMixin):
"""
Attention implementation specialized for SD3 models.
This module implements the joint attention mechanism used in SD3,
with native support for context pre-processing.
Args:
query_dim (`int`): Number of channels in query.
cross_attention_dim (`int`, *optional*): Number of channels in encoder states.
heads (`int`, defaults to 8): Number of attention heads.
dim_head (`int`, defaults to 64): Dimension of each attention head.
dropout (`float`, defaults to 0.0): Dropout probability.
bias (`bool`, defaults to False): Whether to use bias in linear projections.
added_kv_proj_dim (`int`, *optional*): Dimension for added key/value projections.
"""
# Set SD3-specific processor classes
default_processor_class = JointAttnProcessorSDPA
fused_processor_class = FusedJointAttnProcessorSDPA
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
context_pre_only: bool = False,
):
super().__init__()
# Core parameters
self.inner_dim = dim_head * heads
self.query_dim = query_dim
self.heads = heads
self.scale = dim_head ** -0.5
self.use_bias = bias
self.scale_qk = True
self.context_pre_only = context_pre_only
# Cross-attention setup
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
# Projections for self-attention
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
# Added projections for context processing
self.added_kv_proj_dim = added_kv_proj_dim
if added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
self.added_proj_bias = bias
# Output projection
self.to_out = nn.ModuleList([
nn.Linear(self.inner_dim, query_dim, bias=bias),
nn.Dropout(dropout)
])
# Context output projection
if added_kv_proj_dim is not None and not context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias)
else:
self.to_add_out = None
# Set default processor and fusion state
self.fused_projections = False
self.set_processor(self.default_processor_class())
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Process joint attention for SD3 model inputs."""
# Filter parameters to only those expected by the processor
processor_params = inspect.signature(self.processor.__call__).parameters.keys()
quiet_params = {"ip_adapter_masks", "ip_hidden_states"}
# Check for unexpected parameters
unexpected_params = [
k for k, _ in kwargs.items()
if k not in processor_params and k not in quiet_params
]
if unexpected_params:
logger.warning(
f"Parameters {unexpected_params} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
# Filter to only expected parameters
filtered_kwargs = {k: v for k, v in kwargs.items() if k in processor_params}
# Process with appropriate processor
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**filtered_kwargs,
)

View File

@@ -53,6 +53,10 @@ class AttentionModuleMixin:
This mixin adds functionality to set different attention processors, handle attention masks, compute attention
scores, and manage projections.
"""
# Default processor classes to be overridden by subclasses
default_processor_class = None
fused_processor_class = None
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
"""
@@ -111,6 +115,74 @@ class AttentionModuleMixin:
else AttnProcessor()
)
self.set_processor(processor)
@torch.no_grad()
def fuse_projections(self, fuse=True):
"""
Fuse the query, key, and value projections into a single projection for efficiency.
Args:
fuse (`bool`): Whether to fuse the projections or not.
"""
# Skip if already in desired state
if getattr(self, "fused_projections", False) == fuse:
return
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if not self.is_cross_attention:
# Fuse self-attention projections
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias)
else:
# Fuse cross-attention key-value projections
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
self.to_kv.bias.copy_(concatenated_bias)
# Handle added projections for models like SD3, Flux, etc.
if (
getattr(self, "add_q_proj", None) is not None
and getattr(self, "add_k_proj", None) is not None
and getattr(self, "add_v_proj", None) is not None
):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_added_qkv = nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)
self.fused_projections = fuse
# Update processor based on fusion state
processor_class = self.fused_processor_class if fuse else self.default_processor_class
if processor_class is not None:
self.set_processor(processor_class())
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
@@ -480,68 +552,12 @@ class AttentionModuleMixin:
return encoder_hidden_states
@torch.no_grad()
def fuse_projections(self, fuse=True):
"""
Fuse the query, key, and value projections into a single projection for efficiency.
Args:
fuse (`bool`): Whether to fuse the projections or not.
"""
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if not self.is_cross_attention:
# fetch weight matrices.
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
# create a new single projection layer and copy over the weights.
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
self.to_kv.bias.copy_(concatenated_bias)
# handle added projections for SD3 and others.
if (
getattr(self, "add_q_proj", None) is not None
and getattr(self, "add_k_proj", None) is not None
and getattr(self, "add_v_proj", None) is not None
):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_added_qkv = nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)
self.fused_projections = fuse
@maybe_allow_in_graph
class Attention(nn.Module, AttentionModuleMixin):
# Set default and fused processor classes
default_processor_class = AttnProcessorSDPA
fused_processor_class = None # Will be set appropriately in the future
r"""
A cross attention layer.

View File

@@ -24,6 +24,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
from ..attention_processor import (
Attention,
AttentionProcessor,
AttentionModuleMixin,
SanaLinearAttnProcessor2_0,
)
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
@@ -35,6 +36,104 @@ from ..normalization import AdaLayerNormSingle, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class SanaAttention(nn.Module, AttentionModuleMixin):
"""
Attention implementation specialized for Sana models.
This module implements lightweight multi-scale linear attention as used in Sana.
"""
# Set Sana-specific processor classes
default_processor_class = SanaLinearAttnProcessor2_0
def __init__(
self,
in_channels: int,
out_channels: int,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 8,
mult: float = 1.0,
norm_type: str = "batch_norm",
kernel_sizes: Tuple[int, ...] = (5,),
eps: float = 1e-15,
residual_connection: bool = False,
):
super().__init__()
# Core parameters
self.eps = eps
self.attention_head_dim = attention_head_dim
self.norm_type = norm_type
self.residual_connection = residual_connection
# Calculate dimensions
num_attention_heads = (
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
)
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
self.heads = num_attention_heads
# Query, key, value projections
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
# Multi-scale attention
self.to_qkv_multiscale = nn.ModuleList()
for kernel_size in kernel_sizes:
self.to_qkv_multiscale.append(
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
)
# Output layers
self.nonlinearity = nn.ReLU()
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
# Get normalization based on type
if norm_type == "batch_norm":
self.norm_out = nn.BatchNorm1d(out_channels)
elif norm_type == "layer_norm":
self.norm_out = nn.LayerNorm(out_channels)
elif norm_type == "group_norm":
self.norm_out = nn.GroupNorm(32, out_channels)
elif norm_type == "instance_norm":
self.norm_out = nn.InstanceNorm1d(out_channels)
else:
self.norm_out = nn.Identity()
# Set processor
self.processor = self.default_processor_class()
class SanaMultiscaleAttentionProjection(nn.Module):
"""Projection layer for Sana multi-scale attention."""
def __init__(
self,
in_channels: int,
num_attention_heads: int,
kernel_size: int,
) -> None:
super().__init__()
channels = 3 * in_channels
self.proj_in = nn.Conv2d(
channels,
channels,
kernel_size,
padding=kernel_size // 2,
groups=channels,
bias=False,
)
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_out(hidden_states)
return hidden_states
class GLUMBConv(nn.Module):
def __init__(
self,

View File

@@ -18,6 +18,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
@@ -25,12 +26,14 @@ from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
AttentionProcessor,
AttentionModuleMixin,
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
from ...utils.torch_utils import maybe_allow_in_graph
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
@@ -42,6 +45,216 @@ from ..modeling_outputs import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FluxAttnProcessor:
"""Flux-specific attention processor that implements normalized attention with support for rotary embeddings."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("FluxAttnProcessor requires PyTorch 2.0, please upgrade PyTorch.")
def __call__(
self,
attn,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.FloatTensor:
batch_size, seq_len, _ = hidden_states.shape
# Project query from hidden states
query = attn.to_q(hidden_states)
# Handle cross-attention vs self-attention
if encoder_hidden_states is None:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# If we have added_kv_proj_dim, handle additional projections
if hasattr(attn, "added_kv_proj_dim") and attn.added_kv_proj_dim is not None:
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = attn.add_q_proj(encoder_hidden_states)
# Reshape
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
encoder_query = encoder_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
encoder_key = encoder_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
encoder_value = encoder_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Apply normalization if available
if hasattr(attn, "norm_added_q") and attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if hasattr(attn, "norm_added_k") and attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
# Reshape for multi-head attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Apply normalization if available
if hasattr(attn, "norm_q") and attn.norm_q is not None:
query = attn.norm_q(query)
if hasattr(attn, "norm_k") and attn.norm_k is not None:
key = attn.norm_k(key)
# Handle rotary embeddings if provided
if image_rotary_emb is not None:
from ...models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
# Only apply to key in self-attention
if encoder_hidden_states is None:
key = apply_rotary_emb(key, image_rotary_emb)
# Concatenate encoder projections if we have them
if encoder_hidden_states is not None and hasattr(attn, "added_kv_proj_dim") and attn.added_kv_proj_dim is not None:
# Concatenate for joint attention
query = torch.cat([encoder_query, query], dim=2)
key = torch.cat([encoder_key, key], dim=2)
value = torch.cat([encoder_value, value], dim=2)
# Compute attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# Reshape back
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split back if we did joint attention
if (
encoder_hidden_states is not None
and hasattr(attn, "added_kv_proj_dim")
and attn.added_kv_proj_dim is not None
and hasattr(attn, "to_add_out")
and attn.to_add_out is not None
):
context_len = encoder_hidden_states.shape[1]
encoder_hidden_states, hidden_states = (
hidden_states[:, :context_len],
hidden_states[:, context_len:],
)
# Project output
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
# Project output
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
@maybe_allow_in_graph
class FluxAttention(nn.Module, AttentionModuleMixin):
"""
Specialized attention implementation for Flux models.
This attention module provides optimized implementation for Flux models,
with support for RMSNorm, rotary embeddings, and added key/value projections.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
):
super().__init__()
# Core parameters
self.inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.use_bias = bias
self.scale_qk = True # Flux always uses scaled QK
# Set cross-attention parameters
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
# Query, Key, Value projections
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
# RMSNorm for Flux models
self.norm_q = RMSNorm(dim_head, eps=1e-6)
self.norm_k = RMSNorm(dim_head, eps=1e-6)
# Optional added key/value projections for joint attention
self.added_kv_proj_dim = added_kv_proj_dim
if added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
# Normalization for added projections
self.norm_added_q = RMSNorm(dim_head, eps=1e-6)
self.norm_added_k = RMSNorm(dim_head, eps=1e-6)
self.added_proj_bias = bias
# Output projection for context
self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias)
# Output projection and dropout
self.to_out = nn.ModuleList([
nn.Linear(self.inner_dim, query_dim, bias=bias),
nn.Dropout(dropout)
])
# Set processor
self.processor = FluxAttnProcessor()
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Forward pass for Flux attention.
Args:
hidden_states: Input hidden states
encoder_hidden_states: Optional encoder hidden states for cross-attention
attention_mask: Optional attention mask
image_rotary_emb: Optional rotary embeddings for image tokens
Returns:
Output hidden states, and optionally encoder hidden states for joint attention
"""
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
**kwargs,
)
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
@@ -53,27 +266,14 @@ class FluxSingleTransformerBlock(nn.Module):
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available():
deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method."
)
deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU()
else:
processor = FluxAttnProcessor2_0()
self.attn = Attention(
# Use specialized FluxAttention instead of generic Attention
self.attn = FluxAttention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
dropout=0.0,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
@@ -113,18 +313,15 @@ class FluxTransformerBlock(nn.Module):
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
self.attn = Attention(
# Use specialized FluxAttention instead of generic Attention
self.attn = FluxAttention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
dropout=0.0,
bias=True,
processor=FluxAttnProcessor2_0(),
qk_norm=qk_norm,
eps=eps,
added_kv_proj_dim=dim,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)

View File

@@ -24,7 +24,7 @@ from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0, AttentionModuleMixin
from ..cache_utils import CacheMixin
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
@@ -35,6 +35,85 @@ from ..normalization import AdaLayerNormContinuous, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class MochiAttention(nn.Module, AttentionModuleMixin):
"""
Specialized attention module for Mochi video models.
Features RMSNorm normalization and rotary position embeddings.
"""
# Set Mochi-specific processor classes
default_processor_class = MochiAttnProcessor2_0
def __init__(
self,
query_dim: int,
added_kv_proj_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_proj_bias: bool = True,
out_dim: Optional[int] = None,
out_context_dim: Optional[int] = None,
context_pre_only: bool = False,
eps: float = 1e-5,
):
super().__init__()
# Import here to avoid circular imports
from ..normalization import MochiRMSNorm
# Core parameters
self.inner_dim = dim_head * heads
self.query_dim = query_dim
self.heads = heads
self.scale = dim_head ** -0.5
self.use_bias = bias
self.scale_qk = True # Always use scaled attention
self.context_pre_only = context_pre_only
self.eps = eps
# Set output dimensions
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim else query_dim
# Self-attention projections
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
# Normalization for queries and keys
self.norm_q = MochiRMSNorm(dim_head, eps, True)
self.norm_k = MochiRMSNorm(dim_head, eps, True)
# Added key/value projections for joint processing
self.added_kv_proj_dim = added_kv_proj_dim
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
# Normalization for added projections
self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
self.added_proj_bias = added_proj_bias
# Output projections
self.to_out = nn.ModuleList([
nn.Linear(self.inner_dim, self.out_dim, bias=bias),
nn.Dropout(dropout)
])
# Context output projection
if not context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=added_proj_bias)
else:
self.to_add_out = None
# Initialize attention processor using the default class
self.processor = self.default_processor_class()
class MochiModulatedRMSNorm(nn.Module):
def __init__(self, eps: float):
super().__init__()

View File

@@ -22,6 +22,7 @@ from ...models.attention import FeedForward, JointTransformerBlock
from ...models.attention_processor import (
Attention,
AttentionProcessor,
AttentionModuleMixin,
FusedJointAttnProcessor2_0,
JointAttnProcessor2_0,
)
@@ -36,6 +37,208 @@ from ..modeling_outputs import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class JointAttnProcessor:
"""Attention processor used for processing joint attention."""
def __init__(self):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise ImportError("JointAttnProcessor requires PyTorch 2.0, please upgrade PyTorch.")
def __call__(
self,
attn,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.FloatTensor:
batch_size, sequence_length, _ = hidden_states.shape
# Project query from hidden states
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
# Self-attention: Use hidden_states for key and value
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
else:
# Cross-attention: Use encoder_hidden_states for key and value
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# Handle additional context for joint attention
if hasattr(attn, "added_kv_proj_dim") and attn.added_kv_proj_dim is not None:
context_key = attn.add_k_proj(encoder_hidden_states)
context_value = attn.add_v_proj(encoder_hidden_states)
context_query = attn.add_q_proj(encoder_hidden_states)
# Joint query, key, value with context
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
# Reshape for multi-head attention
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
context_query = context_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
context_key = context_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
context_value = context_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Concatenate for joint attention
query = torch.cat([context_query, query], dim=2)
key = torch.cat([context_key, key], dim=2)
value = torch.cat([context_value, value], dim=2)
# Apply joint attention
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# Reshape back to original dimensions
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# Split context and hidden states
context_len = encoder_hidden_states.shape[1]
encoder_hidden_states, hidden_states = (
hidden_states[:, :context_len],
hidden_states[:, context_len:],
)
# Apply output projections
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if not attn.context_pre_only and hasattr(attn, "to_add_out") and attn.to_add_out is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
return hidden_states
# Handle standard attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
# Reshape for multi-head attention
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Apply attention
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# Reshape output
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# Apply output projection
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
@maybe_allow_in_graph
class SD3Attention(nn.Module, AttentionModuleMixin):
"""
Specialized attention implementation for SD3 models.
Features joint attention mechanisms and custom handling of
context projections.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
out_dim: Optional[int] = None,
context_pre_only: bool = False,
eps: float = 1e-6,
):
super().__init__()
# Core parameters
self.inner_dim = dim_head * heads
self.query_dim = query_dim
self.heads = heads
self.scale = dim_head ** -0.5
self.scale_qk = True # SD3 always scales query-key dot products
self.use_bias = bias
self.context_pre_only = context_pre_only
self.eps = eps
# Set output dimension
out_dim = out_dim if out_dim is not None else query_dim
# Set cross-attention parameters
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
# Linear projections for self-attention
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
# Optional added key/value projections for joint attention
self.added_kv_proj_dim = added_kv_proj_dim
if added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
self.added_proj_bias = bias
# Output projection for context
if not context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, out_dim, bias=bias)
else:
self.to_add_out = None
# Output projection and dropout
self.to_out = nn.ModuleList([
nn.Linear(self.inner_dim, out_dim, bias=bias),
nn.Dropout(dropout)
])
# Set processor
self.processor = JointAttnProcessor()
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Forward pass for SD3 attention.
Args:
hidden_states: Input hidden states
encoder_hidden_states: Optional encoder hidden states for cross/joint attention
attention_mask: Optional attention mask
position_ids: Optional position IDs
Returns:
Output hidden states, and optionally encoder hidden states for joint attention
"""
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs,
)
@maybe_allow_in_graph
class SD3SingleTransformerBlock(nn.Module):
def __init__(
@@ -47,13 +250,13 @@ class SD3SingleTransformerBlock(nn.Module):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.attn = Attention(
# Use specialized SD3Attention instead of generic Attention
self.attn = SD3Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=JointAttnProcessor2_0(),
eps=1e-6,
)