diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5cf608eff8..80aed8d123 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5115,6 +5115,54 @@ class PAGIdentitySanaLinearAttnProcessorSDPA: # Deprecated classes for backward compatibility +class AttnProcessor: + def __new__(cls, *args, **kwargs): + deprecation_message = "`AttnProcessor` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`" + deprecate("AttnProcessor", "1.0.0", deprecation_message) + + return AttnProcessorSDPA(*args, **kwargs) + + +class AttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = "`AttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`" + deprecate("AttnProcessor2_0", "1.0.0", deprecation_message) + + return AttnProcessorSDPA(*args, **kwargs) + + +class AttnAddedKVProcessor: + def __new__(cls, *args, **kwargs): + deprecation_message = "`AttnAddedKVAttentionProcessor` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`" + deprecate("AttnAddedKVAttentionProcessor", "1.0.0", deprecation_message) + + return AttnAddedKVProcessorSDPA(*args, **kwargs) + + +class AttnAddedKVProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = "`AttnAddedKVAttentionProcessor` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`" + deprecate("AttnAddedKVAttentionProcessor", "1.0.0", deprecation_message) + + return AttnAddedKVProcessorSDPA(*args, **kwargs) + + +class AllegroAttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = "`AllegroAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AllegroAttnProcessorSDPA`" + deprecate("AllegroAttnProcessor2_0", "1.0.0", deprecation_message) + + return AllegroAttnProcessorSDPA(*args, **kwargs) + + +class AuraFlowAttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = "`AuraFlowAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessorSDPA`" + deprecate("AuraFlowAttnProcessor2_0", "1.0.0", deprecation_message) + + return AuraFlowAttnProcessorSDPA(*args, **kwargs) + + class MochiAttnProcessor2_0: def __new__(cls, *args, **kwargs): deprecation_message = "`MochiAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `MochiAttnProcessorSDPA`" @@ -5150,10 +5198,13 @@ class FluxSingleAttnProcessor2_0: Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ - def __init__(self): + def __new__(cls, *args, **kwargs): deprecation_message = "`FluxSingleAttnProcessorSDPA` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead." - deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message) - super().__init__() + deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message) + + from .transformers.transformer_allegro import FluxAttnProcessorSDPA + + return FluxAttnProcessorSDPA(*args, **kwargs) class FusedAttnProcessor2_0: @@ -5164,26 +5215,11 @@ class FusedAttnProcessor2_0: return AttnProcessorSDPA(*args, **kwargs) -class AttnAddedKVProcessor: - def __new__(cls, *args, **kwargs): - deprecation_message = "`AttnAddedKVAttentionProcessor` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`" - deprecate("AttnAddedKVAttentionProcessor", "1.0.0", deprecation_message) - - return AttnAddedKVProcessorSDPA(*args, **kwargs) - - -class AttnAddedKVProcessor2_0: - def __new__(cls, *args, **kwargs): - deprecation_message = "`AttnAddedKVAttentionProcessor` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`" - deprecate("AttnAddedKVAttentionProcessor", "1.0.0", deprecation_message) - - return AttnAddedKVProcessorSDPA(*args, **kwargs) - - class JointAttnProcessor2_0: def __new__(cls, *args, **kwargs): deprecation_message = "`JointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `JointAttnProcessorSDPA`" deprecate("JointAttnProcessor2_0", "1.0.0", deprecation_message) + return JointAttnProcessorSDPA(*args, **kwargs) @@ -5191,6 +5227,7 @@ class PAGJointAttnProcessor2_0: def __new__(cls, *args, **kwargs): deprecation_message = "`PAGJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGJointAttnProcessorSDPA`" deprecate("PAGJointAttnProcessor2_0", "1.0.0", deprecation_message) + return PAGJointAttnProcessorSDPA(*args, **kwargs) @@ -5198,6 +5235,7 @@ class PAGCFGJointAttnProcessor2_0: def __new__(cls, *args, **kwargs): deprecation_message = "`PAGCFGJointAttnProcessor2_0 is deprecated and this will be removed in a future version. Please use `PAGCFGJointAttnProcessorSDPA`" deprecate("PAGCFGJointAttnProcessor2_0", "1.0.0", deprecation_message) + return PAGCFGJointAttnProcessorSDPA(*args, **kwargs) @@ -5205,27 +5243,15 @@ class FusedJointAttnProcessor2_0: def __new__(cls, *args, **kwargs): deprecation_message = "`FusedJointAttnProcessor2_0 is deprecated and this will be removed in a future version. Please use `JointAttnProcessorSDPA`" deprecate("FusedJointAttnProcessor2_0", "1.0.0", deprecation_message) + return JointAttnProcessorSDPA(*args, **kwargs) -class AllegroAttnProcessor2_0: - def __new__(cls, *args, **kwargs): - deprecation_message = "`AllegroAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AllegroAttnProcessorSDPA`" - deprecate("AllegroAttnProcessor2_0", "1.0.0", deprecation_message) - return AllegroAttnProcessorSDPA(*args, **kwargs) - - -class AuraFlowAttnProcessor2_0: - def __new__(cls, *args, **kwargs): - deprecation_message = "`AuraFlowAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessorSDPA`" - deprecate("AuraFlowAttnProcessor2_0", "1.0.0", deprecation_message) - return AuraFlowAttnProcessorSDPA(*args, **kwargs) - - class FusedAuraFlowAttnProcessor2_0: def __new__(cls, *args, **kwargs): deprecation_message = "`FusedAuraFlowAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessorSDPA`" deprecate("FusedAuraFlowAttnProcessor2_0", "1.0.0", deprecation_message) + return AuraFlowAttnProcessorSDPA(*args, **kwargs) @@ -5256,22 +5282,6 @@ class FusedCogVideoXAttnProcessor2_0: return CogVideoXAttnProcessorSDPA(*args, **kwargs) -class AttnProcessor: - def __new__(cls, *args, **kwargs): - deprecation_message = "`AttnProcessor` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`" - deprecate("AttnProcessor", "1.0.0", deprecation_message) - - return AttnProcessorSDPA(*args, **kwargs) - - -class AttnProcessor2_0: - def __new__(cls, *args, **kwargs): - deprecation_message = "`AttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`" - deprecate("AttnProcessor2_0", "1.0.0", deprecation_message) - - return AttnProcessorSDPA(*args, **kwargs) - - class XLAFlashAttnProcessor2_0: def __new__(cls, *args, **kwargs): deprecation_message = "`XLAFlashAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `XLAFlashAttnProcessorSDPA`" @@ -5281,7 +5291,7 @@ class XLAFlashAttnProcessor2_0: class XLAFluxFlashAttnProcessor2_0: - def __init__(cls, *args, **kwargs): + def __new__(cls, *args, **kwargs): deprecation_message = "`XLAFluxFlashAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `XLAFluxFlashAttnProcessorSDPA`" deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message) @@ -5290,18 +5300,18 @@ class XLAFluxFlashAttnProcessor2_0: return FluxAttnProcessorXLA(*args, **kwargs) -class StableAudioAttnProcessor2_0(StableAudioAttnProcessorSDPA): - def __init__(self, *args, **kwargs): +class StableAudioAttnProcessor2_0: + def __new__(self, *args, **kwargs): deprecation_message = "`StableAudioAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `StableAudioAttnProcessorSDPA`" deprecate("StableAudioAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + return StableAudioAttnProcessorSDPA(*args, **kwargs) class HunyuanAttnProcessor2_0(HunyuanAttnProcessorSDPA): - def __init__(self, *args, **kwargs): + def __new__(cls, *args, **kwargs): deprecation_message = "`HunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `HunyuanAttnProcessorSDPA`" deprecate("HunyuanAttnProcessor2_0", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + return HunyuanAttnProcessorSDPA(*args, **kwargs) class FusedHunyuanAttnProcessor2_0(FusedHunyuanAttnProcessorSDPA): diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index e31419c511..5e0965572d 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -23,7 +23,12 @@ from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention -from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 +from ..attention_processor import ( + AttentionModuleMixin, + AttentionProcessor, + CogVideoXAttnProcessor2_0, + FusedCogVideoXAttnProcessor2_0, +) from ..cache_utils import CacheMixin from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput @@ -35,6 +40,168 @@ from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class BaseCogVideoXAttnProcessor: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + compatible_backends = [] + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "CogVideoXAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def get_projections(self, attn, hidden_states, encoder_hidden_states=None): + """Public method to get projections based on whether we're using fused mode or not.""" + if self.is_fused and hasattr(attn, "to_qkv"): + return self._get_fused_projections(attn, hidden_states, encoder_hidden_states) + + return self._get_projections(attn, hidden_states, encoder_hidden_states) + + def _get_projections(self, attn, hidden_states, encoder_hidden_states=None): + """Get projections using standard separate projection matrices.""" + # Standard separate projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # Handle encoder projections if present + encoder_projections = None + if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"): + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_projections = (encoder_query, encoder_key, encoder_value) + + return query, key, value, encoder_projections + + def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None): + """Get projections using fused QKV projection matrices.""" + # Fused QKV projection + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + # Handle encoder projections if present + encoder_projections = None + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1) + encoder_projections = (encoder_query, encoder_key, encoder_value) + + return query, key, value, encoder_projections + + def _compute_attention(self, query, key, value, attention_mask=None): + """Computes the attention. Can be overridden by hardware-specific implementations.""" + return F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = hidden_states.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query, key, value, _ = self.get_projections(attn, hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + hidden_states = self._compute_attention(query, key, value, attention_mask=attention_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +class CogVideoXAttnProcessorSDPA(BaseCogVideoXAttnProcessor): + compatible_backends = ["cuda", "cpu", "xpu"] + + +class CogVideoXAttention(nn.Module, AttentionModuleMixin): + default_processor_cls = CogVideoXAttnProcessorSDPA + _available_processors = [CogVideoXAttnProcessorSDPA] + + def __init__( + self, query_dim, dim_head, heads, dropout=0.0, qk_norm=None, eps=1e-6, bias=False, out_bias=False + ) -> None: + self.query_dim = query_dim + self.out_dim = query_dim + self.inner_dim = dim_head * heads + self.use_bias = bias + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + self.set_processor(self.default_processor_cls()) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + + @maybe_allow_in_graph class CogVideoXBlock(nn.Module): r""" diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index f5e79beb51..7abd0d6d10 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -41,7 +41,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name class BaseFluxAttnProcessor: """Base attention processor for Flux models with common functionality.""" - is_fused = False compatible_backends = [] def __init__(self): @@ -377,13 +376,6 @@ class FluxIPAdapterAttnProcessorSDPA(torch.nn.Module): @maybe_allow_in_graph class FluxAttention(nn.Module, AttentionModuleMixin): - """ - Specialized attention implementation for Flux models. - - This attention module provides optimized implementation for Flux models, - with support for RMSNorm, rotary embeddings, and added key/value projections. - """ - _default_processor_cls = FluxAttnProcessorSDPA _available_processors = [ FluxAttnProcessorSDPA,