diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 647ad41a6b..2df37dbe0e 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -46,6 +46,63 @@ else: XLA_AVAILABLE = False +class AttnProcessorMixin: + """Attention processor used typically in processing Aura Flow.""" + + def _get_projections(self, attn, hidden_states, encoder_hidden_states=None): + """Get projections using standard separate projection matrices.""" + # Standard separate projections + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Handle encoder projections if present + encoder_projections = None + if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"): + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_projections = (encoder_query, encoder_key, encoder_value) + + return query, key, value, encoder_projections + + def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None): + """Get projections using fused QKV projection matrices.""" + # Fused QKV projection + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + # Handle encoder projections if present + encoder_projections = None + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1) + encoder_projections = (encoder_query, encoder_key, encoder_value) + + return query, key, value, encoder_projections + + def get_projections(self, attn, hidden_states, encoder_hidden_states=None): + """Public method to get projections based on whether we're using fused mode or not.""" + if self.is_fused and hasattr(attn, "to_qkv"): + return self._get_fused_projections(attn, hidden_states, encoder_hidden_states) + + return self._get_projections(attn, hidden_states, encoder_hidden_states) + + def attention_fn(self, query, key, value, scale=None, attention_mask=None): + """Computes the attention. Can be overridden by hardware-specific implementations.""" + return F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=scale, dropout_p=0.0, is_causal=False + ) + + class Attention(nn.Module, AttentionModuleMixin): default_processor_class = AttnProcessorSDPA _available_processors = [] @@ -1292,99 +1349,6 @@ class AllegroAttnProcessorSDPA: return hidden_states -class AuraFlowAttnProcessorSDPA: - """Attention processor used typically in processing Aura Flow.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): - raise ImportError( - "AuraFlowAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - *args, - **kwargs, - ) -> torch.FloatTensor: - batch_size = hidden_states.shape[0] - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - # `context` projections. - if encoder_hidden_states is not None: - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - # Reshape. - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - query = query.view(batch_size, -1, attn.heads, head_dim) - key = key.view(batch_size, -1, attn.heads, head_dim) - value = value.view(batch_size, -1, attn.heads, head_dim) - - # Apply QK norm. - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Concatenate the projections. - if encoder_hidden_states is not None: - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) - - query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - # Attention. - hidden_states = F.scaled_dot_product_attention( - query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # Split the attention outputs. - if encoder_hidden_states is not None: - hidden_states, encoder_hidden_states = ( - hidden_states[:, encoder_hidden_states.shape[1] :], - hidden_states[:, : encoder_hidden_states.shape[1]], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - if encoder_hidden_states is not None: - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - if encoder_hidden_states is not None: - return hidden_states, encoder_hidden_states - else: - return hidden_states - - class FusedAuraFlowAttnProcessorSDPA: """Attention processor used typically in processing Aura Flow with fused projections.""" @@ -2335,104 +2299,6 @@ class StableAudioAttnProcessorSDPA: return hidden_states -class HunyuanAttnProcessorSDPA: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is - used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - from .embeddings import apply_rotary_emb - - residual = hidden_states - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - if not attn.is_cross_attention: - key = apply_rotary_emb(key, image_rotary_emb) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - class FusedHunyuanAttnProcessorSDPA: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 1b742463aa..b91ea32412 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib import os from typing import Optional, Union -from huggingface_hub.utils import validate_hf_hub_args +from huggingface_hub.utils import EntryNotFoundError, validate_hf_hub_args from ..configuration_utils import ConfigMixin +from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates class AutoModel(ConfigMixin): @@ -153,17 +153,39 @@ class AutoModel(ConfigMixin): "token": token, "local_files_only": local_files_only, "revision": revision, - "subfolder": subfolder, } - config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) - orig_class_name = config["_class_name"] + library = None + orig_class_name = None + from diffusers import pipelines - library = importlib.import_module("diffusers") + # Always attempt to fetch model_index.json first + try: + cls.config_name = "model_index.json" + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) - model_cls = getattr(library, orig_class_name, None) - if model_cls is None: - raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.") + if subfolder is not None and subfolder in config: + library, orig_class_name = config[subfolder] + + except (OSError, EntryNotFoundError) as e: + logger.debug(e) + + # Unable to load from model_index.json so fallback to loading from config + if library is None and orig_class_name is None: + cls.config_name = "config.json" + load_config_kwargs.update({"subfolder": subfolder}) + + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + orig_class_name = config["_class_name"] + library = "diffusers" + + model_cls, _ = get_class_obj_and_candidates( + library_name=library, + class_name=orig_class_name, + importable_classes=ALL_IMPORTABLE_CLASSES, + pipelines=pipelines, + is_pipeline_module=hasattr(pipelines, library), + ) kwargs = {**load_config_kwargs, **kwargs} return model_cls.from_pretrained(pretrained_model_or_path, **kwargs) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index e78f77ebda..8a55f1a25e 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -23,8 +23,9 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention, AttentionMixin +from ..attention import Attention, AttentionMixin, AttentionModuleMixin from ..attention_processor import ( + AttnProcessorMixin, AuraFlowAttnProcessor2_0, ) from ..embeddings import TimestepEmbedding, Timesteps @@ -133,6 +134,98 @@ class AuraFlowPreFinalBlock(nn.Module): return x +class AuraFlowAttnProcessorSDPA(AttnProcessorMixin): + """Attention processor used typically in processing Aura Flow.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): + raise ImportError( + "AuraFlowAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " + ) + + def __call__( + self, + attn: AttentionModuleMixin, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + batch_size = hidden_states.shape[0] + query, key, value, encoder_projections = self.get_projections(attn, hidden_states, encoder_hidden_states) + + # `context` projections. + if encoder_projections is not None: + encoder_hidden_states_query_proj, encoder_hidden_states_key_proj, encoder_hidden_states_value_proj = ( + encoder_projections + ) + + # Reshape. + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + # Apply QK norm. + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Concatenate the projections. + if encoder_projections is not None: + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) + + query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Attention. + hidden_states = self.attention_fn(query, key, value, scale=attn.scale) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, encoder_hidden_states.shape[1] :], + hidden_states[:, : encoder_hidden_states.shape[1]], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if encoder_hidden_states is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class AuraFlowAttention(Attention): + default_processor_cls = AuraFlowAttnProcessorSDPA + _available_processors = [AuraFlowAttnProcessorSDPA] + + @maybe_allow_in_graph class AuraFlowSingleTransformerBlock(nn.Module): """Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT.""" @@ -143,7 +236,7 @@ class AuraFlowSingleTransformerBlock(nn.Module): self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") processor = AuraFlowAttnProcessor2_0() - self.attn = Attention( + self.attn = AuraFlowAttention( query_dim=dim, cross_attention_dim=None, dim_head=attention_head_dim, @@ -206,7 +299,7 @@ class AuraFlowJointTransformerBlock(nn.Module): self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") processor = AuraFlowAttnProcessor2_0() - self.attn = Attention( + self.attn = AuraFlowAttention( query_dim=dim, cross_attention_dim=None, added_kv_proj_dim=dim, diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 0e7da957c5..699e8c9a95 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -22,10 +22,8 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionMixin -from ..attention_processor import ( - AttentionModuleMixin, -) +from ..attention import AttentionMixin, AttentionModuleMixin +from ..attention_processor import Attention, AttnProcessorMixin from ..cache_utils import CacheMixin from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput @@ -37,13 +35,13 @@ from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class BaseCogVideoXAttnProcessor: +class CogVideoXAttnProcessorSDPA(AttnProcessorMixin): r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on query and key vectors, but does not include spatial normalization. """ - compatible_backends = [] + compatible_backends = ["cuda", "cpu", "xpu"] def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -51,56 +49,9 @@ class BaseCogVideoXAttnProcessor: "CogVideoXAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) - def get_projections(self, attn, hidden_states, encoder_hidden_states=None): - """Public method to get projections based on whether we're using fused mode or not.""" - if self.is_fused and hasattr(attn, "to_qkv"): - return self._get_fused_projections(attn, hidden_states, encoder_hidden_states) - - return self._get_projections(attn, hidden_states, encoder_hidden_states) - - def _get_projections(self, attn, hidden_states, encoder_hidden_states=None): - """Get projections using standard separate projection matrices.""" - # Standard separate projections - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - # Handle encoder projections if present - encoder_projections = None - if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"): - encoder_query = attn.add_q_proj(encoder_hidden_states) - encoder_key = attn.add_k_proj(encoder_hidden_states) - encoder_value = attn.add_v_proj(encoder_hidden_states) - encoder_projections = (encoder_query, encoder_key, encoder_value) - - return query, key, value, encoder_projections - - def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None): - """Get projections using fused QKV projection matrices.""" - # Fused QKV projection - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - # Handle encoder projections if present - encoder_projections = None - if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - split_size = encoder_qkv.shape[-1] // 3 - encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1) - encoder_projections = (encoder_query, encoder_key, encoder_value) - - return query, key, value, encoder_projections - - def _compute_attention(self, query, key, value, attention_mask=None): - """Computes the attention. Can be overridden by hardware-specific implementations.""" - return F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - def __call__( self, - attn: CogVideoXAttention, + attn: AttentionModuleMixin, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -138,7 +89,7 @@ class BaseCogVideoXAttnProcessor: if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - hidden_states = self._compute_attention(query, key, value, attention_mask=attention_mask) + hidden_states = self.attention_fn(query, key, value, attention_mask=attention_mask) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) # linear proj @@ -152,52 +103,10 @@ class BaseCogVideoXAttnProcessor: return hidden_states, encoder_hidden_states -class CogVideoXAttnProcessorSDPA(BaseCogVideoXAttnProcessor): - compatible_backends = ["cuda", "cpu", "xpu"] - - -class CogVideoXAttention(nn.Module, AttentionModuleMixin): +class CogVideoXAttention(Attention): default_processor_cls = CogVideoXAttnProcessorSDPA _available_processors = [CogVideoXAttnProcessorSDPA] - def __init__( - self, query_dim, dim_head, heads, dropout=0.0, qk_norm=None, eps=1e-6, bias=False, out_bias=False - ) -> None: - self.query_dim = query_dim - self.out_dim = query_dim - self.inner_dim = dim_head * heads - self.use_bias = bias - - self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) - - if qk_norm is None: - self.norm_q = None - self.norm_k = None - elif qk_norm == "layer_norm": - self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) - self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) - - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(nn.Dropout(dropout)) - - self.set_processor(self.default_processor_cls()) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - ) - @maybe_allow_in_graph class CogVideoXBlock(nn.Module): diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py index 8fdad47838..f7bff1f608 100644 --- a/src/diffusers/models/transformers/consisid_transformer_3d.py +++ b/src/diffusers/models/transformers/consisid_transformer_3d.py @@ -22,8 +22,8 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention, AttentionMixin -from ..attention_processor import CogVideoXAttnProcessor2_0 +from ..attention import Attention +from ..attention_processor import CogVideoXAttnProcessor2_0, CogVideoXAttnProcessorSDPA from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -229,6 +229,11 @@ class PerceiverCrossAttention(nn.Module): return self.to_out(out) +class ConsisIDAttention(Attention): + default_processor_cls = CogVideoXAttnProcessorSDPA + _available_processors = [CogVideoXAttnProcessorSDPA] + + @maybe_allow_in_graph class ConsisIDBlock(nn.Module): r""" @@ -287,7 +292,7 @@ class ConsisIDBlock(nn.Module): # 1. Self Attention self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) - self.attn1 = Attention( + self.attn1 = ConsisIDAttention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 78e929b485..b69ad2cad6 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -19,8 +19,8 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention, AttentionMixin -from ..attention_processor import HunyuanAttnProcessor2_0 +from ..attention import Attention, AttentionMixin, AttnProcessorMixin +from ..attention_processor import HunyuanAttnProcessor2_0, HunyuanAttnProcessorSDPA from ..embeddings import ( HunyuanCombinedTimestepTextSizeStyleEmbedding, PatchEmbed, @@ -56,6 +56,98 @@ class AdaLayerNormShift(nn.Module): return x +class HunyuanAttnProcessorSDPA(AttnProcessorMixin): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: AttentionModuleMixin, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query, key, value = self.get_projections(attn, hidden_states, encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = self.attention_fn(query, key, value, attn_mask=attention_mask) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class HunyuanDiTAttention(Attention): + default_processor_cls = HunyuanAttnProcessorSDPA + _available_processors = [HunyuanAttnProcessorSDPA] + + @maybe_allow_in_graph class HunyuanDiTBlock(nn.Module): r""" @@ -111,7 +203,7 @@ class HunyuanDiTBlock(nn.Module): # 1. Self-Attn self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - self.attn1 = Attention( + self.attn1 = HunyuanDiTAttention( query_dim=dim, cross_attention_dim=None, dim_head=dim // num_attention_heads, @@ -119,7 +211,6 @@ class HunyuanDiTBlock(nn.Module): qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=True, - processor=HunyuanAttnProcessor2_0(), ) # 2. Cross-Attn diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py index 320950866c..cde6283321 100644 --- a/src/diffusers/models/transformers/lumina_nextdit2d.py +++ b/src/diffusers/models/transformers/lumina_nextdit2d.py @@ -20,7 +20,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ..attention import LuminaFeedForward -from ..attention_processor import Attention, LuminaAttnProcessor2_0 +from ..attention_processor import Attention, AttentionMixin, AttnProcessorMixin from ..embeddings import ( LuminaCombinedTimestepCaptionEmbedding, LuminaPatchEmbed, @@ -33,6 +33,101 @@ from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNor logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class LuminaAttnProcessorSDPA(AttnProcessorMixin): + compatible_backends = ["cuda", "cpu", "xpu"] + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[torch.Tensor] = None, + key_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape + + query, key, value, _ = self.get_projections(attn, hidden_states, encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Apply Query-Key Norm if needed + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.view(batch_size, -1, attn.heads, head_dim) + + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply RoPE if needed + if query_rotary_emb is not None: + query = apply_rotary_emb(query, query_rotary_emb, use_real=False) + if key_rotary_emb is not None: + key = apply_rotary_emb(key, key_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Apply proportional attention if true + if key_rotary_emb is None: + softmax_scale = None + else: + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # perform Grouped-qurey Attention (GQA) + n_rep = attn.heads // kv_heads + if n_rep >= 1: + key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) + attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=softmax_scale + ) + hidden_states = hidden_states.transpose(1, 2).to(dtype) + + return hidden_states + + +class LuminaNextAttention(Attention): + default_processor_cls = LuminaAttnProcessorSDPA + _available_processors = [LuminaAttnProcessorSDPA] + + class LuminaNextDiTBlock(nn.Module): """ A LuminaNextDiTBlock for LuminaNextDiT2DModel. @@ -68,7 +163,7 @@ class LuminaNextDiTBlock(nn.Module): self.gate = nn.Parameter(torch.zeros([num_attention_heads])) # Self-attention - self.attn1 = Attention( + self.attn1 = LuminaNextAttention( query_dim=dim, cross_attention_dim=None, dim_head=dim // num_attention_heads, @@ -78,12 +173,11 @@ class LuminaNextDiTBlock(nn.Module): eps=1e-5, bias=False, out_bias=False, - processor=LuminaAttnProcessor2_0(), ) self.attn1.to_out = nn.Identity() # Cross-attention - self.attn2 = Attention( + self.attn2 = LuminaNextAttention( query_dim=dim, cross_attention_dim=cross_attention_dim, dim_head=dim // num_attention_heads, @@ -93,7 +187,6 @@ class LuminaNextDiTBlock(nn.Module): eps=1e-5, bias=False, out_bias=False, - processor=LuminaAttnProcessor2_0(), ) self.feed_forward = LuminaFeedForward( @@ -175,7 +268,7 @@ class LuminaNextDiTBlock(nn.Module): return hidden_states -class LuminaNextDiT2DModel(ModelMixin, ConfigMixin): +class LuminaNextDiT2DModel(ModelMixin, ConfigMixin, AttentionMixin): """ LuminaNextDiT: Diffusion model with a Transformer backbone.