1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-04-29 22:57:59 +05:30
parent 94ae28edea
commit 200e4ac462
7 changed files with 298 additions and 768 deletions

View File

@@ -107,8 +107,8 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
from .auto_model import AutoModel
from .attention_modules import FluxAttention, SanaAttention, SD3Attention
from .auto_model import AutoModel
from .autoencoders import (
AsymmetricAutoencoderKL,
AutoencoderDC,

View File

@@ -15,21 +15,17 @@ import inspect
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from ..utils import logging
from ..utils.torch_utils import maybe_allow_in_graph
from .attention_processor import (
AttentionModuleMixin,
AttnProcessorSDPA,
FluxAttnProcessorSDPA,
FusedFluxAttnProcessorSDPA,
JointAttnProcessorSDPA,
FusedJointAttnProcessorSDPA,
JointAttnProcessorSDPA,
SanaLinearAttnProcessorSDPA,
)
from .normalization import RMSNorm, get_normalization
from .normalization import get_normalization
logger = logging.get_logger(__name__)

View File

@@ -56,8 +56,13 @@ class AttentionModuleMixin:
# Default processor classes to be overridden by subclasses
default_processor_cls = None
fused_processor_cls = None
_available_processors = None
_available_processors = []
def _get_compatible_processor(self, backend):
for processor_cls in self._available_processors:
if backend in processor_cls.compatible_backends:
processor = processor_cls()
return processor
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
"""
@@ -66,18 +71,11 @@ class AttentionModuleMixin:
Args:
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
"""
processor = self.default_processor_cls()
if use_npu_flash_attention:
processor = AttnProcessorNPU()
else:
# set attention processor
# We use the AttnProcessorSDPA by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
processor = (
AttnProcessorSDPA()
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
else AttnProcessor()
)
processor = self._get_compatible_processor("npu")
self.set_processor(processor)
def set_use_xla_flash_attention(
@@ -97,24 +95,17 @@ class AttentionModuleMixin:
is_flux (`bool`, *optional*, defaults to `False`):
Whether the model is a Flux model.
"""
processor = self.default_processor_cls()
if use_xla_flash_attention:
if not is_torch_xla_available:
if not is_torch_xla_available():
raise "torch_xla is not available"
elif is_torch_xla_version("<", "2.3"):
raise "flash attention pallas kernel is supported from torch_xla version 2.3"
elif is_spmd() and is_torch_xla_version("<", "2.4"):
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
else:
if is_flux:
processor = XLAFluxFlashAttnProcessorSDPA(partition_spec)
else:
processor = XLAFlashAttnProcessorSDPA(partition_spec)
else:
processor = (
AttnProcessorSDPA()
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
else AttnProcessor()
)
processor = self._get_compatible_processor("xla")
self.set_processor(processor)
@torch.no_grad()
@@ -179,11 +170,7 @@ class AttentionModuleMixin:
self.to_added_qkv.bias.copy_(concatenated_bias)
self.fused_projections = fuse
# Update processor based on fusion state
processor_class = self.fused_processor_class if fuse else self.default_processor_class
if processor_class is not None:
self.set_processor(processor_class())
self.processor.is_fused = fuse
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
@@ -557,8 +544,9 @@ class AttentionModuleMixin:
@maybe_allow_in_graph
class Attention(nn.Module, AttentionModuleMixin):
# Set default and fused processor classes
default_processor_class = AttnProcessorSDPA
fused_processor_class = None # Will be set appropriately in the future
default_processor_class = None
_available_processors = []
r"""
A cross attention layer.
@@ -958,7 +946,10 @@ class SanaMultiscaleLinearAttention(nn.Module):
return self.processor(self, hidden_states)
class MochiAttention(nn.Module):
class MochiAttention(nn.Module, AttentionModuleMixin):
default_processor_cls = MochiAttnProcessorSDPA
_available_processors = [MochiAttnProcessorSDPA]
def __init__(
self,
query_dim: int,
@@ -1006,7 +997,8 @@ class MochiAttention(nn.Module):
if not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
self.processor = processor
processor = processor if processor is not None else self.default_processor_cls()
self.set_processor(processor)
def forward(
self,

View File

@@ -23,8 +23,8 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention_processor import (
Attention,
AttentionProcessor,
AttentionModuleMixin,
AttentionProcessor,
SanaLinearAttnProcessor2_0,
)
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
@@ -45,7 +45,7 @@ class SanaAttention(nn.Module, AttentionModuleMixin):
"""
# Set Sana-specific processor classes
default_processor_class = SanaLinearAttnProcessor2_0
def __init__(
self,
in_channels: int,
@@ -59,13 +59,13 @@ class SanaAttention(nn.Module, AttentionModuleMixin):
residual_connection: bool = False,
):
super().__init__()
# Core parameters
self.eps = eps
self.attention_head_dim = attention_head_dim
self.norm_type = norm_type
self.residual_connection = residual_connection
# Calculate dimensions
num_attention_heads = (
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
@@ -73,23 +73,23 @@ class SanaAttention(nn.Module, AttentionModuleMixin):
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
self.heads = num_attention_heads
# Query, key, value projections
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
# Multi-scale attention
self.to_qkv_multiscale = nn.ModuleList()
for kernel_size in kernel_sizes:
self.to_qkv_multiscale.append(
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
)
# Output layers
self.nonlinearity = nn.ReLU()
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
# Get normalization based on type
if norm_type == "batch_norm":
self.norm_out = nn.BatchNorm1d(out_channels)
@@ -101,14 +101,14 @@ class SanaAttention(nn.Module, AttentionModuleMixin):
self.norm_out = nn.InstanceNorm1d(out_channels)
else:
self.norm_out = nn.Identity()
# Set processor
self.processor = self.default_processor_class()
class SanaMultiscaleAttentionProjection(nn.Module):
"""Projection layer for Sana multi-scale attention."""
def __init__(
self,
in_channels: int,
@@ -116,7 +116,7 @@ class SanaMultiscaleAttentionProjection(nn.Module):
kernel_size: int,
) -> None:
super().__init__()
channels = 3 * in_channels
self.proj_in = nn.Conv2d(
channels,
@@ -127,7 +127,7 @@ class SanaMultiscaleAttentionProjection(nn.Module):
bias=False,
)
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_out(hidden_states)

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,7 @@ from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0, AttentionModuleMixin
from ..attention_processor import AttentionModuleMixin, MochiAttention, MochiAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
@@ -44,7 +44,7 @@ class MochiAttention(nn.Module, AttentionModuleMixin):
"""
# Set Mochi-specific processor classes
default_processor_class = MochiAttnProcessor2_0
def __init__(
self,
query_dim: int,
@@ -60,10 +60,10 @@ class MochiAttention(nn.Module, AttentionModuleMixin):
eps: float = 1e-5,
):
super().__init__()
# Import here to avoid circular imports
from ..normalization import MochiRMSNorm
# Core parameters
self.inner_dim = dim_head * heads
self.query_dim = query_dim
@@ -73,43 +73,43 @@ class MochiAttention(nn.Module, AttentionModuleMixin):
self.scale_qk = True # Always use scaled attention
self.context_pre_only = context_pre_only
self.eps = eps
# Set output dimensions
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim else query_dim
# Self-attention projections
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
# Normalization for queries and keys
self.norm_q = MochiRMSNorm(dim_head, eps, True)
self.norm_k = MochiRMSNorm(dim_head, eps, True)
# Added key/value projections for joint processing
self.added_kv_proj_dim = added_kv_proj_dim
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
# Normalization for added projections
self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
self.added_proj_bias = added_proj_bias
# Output projections
self.to_out = nn.ModuleList([
nn.Linear(self.inner_dim, self.out_dim, bias=bias),
nn.Dropout(dropout)
])
# Context output projection
if not context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=added_proj_bias)
else:
self.to_add_out = None
# Initialize attention processor using the default class
self.processor = self.default_processor_class()

View File

@@ -21,10 +21,9 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2
from ...models.attention import FeedForward, JointTransformerBlock
from ...models.attention_processor import (
Attention,
AttentionProcessor,
AttentionModuleMixin,
AttentionProcessor,
FusedJointAttnProcessor2_0,
JointAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
@@ -39,11 +38,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class JointAttnProcessor:
"""Attention processor used for processing joint attention."""
def __init__(self):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise ImportError("JointAttnProcessor requires PyTorch 2.0, please upgrade PyTorch.")
def __call__(
self,
attn,
@@ -54,10 +53,10 @@ class JointAttnProcessor:
**kwargs,
) -> torch.FloatTensor:
batch_size, sequence_length, _ = hidden_states.shape
# Project query from hidden states
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
# Self-attention: Use hidden_states for key and value
key = attn.to_k(hidden_states)
@@ -66,77 +65,77 @@ class JointAttnProcessor:
# Cross-attention: Use encoder_hidden_states for key and value
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# Handle additional context for joint attention
if hasattr(attn, "added_kv_proj_dim") and attn.added_kv_proj_dim is not None:
context_key = attn.add_k_proj(encoder_hidden_states)
context_value = attn.add_v_proj(encoder_hidden_states)
context_query = attn.add_q_proj(encoder_hidden_states)
# Joint query, key, value with context
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
# Reshape for multi-head attention
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
context_query = context_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
context_key = context_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
context_value = context_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Concatenate for joint attention
query = torch.cat([context_query, query], dim=2)
key = torch.cat([context_key, key], dim=2)
value = torch.cat([context_value, value], dim=2)
# Apply joint attention
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# Reshape back to original dimensions
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# Split context and hidden states
context_len = encoder_hidden_states.shape[1]
encoder_hidden_states, hidden_states = (
hidden_states[:, :context_len],
hidden_states[:, context_len:],
)
# Apply output projections
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if not attn.context_pre_only and hasattr(attn, "to_add_out") and attn.to_add_out is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
return hidden_states
# Handle standard attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
# Reshape for multi-head attention
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Apply attention
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# Reshape output
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# Apply output projection
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
@@ -148,7 +147,7 @@ class SD3Attention(nn.Module, AttentionModuleMixin):
Features joint attention mechanisms and custom handling of
context projections.
"""
def __init__(
self,
query_dim: int,
@@ -163,7 +162,7 @@ class SD3Attention(nn.Module, AttentionModuleMixin):
eps: float = 1e-6,
):
super().__init__()
# Core parameters
self.inner_dim = dim_head * heads
self.query_dim = query_dim
@@ -173,19 +172,19 @@ class SD3Attention(nn.Module, AttentionModuleMixin):
self.use_bias = bias
self.context_pre_only = context_pre_only
self.eps = eps
# Set output dimension
out_dim = out_dim if out_dim is not None else query_dim
# Set cross-attention parameters
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
# Linear projections for self-attention
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
# Optional added key/value projections for joint attention
self.added_kv_proj_dim = added_kv_proj_dim
if added_kv_proj_dim is not None:
@@ -193,22 +192,22 @@ class SD3Attention(nn.Module, AttentionModuleMixin):
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
self.added_proj_bias = bias
# Output projection for context
if not context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, out_dim, bias=bias)
else:
self.to_add_out = None
# Output projection and dropout
self.to_out = nn.ModuleList([
nn.Linear(self.inner_dim, out_dim, bias=bias),
nn.Dropout(dropout)
])
# Set processor
self.processor = JointAttnProcessor()
def forward(
self,
hidden_states: torch.Tensor,