1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-04-30 14:39:37 +05:30
parent 9b6a062adf
commit a7462302cd
3 changed files with 233 additions and 64 deletions

View File

@@ -5115,6 +5115,54 @@ class PAGIdentitySanaLinearAttnProcessorSDPA:
# Deprecated classes for backward compatibility
class AttnProcessor:
def __new__(cls, *args, **kwargs):
deprecation_message = "`AttnProcessor` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`"
deprecate("AttnProcessor", "1.0.0", deprecation_message)
return AttnProcessorSDPA(*args, **kwargs)
class AttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`AttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`"
deprecate("AttnProcessor2_0", "1.0.0", deprecation_message)
return AttnProcessorSDPA(*args, **kwargs)
class AttnAddedKVProcessor:
def __new__(cls, *args, **kwargs):
deprecation_message = "`AttnAddedKVAttentionProcessor` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`"
deprecate("AttnAddedKVAttentionProcessor", "1.0.0", deprecation_message)
return AttnAddedKVProcessorSDPA(*args, **kwargs)
class AttnAddedKVProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`AttnAddedKVAttentionProcessor` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`"
deprecate("AttnAddedKVAttentionProcessor", "1.0.0", deprecation_message)
return AttnAddedKVProcessorSDPA(*args, **kwargs)
class AllegroAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`AllegroAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AllegroAttnProcessorSDPA`"
deprecate("AllegroAttnProcessor2_0", "1.0.0", deprecation_message)
return AllegroAttnProcessorSDPA(*args, **kwargs)
class AuraFlowAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`AuraFlowAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessorSDPA`"
deprecate("AuraFlowAttnProcessor2_0", "1.0.0", deprecation_message)
return AuraFlowAttnProcessorSDPA(*args, **kwargs)
class MochiAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`MochiAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `MochiAttnProcessorSDPA`"
@@ -5150,10 +5198,13 @@ class FluxSingleAttnProcessor2_0:
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
def __new__(cls, *args, **kwargs):
deprecation_message = "`FluxSingleAttnProcessorSDPA` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
super().__init__()
deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
from .transformers.transformer_allegro import FluxAttnProcessorSDPA
return FluxAttnProcessorSDPA(*args, **kwargs)
class FusedAttnProcessor2_0:
@@ -5164,26 +5215,11 @@ class FusedAttnProcessor2_0:
return AttnProcessorSDPA(*args, **kwargs)
class AttnAddedKVProcessor:
def __new__(cls, *args, **kwargs):
deprecation_message = "`AttnAddedKVAttentionProcessor` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`"
deprecate("AttnAddedKVAttentionProcessor", "1.0.0", deprecation_message)
return AttnAddedKVProcessorSDPA(*args, **kwargs)
class AttnAddedKVProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`AttnAddedKVAttentionProcessor` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`"
deprecate("AttnAddedKVAttentionProcessor", "1.0.0", deprecation_message)
return AttnAddedKVProcessorSDPA(*args, **kwargs)
class JointAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`JointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `JointAttnProcessorSDPA`"
deprecate("JointAttnProcessor2_0", "1.0.0", deprecation_message)
return JointAttnProcessorSDPA(*args, **kwargs)
@@ -5191,6 +5227,7 @@ class PAGJointAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`PAGJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGJointAttnProcessorSDPA`"
deprecate("PAGJointAttnProcessor2_0", "1.0.0", deprecation_message)
return PAGJointAttnProcessorSDPA(*args, **kwargs)
@@ -5198,6 +5235,7 @@ class PAGCFGJointAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`PAGCFGJointAttnProcessor2_0 is deprecated and this will be removed in a future version. Please use `PAGCFGJointAttnProcessorSDPA`"
deprecate("PAGCFGJointAttnProcessor2_0", "1.0.0", deprecation_message)
return PAGCFGJointAttnProcessorSDPA(*args, **kwargs)
@@ -5205,27 +5243,15 @@ class FusedJointAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`FusedJointAttnProcessor2_0 is deprecated and this will be removed in a future version. Please use `JointAttnProcessorSDPA`"
deprecate("FusedJointAttnProcessor2_0", "1.0.0", deprecation_message)
return JointAttnProcessorSDPA(*args, **kwargs)
class AllegroAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`AllegroAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AllegroAttnProcessorSDPA`"
deprecate("AllegroAttnProcessor2_0", "1.0.0", deprecation_message)
return AllegroAttnProcessorSDPA(*args, **kwargs)
class AuraFlowAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`AuraFlowAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessorSDPA`"
deprecate("AuraFlowAttnProcessor2_0", "1.0.0", deprecation_message)
return AuraFlowAttnProcessorSDPA(*args, **kwargs)
class FusedAuraFlowAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`FusedAuraFlowAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessorSDPA`"
deprecate("FusedAuraFlowAttnProcessor2_0", "1.0.0", deprecation_message)
return AuraFlowAttnProcessorSDPA(*args, **kwargs)
@@ -5256,22 +5282,6 @@ class FusedCogVideoXAttnProcessor2_0:
return CogVideoXAttnProcessorSDPA(*args, **kwargs)
class AttnProcessor:
def __new__(cls, *args, **kwargs):
deprecation_message = "`AttnProcessor` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`"
deprecate("AttnProcessor", "1.0.0", deprecation_message)
return AttnProcessorSDPA(*args, **kwargs)
class AttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`AttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`"
deprecate("AttnProcessor2_0", "1.0.0", deprecation_message)
return AttnProcessorSDPA(*args, **kwargs)
class XLAFlashAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`XLAFlashAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `XLAFlashAttnProcessorSDPA`"
@@ -5281,7 +5291,7 @@ class XLAFlashAttnProcessor2_0:
class XLAFluxFlashAttnProcessor2_0:
def __init__(cls, *args, **kwargs):
def __new__(cls, *args, **kwargs):
deprecation_message = "`XLAFluxFlashAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `XLAFluxFlashAttnProcessorSDPA`"
deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message)
@@ -5290,18 +5300,18 @@ class XLAFluxFlashAttnProcessor2_0:
return FluxAttnProcessorXLA(*args, **kwargs)
class StableAudioAttnProcessor2_0(StableAudioAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
class StableAudioAttnProcessor2_0:
def __new__(self, *args, **kwargs):
deprecation_message = "`StableAudioAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `StableAudioAttnProcessorSDPA`"
deprecate("StableAudioAttnProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
return StableAudioAttnProcessorSDPA(*args, **kwargs)
class HunyuanAttnProcessor2_0(HunyuanAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
def __new__(cls, *args, **kwargs):
deprecation_message = "`HunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `HunyuanAttnProcessorSDPA`"
deprecate("HunyuanAttnProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
return HunyuanAttnProcessorSDPA(*args, **kwargs)
class FusedHunyuanAttnProcessor2_0(FusedHunyuanAttnProcessorSDPA):

View File

@@ -23,7 +23,12 @@ from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..attention_processor import (
AttentionModuleMixin,
AttentionProcessor,
CogVideoXAttnProcessor2_0,
FusedCogVideoXAttnProcessor2_0,
)
from ..cache_utils import CacheMixin
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
@@ -35,6 +40,168 @@ from .modeling_common import FeedForward
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class BaseCogVideoXAttnProcessor:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
compatible_backends = []
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"CogVideoXAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def get_projections(self, attn, hidden_states, encoder_hidden_states=None):
"""Public method to get projections based on whether we're using fused mode or not."""
if self.is_fused and hasattr(attn, "to_qkv"):
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
return self._get_projections(attn, hidden_states, encoder_hidden_states)
def _get_projections(self, attn, hidden_states, encoder_hidden_states=None):
"""Get projections using standard separate projection matrices."""
# Standard separate projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# Handle encoder projections if present
encoder_projections = None
if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"):
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_projections = (encoder_query, encoder_key, encoder_value)
return query, key, value, encoder_projections
def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None):
"""Get projections using fused QKV projection matrices."""
# Fused QKV projection
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
# Handle encoder projections if present
encoder_projections = None
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1)
encoder_projections = (encoder_query, encoder_key, encoder_value)
return query, key, value, encoder_projections
def _compute_attention(self, query, key, value, attention_mask=None):
"""Computes the attention. Can be overridden by hardware-specific implementations."""
return F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = hidden_states.shape
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query, key, value, _ = self.get_projections(attn, hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
hidden_states = self._compute_attention(query, key, value, attention_mask=attention_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class CogVideoXAttnProcessorSDPA(BaseCogVideoXAttnProcessor):
compatible_backends = ["cuda", "cpu", "xpu"]
class CogVideoXAttention(nn.Module, AttentionModuleMixin):
default_processor_cls = CogVideoXAttnProcessorSDPA
_available_processors = [CogVideoXAttnProcessorSDPA]
def __init__(
self, query_dim, dim_head, heads, dropout=0.0, qk_norm=None, eps=1e-6, bias=False, out_bias=False
) -> None:
self.query_dim = query_dim
self.out_dim = query_dim
self.inner_dim = dim_head * heads
self.use_bias = bias
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
if qk_norm is None:
self.norm_q = None
self.norm_k = None
elif qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
self.set_processor(self.default_processor_cls())
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
)
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
r"""

View File

@@ -41,7 +41,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class BaseFluxAttnProcessor:
"""Base attention processor for Flux models with common functionality."""
is_fused = False
compatible_backends = []
def __init__(self):
@@ -377,13 +376,6 @@ class FluxIPAdapterAttnProcessorSDPA(torch.nn.Module):
@maybe_allow_in_graph
class FluxAttention(nn.Module, AttentionModuleMixin):
"""
Specialized attention implementation for Flux models.
This attention module provides optimized implementation for Flux models,
with support for RMSNorm, rotary embeddings, and added key/value projections.
"""
_default_processor_cls = FluxAttnProcessorSDPA
_available_processors = [
FluxAttnProcessorSDPA,