1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-05-05 23:23:44 +05:30
parent be84828840
commit ad4e8be19a
7 changed files with 393 additions and 314 deletions

View File

@@ -46,6 +46,63 @@ else:
XLA_AVAILABLE = False
class AttnProcessorMixin:
"""Attention processor used typically in processing Aura Flow."""
def _get_projections(self, attn, hidden_states, encoder_hidden_states=None):
"""Get projections using standard separate projection matrices."""
# Standard separate projections
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# Handle encoder projections if present
encoder_projections = None
if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"):
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_projections = (encoder_query, encoder_key, encoder_value)
return query, key, value, encoder_projections
def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None):
"""Get projections using fused QKV projection matrices."""
# Fused QKV projection
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
# Handle encoder projections if present
encoder_projections = None
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1)
encoder_projections = (encoder_query, encoder_key, encoder_value)
return query, key, value, encoder_projections
def get_projections(self, attn, hidden_states, encoder_hidden_states=None):
"""Public method to get projections based on whether we're using fused mode or not."""
if self.is_fused and hasattr(attn, "to_qkv"):
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
return self._get_projections(attn, hidden_states, encoder_hidden_states)
def attention_fn(self, query, key, value, scale=None, attention_mask=None):
"""Computes the attention. Can be overridden by hardware-specific implementations."""
return F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, scale=scale, dropout_p=0.0, is_causal=False
)
class Attention(nn.Module, AttentionModuleMixin):
default_processor_class = AttnProcessorSDPA
_available_processors = []
@@ -1292,99 +1349,6 @@ class AllegroAttnProcessorSDPA:
return hidden_states
class AuraFlowAttnProcessorSDPA:
"""Attention processor used typically in processing Aura Flow."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
raise ImportError(
"AuraFlowAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
*args,
**kwargs,
) -> torch.FloatTensor:
batch_size = hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# `context` projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# Reshape.
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
# Apply QK norm.
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Concatenate the projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Attention.
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, encoder_hidden_states.shape[1] :],
hidden_states[:, : encoder_hidden_states.shape[1]],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FusedAuraFlowAttnProcessorSDPA:
"""Attention processor used typically in processing Aura Flow with fused projections."""
@@ -2335,104 +2299,6 @@ class StableAudioAttnProcessorSDPA:
return hidden_states
class HunyuanAttnProcessorSDPA:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class FusedHunyuanAttnProcessorSDPA:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused

View File

@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from typing import Optional, Union
from huggingface_hub.utils import validate_hf_hub_args
from huggingface_hub.utils import EntryNotFoundError, validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
class AutoModel(ConfigMixin):
@@ -153,17 +153,39 @@ class AutoModel(ConfigMixin):
"token": token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
library = None
orig_class_name = None
from diffusers import pipelines
library = importlib.import_module("diffusers")
# Always attempt to fetch model_index.json first
try:
cls.config_name = "model_index.json"
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
model_cls = getattr(library, orig_class_name, None)
if model_cls is None:
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
if subfolder is not None and subfolder in config:
library, orig_class_name = config[subfolder]
except (OSError, EntryNotFoundError) as e:
logger.debug(e)
# Unable to load from model_index.json so fallback to loading from config
if library is None and orig_class_name is None:
cls.config_name = "config.json"
load_config_kwargs.update({"subfolder": subfolder})
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
library = "diffusers"
model_cls, _ = get_class_obj_and_candidates(
library_name=library,
class_name=orig_class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=pipelines,
is_pipeline_module=hasattr(pipelines, library),
)
kwargs = {**load_config_kwargs, **kwargs}
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)

View File

@@ -23,8 +23,9 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, AttentionMixin
from ..attention import Attention, AttentionMixin, AttentionModuleMixin
from ..attention_processor import (
AttnProcessorMixin,
AuraFlowAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps
@@ -133,6 +134,98 @@ class AuraFlowPreFinalBlock(nn.Module):
return x
class AuraFlowAttnProcessorSDPA(AttnProcessorMixin):
"""Attention processor used typically in processing Aura Flow."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
raise ImportError(
"AuraFlowAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
)
def __call__(
self,
attn: AttentionModuleMixin,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
*args,
**kwargs,
) -> torch.FloatTensor:
batch_size = hidden_states.shape[0]
query, key, value, encoder_projections = self.get_projections(attn, hidden_states, encoder_hidden_states)
# `context` projections.
if encoder_projections is not None:
encoder_hidden_states_query_proj, encoder_hidden_states_key_proj, encoder_hidden_states_value_proj = (
encoder_projections
)
# Reshape.
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
# Apply QK norm.
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Concatenate the projections.
if encoder_projections is not None:
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Attention.
hidden_states = self.attention_fn(query, key, value, scale=attn.scale)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, encoder_hidden_states.shape[1] :],
hidden_states[:, : encoder_hidden_states.shape[1]],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class AuraFlowAttention(Attention):
default_processor_cls = AuraFlowAttnProcessorSDPA
_available_processors = [AuraFlowAttnProcessorSDPA]
@maybe_allow_in_graph
class AuraFlowSingleTransformerBlock(nn.Module):
"""Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT."""
@@ -143,7 +236,7 @@ class AuraFlowSingleTransformerBlock(nn.Module):
self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
processor = AuraFlowAttnProcessor2_0()
self.attn = Attention(
self.attn = AuraFlowAttention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
@@ -206,7 +299,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
processor = AuraFlowAttnProcessor2_0()
self.attn = Attention(
self.attn = AuraFlowAttention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,

View File

@@ -22,10 +22,8 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin
from ..attention_processor import (
AttentionModuleMixin,
)
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_processor import Attention, AttnProcessorMixin
from ..cache_utils import CacheMixin
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
@@ -37,13 +35,13 @@ from .modeling_common import FeedForward
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class BaseCogVideoXAttnProcessor:
class CogVideoXAttnProcessorSDPA(AttnProcessorMixin):
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
compatible_backends = []
compatible_backends = ["cuda", "cpu", "xpu"]
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
@@ -51,56 +49,9 @@ class BaseCogVideoXAttnProcessor:
"CogVideoXAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def get_projections(self, attn, hidden_states, encoder_hidden_states=None):
"""Public method to get projections based on whether we're using fused mode or not."""
if self.is_fused and hasattr(attn, "to_qkv"):
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
return self._get_projections(attn, hidden_states, encoder_hidden_states)
def _get_projections(self, attn, hidden_states, encoder_hidden_states=None):
"""Get projections using standard separate projection matrices."""
# Standard separate projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# Handle encoder projections if present
encoder_projections = None
if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"):
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_projections = (encoder_query, encoder_key, encoder_value)
return query, key, value, encoder_projections
def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None):
"""Get projections using fused QKV projection matrices."""
# Fused QKV projection
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
# Handle encoder projections if present
encoder_projections = None
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1)
encoder_projections = (encoder_query, encoder_key, encoder_value)
return query, key, value, encoder_projections
def _compute_attention(self, query, key, value, attention_mask=None):
"""Computes the attention. Can be overridden by hardware-specific implementations."""
return F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
def __call__(
self,
attn: CogVideoXAttention,
attn: AttentionModuleMixin,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
@@ -138,7 +89,7 @@ class BaseCogVideoXAttnProcessor:
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
hidden_states = self._compute_attention(query, key, value, attention_mask=attention_mask)
hidden_states = self.attention_fn(query, key, value, attention_mask=attention_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
@@ -152,52 +103,10 @@ class BaseCogVideoXAttnProcessor:
return hidden_states, encoder_hidden_states
class CogVideoXAttnProcessorSDPA(BaseCogVideoXAttnProcessor):
compatible_backends = ["cuda", "cpu", "xpu"]
class CogVideoXAttention(nn.Module, AttentionModuleMixin):
class CogVideoXAttention(Attention):
default_processor_cls = CogVideoXAttnProcessorSDPA
_available_processors = [CogVideoXAttnProcessorSDPA]
def __init__(
self, query_dim, dim_head, heads, dropout=0.0, qk_norm=None, eps=1e-6, bias=False, out_bias=False
) -> None:
self.query_dim = query_dim
self.out_dim = query_dim
self.inner_dim = dim_head * heads
self.use_bias = bias
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
if qk_norm is None:
self.norm_q = None
self.norm_k = None
elif qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
self.set_processor(self.default_processor_cls())
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
)
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):

View File

@@ -22,8 +22,8 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, AttentionMixin
from ..attention_processor import CogVideoXAttnProcessor2_0
from ..attention import Attention
from ..attention_processor import CogVideoXAttnProcessor2_0, CogVideoXAttnProcessorSDPA
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -229,6 +229,11 @@ class PerceiverCrossAttention(nn.Module):
return self.to_out(out)
class ConsisIDAttention(Attention):
default_processor_cls = CogVideoXAttnProcessorSDPA
_available_processors = [CogVideoXAttnProcessorSDPA]
@maybe_allow_in_graph
class ConsisIDBlock(nn.Module):
r"""
@@ -287,7 +292,7 @@ class ConsisIDBlock(nn.Module):
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
self.attn1 = ConsisIDAttention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,

View File

@@ -19,8 +19,8 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, AttentionMixin
from ..attention_processor import HunyuanAttnProcessor2_0
from ..attention import Attention, AttentionMixin, AttnProcessorMixin
from ..attention_processor import HunyuanAttnProcessor2_0, HunyuanAttnProcessorSDPA
from ..embeddings import (
HunyuanCombinedTimestepTextSizeStyleEmbedding,
PatchEmbed,
@@ -56,6 +56,98 @@ class AdaLayerNormShift(nn.Module):
return x
class HunyuanAttnProcessorSDPA(AttnProcessorMixin):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: AttentionModuleMixin,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query, key, value = self.get_projections(attn, hidden_states, encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = self.attention_fn(query, key, value, attn_mask=attention_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class HunyuanDiTAttention(Attention):
default_processor_cls = HunyuanAttnProcessorSDPA
_available_processors = [HunyuanAttnProcessorSDPA]
@maybe_allow_in_graph
class HunyuanDiTBlock(nn.Module):
r"""
@@ -111,7 +203,7 @@ class HunyuanDiTBlock(nn.Module):
# 1. Self-Attn
self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
self.attn1 = HunyuanDiTAttention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // num_attention_heads,
@@ -119,7 +211,6 @@ class HunyuanDiTBlock(nn.Module):
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=True,
processor=HunyuanAttnProcessor2_0(),
)
# 2. Cross-Attn

View File

@@ -20,7 +20,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import LuminaFeedForward
from ..attention_processor import Attention, LuminaAttnProcessor2_0
from ..attention_processor import Attention, AttentionMixin, AttnProcessorMixin
from ..embeddings import (
LuminaCombinedTimestepCaptionEmbedding,
LuminaPatchEmbed,
@@ -33,6 +33,101 @@ from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LuminaAttnProcessorSDPA(AttnProcessorMixin):
compatible_backends = ["cuda", "cpu", "xpu"]
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
query_rotary_emb: Optional[torch.Tensor] = None,
key_rotary_emb: Optional[torch.Tensor] = None,
base_sequence_length: Optional[int] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
query, key, value, _ = self.get_projections(attn, hidden_states, encoder_hidden_states)
query_dim = query.shape[-1]
inner_dim = key.shape[-1]
head_dim = query_dim // attn.heads
dtype = query.dtype
# Get key-value heads
kv_heads = inner_dim // head_dim
# Apply Query-Key Norm if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, kv_heads, head_dim)
value = value.view(batch_size, -1, kv_heads, head_dim)
# Apply RoPE if needed
if query_rotary_emb is not None:
query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
if key_rotary_emb is not None:
key = apply_rotary_emb(key, key_rotary_emb, use_real=False)
query, key = query.to(dtype), key.to(dtype)
# Apply proportional attention if true
if key_rotary_emb is None:
softmax_scale = None
else:
if base_sequence_length is not None:
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
else:
softmax_scale = attn.scale
# perform Grouped-qurey Attention (GQA)
n_rep = attn.heads // kv_heads
if n_rep >= 1:
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, scale=softmax_scale
)
hidden_states = hidden_states.transpose(1, 2).to(dtype)
return hidden_states
class LuminaNextAttention(Attention):
default_processor_cls = LuminaAttnProcessorSDPA
_available_processors = [LuminaAttnProcessorSDPA]
class LuminaNextDiTBlock(nn.Module):
"""
A LuminaNextDiTBlock for LuminaNextDiT2DModel.
@@ -68,7 +163,7 @@ class LuminaNextDiTBlock(nn.Module):
self.gate = nn.Parameter(torch.zeros([num_attention_heads]))
# Self-attention
self.attn1 = Attention(
self.attn1 = LuminaNextAttention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // num_attention_heads,
@@ -78,12 +173,11 @@ class LuminaNextDiTBlock(nn.Module):
eps=1e-5,
bias=False,
out_bias=False,
processor=LuminaAttnProcessor2_0(),
)
self.attn1.to_out = nn.Identity()
# Cross-attention
self.attn2 = Attention(
self.attn2 = LuminaNextAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
dim_head=dim // num_attention_heads,
@@ -93,7 +187,6 @@ class LuminaNextDiTBlock(nn.Module):
eps=1e-5,
bias=False,
out_bias=False,
processor=LuminaAttnProcessor2_0(),
)
self.feed_forward = LuminaFeedForward(
@@ -175,7 +268,7 @@ class LuminaNextDiTBlock(nn.Module):
return hidden_states
class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
class LuminaNextDiT2DModel(ModelMixin, ConfigMixin, AttentionMixin):
"""
LuminaNextDiT: Diffusion model with a Transformer backbone.