mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -107,8 +107,8 @@ if is_flax_available():
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .auto_model import AutoModel
|
||||
from .attention_modules import FluxAttention, SanaAttention, SD3Attention
|
||||
from .auto_model import AutoModel
|
||||
from .autoencoders import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderDC,
|
||||
|
||||
@@ -15,21 +15,17 @@ import inspect
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import logging
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .attention_processor import (
|
||||
AttentionModuleMixin,
|
||||
AttnProcessorSDPA,
|
||||
FluxAttnProcessorSDPA,
|
||||
FusedFluxAttnProcessorSDPA,
|
||||
JointAttnProcessorSDPA,
|
||||
FusedJointAttnProcessorSDPA,
|
||||
JointAttnProcessorSDPA,
|
||||
SanaLinearAttnProcessorSDPA,
|
||||
)
|
||||
from .normalization import RMSNorm, get_normalization
|
||||
from .normalization import get_normalization
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -56,8 +56,13 @@ class AttentionModuleMixin:
|
||||
|
||||
# Default processor classes to be overridden by subclasses
|
||||
default_processor_cls = None
|
||||
fused_processor_cls = None
|
||||
_available_processors = None
|
||||
_available_processors = []
|
||||
|
||||
def _get_compatible_processor(self, backend):
|
||||
for processor_cls in self._available_processors:
|
||||
if backend in processor_cls.compatible_backends:
|
||||
processor = processor_cls()
|
||||
return processor
|
||||
|
||||
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
||||
"""
|
||||
@@ -66,18 +71,11 @@ class AttentionModuleMixin:
|
||||
Args:
|
||||
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
|
||||
"""
|
||||
processor = self.default_processor_cls()
|
||||
|
||||
if use_npu_flash_attention:
|
||||
processor = AttnProcessorNPU()
|
||||
else:
|
||||
# set attention processor
|
||||
# We use the AttnProcessorSDPA by default when torch 2.x is used which uses
|
||||
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
||||
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
||||
processor = (
|
||||
AttnProcessorSDPA()
|
||||
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
||||
else AttnProcessor()
|
||||
)
|
||||
processor = self._get_compatible_processor("npu")
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_use_xla_flash_attention(
|
||||
@@ -97,24 +95,17 @@ class AttentionModuleMixin:
|
||||
is_flux (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model is a Flux model.
|
||||
"""
|
||||
processor = self.default_processor_cls()
|
||||
if use_xla_flash_attention:
|
||||
if not is_torch_xla_available:
|
||||
if not is_torch_xla_available():
|
||||
raise "torch_xla is not available"
|
||||
elif is_torch_xla_version("<", "2.3"):
|
||||
raise "flash attention pallas kernel is supported from torch_xla version 2.3"
|
||||
elif is_spmd() and is_torch_xla_version("<", "2.4"):
|
||||
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
|
||||
else:
|
||||
if is_flux:
|
||||
processor = XLAFluxFlashAttnProcessorSDPA(partition_spec)
|
||||
else:
|
||||
processor = XLAFlashAttnProcessorSDPA(partition_spec)
|
||||
else:
|
||||
processor = (
|
||||
AttnProcessorSDPA()
|
||||
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
||||
else AttnProcessor()
|
||||
)
|
||||
processor = self._get_compatible_processor("xla")
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -179,11 +170,7 @@ class AttentionModuleMixin:
|
||||
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
self.fused_projections = fuse
|
||||
|
||||
# Update processor based on fusion state
|
||||
processor_class = self.fused_processor_class if fuse else self.default_processor_class
|
||||
if processor_class is not None:
|
||||
self.set_processor(processor_class())
|
||||
self.processor.is_fused = fuse
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||
@@ -557,8 +544,9 @@ class AttentionModuleMixin:
|
||||
@maybe_allow_in_graph
|
||||
class Attention(nn.Module, AttentionModuleMixin):
|
||||
# Set default and fused processor classes
|
||||
default_processor_class = AttnProcessorSDPA
|
||||
fused_processor_class = None # Will be set appropriately in the future
|
||||
default_processor_class = None
|
||||
_available_processors = []
|
||||
|
||||
r"""
|
||||
A cross attention layer.
|
||||
|
||||
@@ -958,7 +946,10 @@ class SanaMultiscaleLinearAttention(nn.Module):
|
||||
return self.processor(self, hidden_states)
|
||||
|
||||
|
||||
class MochiAttention(nn.Module):
|
||||
class MochiAttention(nn.Module, AttentionModuleMixin):
|
||||
default_processor_cls = MochiAttnProcessorSDPA
|
||||
_available_processors = [MochiAttnProcessorSDPA]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
@@ -1006,7 +997,8 @@ class MochiAttention(nn.Module):
|
||||
if not self.context_pre_only:
|
||||
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
||||
|
||||
self.processor = processor
|
||||
processor = processor if processor is not None else self.default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -23,8 +23,8 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttentionModuleMixin,
|
||||
AttentionProcessor,
|
||||
SanaLinearAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
|
||||
@@ -45,7 +45,7 @@ class SanaAttention(nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
# Set Sana-specific processor classes
|
||||
default_processor_class = SanaLinearAttnProcessor2_0
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -59,13 +59,13 @@ class SanaAttention(nn.Module, AttentionModuleMixin):
|
||||
residual_connection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
# Core parameters
|
||||
self.eps = eps
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
|
||||
# Calculate dimensions
|
||||
num_attention_heads = (
|
||||
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
|
||||
@@ -73,23 +73,23 @@ class SanaAttention(nn.Module, AttentionModuleMixin):
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
self.heads = num_attention_heads
|
||||
|
||||
|
||||
# Query, key, value projections
|
||||
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
|
||||
|
||||
# Multi-scale attention
|
||||
self.to_qkv_multiscale = nn.ModuleList()
|
||||
for kernel_size in kernel_sizes:
|
||||
self.to_qkv_multiscale.append(
|
||||
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
|
||||
)
|
||||
|
||||
|
||||
# Output layers
|
||||
self.nonlinearity = nn.ReLU()
|
||||
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
|
||||
|
||||
|
||||
# Get normalization based on type
|
||||
if norm_type == "batch_norm":
|
||||
self.norm_out = nn.BatchNorm1d(out_channels)
|
||||
@@ -101,14 +101,14 @@ class SanaAttention(nn.Module, AttentionModuleMixin):
|
||||
self.norm_out = nn.InstanceNorm1d(out_channels)
|
||||
else:
|
||||
self.norm_out = nn.Identity()
|
||||
|
||||
|
||||
# Set processor
|
||||
self.processor = self.default_processor_class()
|
||||
|
||||
|
||||
class SanaMultiscaleAttentionProjection(nn.Module):
|
||||
"""Projection layer for Sana multi-scale attention."""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -116,7 +116,7 @@ class SanaMultiscaleAttentionProjection(nn.Module):
|
||||
kernel_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
channels = 3 * in_channels
|
||||
self.proj_in = nn.Conv2d(
|
||||
channels,
|
||||
@@ -127,7 +127,7 @@ class SanaMultiscaleAttentionProjection(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,7 @@ from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0, AttentionModuleMixin
|
||||
from ..attention_processor import AttentionModuleMixin, MochiAttention, MochiAttnProcessor2_0
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -44,7 +44,7 @@ class MochiAttention(nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
# Set Mochi-specific processor classes
|
||||
default_processor_class = MochiAttnProcessor2_0
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
@@ -60,10 +60,10 @@ class MochiAttention(nn.Module, AttentionModuleMixin):
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from ..normalization import MochiRMSNorm
|
||||
|
||||
|
||||
# Core parameters
|
||||
self.inner_dim = dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
@@ -73,43 +73,43 @@ class MochiAttention(nn.Module, AttentionModuleMixin):
|
||||
self.scale_qk = True # Always use scaled attention
|
||||
self.context_pre_only = context_pre_only
|
||||
self.eps = eps
|
||||
|
||||
|
||||
# Set output dimensions
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.out_context_dim = out_context_dim if out_context_dim else query_dim
|
||||
|
||||
|
||||
# Self-attention projections
|
||||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
|
||||
# Normalization for queries and keys
|
||||
self.norm_q = MochiRMSNorm(dim_head, eps, True)
|
||||
self.norm_k = MochiRMSNorm(dim_head, eps, True)
|
||||
|
||||
|
||||
# Added key/value projections for joint processing
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
|
||||
|
||||
# Normalization for added projections
|
||||
self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
|
||||
self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
|
||||
self.added_proj_bias = added_proj_bias
|
||||
|
||||
|
||||
# Output projections
|
||||
self.to_out = nn.ModuleList([
|
||||
nn.Linear(self.inner_dim, self.out_dim, bias=bias),
|
||||
nn.Dropout(dropout)
|
||||
])
|
||||
|
||||
|
||||
# Context output projection
|
||||
if not context_pre_only:
|
||||
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=added_proj_bias)
|
||||
else:
|
||||
self.to_add_out = None
|
||||
|
||||
|
||||
# Initialize attention processor using the default class
|
||||
self.processor = self.default_processor_class()
|
||||
|
||||
|
||||
@@ -21,10 +21,9 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2
|
||||
from ...models.attention import FeedForward, JointTransformerBlock
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttentionModuleMixin,
|
||||
AttentionProcessor,
|
||||
FusedJointAttnProcessor2_0,
|
||||
JointAttnProcessor2_0,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
|
||||
@@ -39,11 +38,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
class JointAttnProcessor:
|
||||
"""Attention processor used for processing joint attention."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
raise ImportError("JointAttnProcessor requires PyTorch 2.0, please upgrade PyTorch.")
|
||||
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
@@ -54,10 +53,10 @@ class JointAttnProcessor:
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
|
||||
# Project query from hidden states
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
# Self-attention: Use hidden_states for key and value
|
||||
key = attn.to_k(hidden_states)
|
||||
@@ -66,77 +65,77 @@ class JointAttnProcessor:
|
||||
# Cross-attention: Use encoder_hidden_states for key and value
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
|
||||
# Handle additional context for joint attention
|
||||
if hasattr(attn, "added_kv_proj_dim") and attn.added_kv_proj_dim is not None:
|
||||
context_key = attn.add_k_proj(encoder_hidden_states)
|
||||
context_value = attn.add_v_proj(encoder_hidden_states)
|
||||
context_query = attn.add_q_proj(encoder_hidden_states)
|
||||
|
||||
|
||||
# Joint query, key, value with context
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
|
||||
# Reshape for multi-head attention
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
|
||||
context_query = context_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
context_key = context_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
context_value = context_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
|
||||
# Concatenate for joint attention
|
||||
query = torch.cat([context_query, query], dim=2)
|
||||
key = torch.cat([context_key, key], dim=2)
|
||||
value = torch.cat([context_value, value], dim=2)
|
||||
|
||||
|
||||
# Apply joint attention
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
|
||||
# Reshape back to original dimensions
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
|
||||
# Split context and hidden states
|
||||
context_len = encoder_hidden_states.shape[1]
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, :context_len],
|
||||
hidden_states[:, context_len:],
|
||||
)
|
||||
|
||||
|
||||
# Apply output projections
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
|
||||
if not attn.context_pre_only and hasattr(attn, "to_add_out") and attn.to_add_out is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Handle standard attention
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
|
||||
# Reshape for multi-head attention
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
|
||||
# Apply attention
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
|
||||
# Reshape output
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
|
||||
# Apply output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -148,7 +147,7 @@ class SD3Attention(nn.Module, AttentionModuleMixin):
|
||||
Features joint attention mechanisms and custom handling of
|
||||
context projections.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
@@ -163,7 +162,7 @@ class SD3Attention(nn.Module, AttentionModuleMixin):
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
# Core parameters
|
||||
self.inner_dim = dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
@@ -173,19 +172,19 @@ class SD3Attention(nn.Module, AttentionModuleMixin):
|
||||
self.use_bias = bias
|
||||
self.context_pre_only = context_pre_only
|
||||
self.eps = eps
|
||||
|
||||
|
||||
# Set output dimension
|
||||
out_dim = out_dim if out_dim is not None else query_dim
|
||||
|
||||
|
||||
# Set cross-attention parameters
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
|
||||
|
||||
# Linear projections for self-attention
|
||||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
|
||||
# Optional added key/value projections for joint attention
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
if added_kv_proj_dim is not None:
|
||||
@@ -193,22 +192,22 @@ class SD3Attention(nn.Module, AttentionModuleMixin):
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
|
||||
self.added_proj_bias = bias
|
||||
|
||||
|
||||
# Output projection for context
|
||||
if not context_pre_only:
|
||||
self.to_add_out = nn.Linear(self.inner_dim, out_dim, bias=bias)
|
||||
else:
|
||||
self.to_add_out = None
|
||||
|
||||
|
||||
# Output projection and dropout
|
||||
self.to_out = nn.ModuleList([
|
||||
nn.Linear(self.inner_dim, out_dim, bias=bias),
|
||||
nn.Dropout(dropout)
|
||||
])
|
||||
|
||||
|
||||
# Set processor
|
||||
self.processor = JointAttnProcessor()
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user