mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -5115,6 +5115,54 @@ class PAGIdentitySanaLinearAttnProcessorSDPA:
|
||||
# Deprecated classes for backward compatibility
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`AttnProcessor` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`"
|
||||
deprecate("AttnProcessor", "1.0.0", deprecation_message)
|
||||
|
||||
return AttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class AttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`AttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`"
|
||||
deprecate("AttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
return AttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class AttnAddedKVProcessor:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`AttnAddedKVAttentionProcessor` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`"
|
||||
deprecate("AttnAddedKVAttentionProcessor", "1.0.0", deprecation_message)
|
||||
|
||||
return AttnAddedKVProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class AttnAddedKVProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`AttnAddedKVAttentionProcessor` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`"
|
||||
deprecate("AttnAddedKVAttentionProcessor", "1.0.0", deprecation_message)
|
||||
|
||||
return AttnAddedKVProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class AllegroAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`AllegroAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AllegroAttnProcessorSDPA`"
|
||||
deprecate("AllegroAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
return AllegroAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class AuraFlowAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`AuraFlowAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessorSDPA`"
|
||||
deprecate("AuraFlowAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
return AuraFlowAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class MochiAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`MochiAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `MochiAttnProcessorSDPA`"
|
||||
@@ -5150,10 +5198,13 @@ class FluxSingleAttnProcessor2_0:
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`FluxSingleAttnProcessorSDPA` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
|
||||
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
|
||||
super().__init__()
|
||||
deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
from .transformers.transformer_allegro import FluxAttnProcessorSDPA
|
||||
|
||||
return FluxAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class FusedAttnProcessor2_0:
|
||||
@@ -5164,26 +5215,11 @@ class FusedAttnProcessor2_0:
|
||||
return AttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class AttnAddedKVProcessor:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`AttnAddedKVAttentionProcessor` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`"
|
||||
deprecate("AttnAddedKVAttentionProcessor", "1.0.0", deprecation_message)
|
||||
|
||||
return AttnAddedKVProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class AttnAddedKVProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`AttnAddedKVAttentionProcessor` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`"
|
||||
deprecate("AttnAddedKVAttentionProcessor", "1.0.0", deprecation_message)
|
||||
|
||||
return AttnAddedKVProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class JointAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`JointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `JointAttnProcessorSDPA`"
|
||||
deprecate("JointAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
return JointAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -5191,6 +5227,7 @@ class PAGJointAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`PAGJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGJointAttnProcessorSDPA`"
|
||||
deprecate("PAGJointAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
return PAGJointAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -5198,6 +5235,7 @@ class PAGCFGJointAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`PAGCFGJointAttnProcessor2_0 is deprecated and this will be removed in a future version. Please use `PAGCFGJointAttnProcessorSDPA`"
|
||||
deprecate("PAGCFGJointAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
return PAGCFGJointAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -5205,27 +5243,15 @@ class FusedJointAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`FusedJointAttnProcessor2_0 is deprecated and this will be removed in a future version. Please use `JointAttnProcessorSDPA`"
|
||||
deprecate("FusedJointAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
return JointAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class AllegroAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`AllegroAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AllegroAttnProcessorSDPA`"
|
||||
deprecate("AllegroAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
return AllegroAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class AuraFlowAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`AuraFlowAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessorSDPA`"
|
||||
deprecate("AuraFlowAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
return AuraFlowAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class FusedAuraFlowAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`FusedAuraFlowAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessorSDPA`"
|
||||
deprecate("FusedAuraFlowAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
return AuraFlowAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -5256,22 +5282,6 @@ class FusedCogVideoXAttnProcessor2_0:
|
||||
return CogVideoXAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`AttnProcessor` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`"
|
||||
deprecate("AttnProcessor", "1.0.0", deprecation_message)
|
||||
|
||||
return AttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class AttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`AttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`"
|
||||
deprecate("AttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
return AttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class XLAFlashAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`XLAFlashAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `XLAFlashAttnProcessorSDPA`"
|
||||
@@ -5281,7 +5291,7 @@ class XLAFlashAttnProcessor2_0:
|
||||
|
||||
|
||||
class XLAFluxFlashAttnProcessor2_0:
|
||||
def __init__(cls, *args, **kwargs):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`XLAFluxFlashAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `XLAFluxFlashAttnProcessorSDPA`"
|
||||
deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
@@ -5290,18 +5300,18 @@ class XLAFluxFlashAttnProcessor2_0:
|
||||
return FluxAttnProcessorXLA(*args, **kwargs)
|
||||
|
||||
|
||||
class StableAudioAttnProcessor2_0(StableAudioAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
class StableAudioAttnProcessor2_0:
|
||||
def __new__(self, *args, **kwargs):
|
||||
deprecation_message = "`StableAudioAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `StableAudioAttnProcessorSDPA`"
|
||||
deprecate("StableAudioAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
return StableAudioAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class HunyuanAttnProcessor2_0(HunyuanAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`HunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `HunyuanAttnProcessorSDPA`"
|
||||
deprecate("HunyuanAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
return HunyuanAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class FusedHunyuanAttnProcessor2_0(FusedHunyuanAttnProcessorSDPA):
|
||||
|
||||
@@ -23,7 +23,12 @@ from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
||||
from ..attention_processor import (
|
||||
AttentionModuleMixin,
|
||||
AttentionProcessor,
|
||||
CogVideoXAttnProcessor2_0,
|
||||
FusedCogVideoXAttnProcessor2_0,
|
||||
)
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -35,6 +40,168 @@ from .modeling_common import FeedForward
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BaseCogVideoXAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
compatible_backends = []
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"CogVideoXAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def get_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
"""Public method to get projections based on whether we're using fused mode or not."""
|
||||
if self.is_fused and hasattr(attn, "to_qkv"):
|
||||
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
return self._get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
def _get_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
"""Get projections using standard separate projection matrices."""
|
||||
# Standard separate projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# Handle encoder projections if present
|
||||
encoder_projections = None
|
||||
if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"):
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
encoder_projections = (encoder_query, encoder_key, encoder_value)
|
||||
|
||||
return query, key, value, encoder_projections
|
||||
|
||||
def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
"""Get projections using fused QKV projection matrices."""
|
||||
# Fused QKV projection
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
# Handle encoder projections if present
|
||||
encoder_projections = None
|
||||
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
encoder_projections = (encoder_query, encoder_key, encoder_value)
|
||||
|
||||
return query, key, value, encoder_projections
|
||||
|
||||
def _compute_attention(self, query, key, value, attention_mask=None):
|
||||
"""Computes the attention. Can be overridden by hardware-specific implementations."""
|
||||
return F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query, key, value, _ = self.get_projections(attn, hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
hidden_states = self._compute_attention(query, key, value, attention_mask=attention_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXAttnProcessorSDPA(BaseCogVideoXAttnProcessor):
|
||||
compatible_backends = ["cuda", "cpu", "xpu"]
|
||||
|
||||
|
||||
class CogVideoXAttention(nn.Module, AttentionModuleMixin):
|
||||
default_processor_cls = CogVideoXAttnProcessorSDPA
|
||||
_available_processors = [CogVideoXAttnProcessorSDPA]
|
||||
|
||||
def __init__(
|
||||
self, query_dim, dim_head, heads, dropout=0.0, qk_norm=None, eps=1e-6, bias=False, out_bias=False
|
||||
) -> None:
|
||||
self.query_dim = query_dim
|
||||
self.out_dim = query_dim
|
||||
self.inner_dim = dim_head * heads
|
||||
self.use_bias = bias
|
||||
|
||||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if qk_norm is None:
|
||||
self.norm_q = None
|
||||
self.norm_k = None
|
||||
elif qk_norm == "layer_norm":
|
||||
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
self.set_processor(self.default_processor_cls())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class CogVideoXBlock(nn.Module):
|
||||
r"""
|
||||
|
||||
@@ -41,7 +41,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
class BaseFluxAttnProcessor:
|
||||
"""Base attention processor for Flux models with common functionality."""
|
||||
|
||||
is_fused = False
|
||||
compatible_backends = []
|
||||
|
||||
def __init__(self):
|
||||
@@ -377,13 +376,6 @@ class FluxIPAdapterAttnProcessorSDPA(torch.nn.Module):
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FluxAttention(nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
Specialized attention implementation for Flux models.
|
||||
|
||||
This attention module provides optimized implementation for Flux models,
|
||||
with support for RMSNorm, rotary embeddings, and added key/value projections.
|
||||
"""
|
||||
|
||||
_default_processor_cls = FluxAttnProcessorSDPA
|
||||
_available_processors = [
|
||||
FluxAttnProcessorSDPA,
|
||||
|
||||
Reference in New Issue
Block a user