diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 2557cea678..73b225e9b9 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -516,19 +516,46 @@ def _chunked_feed_forward(*args, **kwargs): return _actual_chunked_feed_forward(*args, **kwargs) -class GatedSelfAttentionDense: +class GatedSelfAttentionDense(nn.Module): r""" - Backward compatibility stub. Use transformers.modeling_common.GatedSelfAttentionDense instead. + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. """ - def __new__(cls, *args, **kwargs): - logger.warning( - "Importing `GatedSelfAttentionDense` from `diffusers.models.attention` is deprecated and will be removed in a future version. " - "Please use `from diffusers.models.transformers.modeling_common import GatedSelfAttentionDense` instead." - ) - from .transformers.modeling_common import GatedSelfAttentionDense + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + from .transformers.modeling_common import FeedForward - return GatedSelfAttentionDense(*args, **kwargs) + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x class JointTransformerBlock: diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index 989bf4787a..78ba8fb287 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -19,15 +19,218 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..activations import GEGLU, GELU, ApproximateGELU +from ..attention_processor import Attention from ..embeddings import PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from .modeling_common import BasicTransformerBlock +from ..normalization import AdaLayerNormZero logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class DiTFeedForward(nn.Module): + r""" + A feed-forward layer for DiT. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class DiTTransformerBlock(nn.Module): + r""" + A basic Transformer block for DiT. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (`int`, *optional*): The number of diffusion steps used during training. + attention_bias (`bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + upcast_attention (`bool`, *optional*): Whether to upcast the attention computation to float32. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"ada_norm_zero"`): The normalization layer to use. + norm_eps (`float`, *optional*, defaults to 1e-5): A small constant added to the denominator in normalization layers. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "ada_norm_zero", + norm_eps: float = 1e-5, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.norm_elementwise_affine = norm_elementwise_affine + + # Only support ada_norm_zero for DiT + if norm_type != "ada_norm_zero": + raise ValueError(f"DiTTransformerBlock only supports norm_type='ada_norm_zero', got {norm_type}") + + if num_embeds_ada_norm is None: + raise ValueError("num_embeds_ada_norm must be provided for ada_norm_zero") + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + # 1. Self-Attention with AdaLayerNormZero + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + upcast_attention=upcast_attention, + out_bias=True, + ) + + # 2. Feed-forward + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.ff = DiTFeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=False, + inner_dim=None, + bias=True, + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + # 1. Self-Attention with adaptive norm + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}." + ) + num_chunks = hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + ff_output = gate_mlp.unsqueeze(1) * ff_output + hidden_states = ff_output + hidden_states + + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + class DiTTransformer2DModel(ModelMixin, ConfigMixin): r""" A 2D Transformer model as introduced in DiT (https://huggingface.co/papers/2212.09748). @@ -121,7 +324,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin): self.transformer_blocks = nn.ModuleList( [ - BasicTransformerBlock( + DiTTransformerBlock( self.inner_dim, self.config.num_attention_heads, self.config.attention_head_dim, diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index b623b1f395..432b951c2e 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -12,18 +12,362 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, Optional import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import GatedSelfAttentionDense +from ..attention_processor import Attention from ..cache_utils import CacheMixin -from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid +from ..embeddings import ( + PatchEmbed, + PixArtAlphaTextProjection, + SinusoidalPositionalEmbedding, + get_1d_sincos_pos_embed_from_grid, +) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle -from .modeling_common import BasicTransformerBlock +from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero +from .modeling_common import FeedForward, _chunked_feed_forward + + +logger = logging.get_logger(__name__) + + +@maybe_allow_in_graph +class LatteTransformerBlock(nn.Module): + r""" + A Transformer block for Latte. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + if norm_type == "ada_norm_single": # For Latte + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): @@ -110,7 +454,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): # 2. Define spatial transformers blocks self.transformer_blocks = nn.ModuleList( [ - BasicTransformerBlock( + LatteTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, @@ -130,7 +474,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): # 3. Define temporal transformers blocks self.temporal_transformer_blocks = nn.ModuleList( [ - BasicTransformerBlock( + LatteTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, diff --git a/src/diffusers/models/transformers/modeling_common.py b/src/diffusers/models/transformers/modeling_common.py index f4baff19b1..1354665628 100644 --- a/src/diffusers/models/transformers/modeling_common.py +++ b/src/diffusers/models/transformers/modeling_common.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from ...utils import deprecate, logging from ...utils.torch_utils import maybe_allow_in_graph from ..activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU +from ..attention import GatedSelfAttentionDense from ..attention_processor import Attention, JointAttnProcessor2_0 from ..embeddings import SinusoidalPositionalEmbedding from ..normalization import ( @@ -50,48 +51,6 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: return ff_output -@maybe_allow_in_graph -class GatedSelfAttentionDense(nn.Module): - r""" - A gated self-attention dense layer that combines visual features and object features. - - Parameters: - query_dim (`int`): The number of channels in the query. - context_dim (`int`): The number of channels in the context. - n_heads (`int`): The number of heads to use for attention. - d_head (`int`): The number of channels in each head. - """ - - def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): - super().__init__() - - # we need a linear projection since we need cat visual feature and obj feature - self.linear = nn.Linear(context_dim, query_dim) - - self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) - self.ff = FeedForward(query_dim, activation_fn="geglu") - - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) - - self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) - self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) - - self.enabled = True - - def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: - if not self.enabled: - return x - - n_visual = x.shape[1] - objs = self.linear(objs) - - x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] - x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) - - return x - - @maybe_allow_in_graph class JointTransformerBlock(nn.Module): r""" diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 5b13c24c69..2b1d4962de 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -18,12 +18,346 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import GatedSelfAttentionDense from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0 -from ..embeddings import PatchEmbed, PixArtAlphaTextProjection +from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, SinusoidalPositionalEmbedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle -from .modeling_common import BasicTransformerBlock +from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero +from .modeling_common import FeedForward, _chunked_feed_forward + + +@maybe_allow_in_graph +class PixArtTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + if norm_type == "ada_norm_single": # For Latte + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -78,7 +412,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] + _no_split_modules = ["PixArtTransformerBlock", "PatchEmbed"] _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"] @register_to_config @@ -151,7 +485,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): self.transformer_blocks = nn.ModuleList( [ - BasicTransformerBlock( + PixArtTransformerBlock( self.inner_dim, self.config.num_attention_heads, self.config.attention_head_dim, diff --git a/src/diffusers/models/transformers/prior_transformer.py b/src/diffusers/models/transformers/prior_transformer.py index 3c2a5afd6b..f9a47c98b4 100644 --- a/src/diffusers/models/transformers/prior_transformer.py +++ b/src/diffusers/models/transformers/prior_transformer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union import torch import torch.nn.functional as F @@ -7,17 +7,356 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin -from ...utils import BaseOutput +from ...utils import BaseOutput, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import GatedSelfAttentionDense from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, + Attention, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) -from ..embeddings import TimestepEmbedding, Timesteps +from ..embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin -from .modeling_common import BasicTransformerBlock +from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero +from .modeling_common import FeedForward, _chunked_feed_forward + + +@maybe_allow_in_graph +class PriorTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + if norm_type == "ada_norm_single": # For Latte + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +logger = logging.get_logger(__name__) @dataclass @@ -133,7 +472,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef self.transformer_blocks = nn.ModuleList( [ - BasicTransformerBlock( + PriorTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 97376e9415..c294e4cb18 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -19,16 +19,356 @@ from torch import nn from ...configuration_utils import LegacyConfigMixin, register_to_config from ...utils import deprecate, logging -from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import GatedSelfAttentionDense +from ..attention_processor import Attention +from ..embeddings import ( + ImagePositionalEmbeddings, + PatchEmbed, + PixArtAlphaTextProjection, + SinusoidalPositionalEmbedding, +) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import LegacyModelMixin -from ..normalization import AdaLayerNormSingle -from .modeling_common import BasicTransformerBlock +from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero +from .modeling_common import FeedForward, _chunked_feed_forward logger = logging.get_logger(__name__) # pylint: disable=invalid-name +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + if norm_type == "ada_norm_single": # For Latte + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + class Transformer2DModelOutput(Transformer2DModelOutput): def __init__(self, *args, **kwargs): deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead." diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 26ab1faf64..a7376d458b 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin @@ -29,13 +30,185 @@ from ..attention_processor import ( from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero -from .modeling_common import FeedForward, JointTransformerBlock +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, SD35AdaLayerNormZeroX +from .modeling_common import FeedForward, _chunked_feed_forward logger = logging.get_logger(__name__) # pylint: disable=invalid-name +@maybe_allow_in_graph +class SD3JointTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://huggingface.co/papers/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + context_pre_only: bool = False, + qk_norm: Optional[str] = None, + use_dual_attention: bool = False, + ): + super().__init__() + + self.use_dual_attention = use_dual_attention + self.context_pre_only = context_pre_only + context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" + + if use_dual_attention: + self.norm1 = SD35AdaLayerNormZeroX(dim) + else: + self.norm1 = AdaLayerNormZero(dim) + + if context_norm_type == "ada_norm_continous": + self.norm1_context = AdaLayerNormContinuous( + dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" + ) + elif context_norm_type == "ada_norm_zero": + self.norm1_context = AdaLayerNormZero(dim) + else: + raise ValueError( + f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" + ) + + if hasattr(F, "scaled_dot_product_attention"): + processor = JointAttnProcessor2_0() + else: + raise ValueError( + "The current PyTorch version does not support the `scaled_dot_product_attention` function." + ) + + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=context_pre_only, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=1e-6, + ) + + if use_dual_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=1e-6, + ) + else: + self.attn2 = None + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + if not context_pre_only: + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + else: + self.norm2_context = None + self.ff_context = None + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + joint_attention_kwargs = joint_attention_kwargs or {} + if self.use_dual_attention: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( + hidden_states, emb=temb + ) + else: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + if self.context_pre_only: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + else: + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + **joint_attention_kwargs, + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + if self.use_dual_attention: + attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs) + attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 + hidden_states = hidden_states + attn_output2 + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + if self.context_pre_only: + encoder_hidden_states = None + else: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + context_ff_output = _chunked_feed_forward( + self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return encoder_hidden_states, hidden_states + + @maybe_allow_in_graph class SD3SingleTransformerBlock(nn.Module): def __init__( @@ -114,7 +287,7 @@ class SD3Transformer2DModel( """ _supports_gradient_checkpointing = True - _no_split_modules = ["JointTransformerBlock"] + _no_split_modules = ["SD3JointTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] @register_to_config @@ -155,7 +328,7 @@ class SD3Transformer2DModel( self.transformer_blocks = nn.ModuleList( [ - JointTransformerBlock( + SD3JointTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index 0116744291..45f14cf973 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -18,11 +18,147 @@ import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import BaseOutput +from ...utils import BaseOutput, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..resnet import AlphaBlender -from .modeling_common import BasicTransformerBlock, TemporalBasicTransformerBlock +from .modeling_common import FeedForward, _chunked_feed_forward +from .transformer_2d import BasicTransformerBlock + + +logger = logging.get_logger(__name__) + + +@maybe_allow_in_graph +class TemporalTransformerBlock(nn.Module): + r""" + A basic Transformer block for video like data. + + Parameters: + dim (`int`): The number of channels in the input and output. + time_mix_inner_dim (`int`): The number of channels for temporal attention. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + """ + + def __init__( + self, + dim: int, + time_mix_inner_dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.is_res = dim == time_mix_inner_dim + + self.norm_in = nn.LayerNorm(dim) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.ff_in = FeedForward( + dim, + dim_out=time_mix_inner_dim, + activation_fn="geglu", + ) + + self.norm1 = nn.LayerNorm(time_mix_inner_dim) + self.attn1 = Attention( + query_dim=time_mix_inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + cross_attention_dim=None, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = nn.LayerNorm(time_mix_inner_dim) + self.attn2 = Attention( + query_dim=time_mix_inner_dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(time_mix_inner_dim) + self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = None + + def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): + # Sets chunk feed-forward + self._chunk_size = chunk_size + # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off + self._chunk_dim = 1 + + def forward( + self, + hidden_states: torch.Tensor, + num_frames: int, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + batch_frames, seq_length, channels = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) + + residual = hidden_states + hidden_states = self.norm_in(hidden_states) + + if self._chunk_size is not None: + hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) + else: + hidden_states = self.ff_in(hidden_states) + + if self.is_res: + hidden_states = hidden_states + residual + + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.is_res: + hidden_states = ff_output + hidden_states + else: + hidden_states = ff_output + + hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) + + return hidden_states @dataclass @@ -101,7 +237,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( [ - BasicTransformerBlock( + TemporalTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, @@ -254,7 +390,7 @@ class TransformerSpatioTemporalModel(nn.Module): time_mix_inner_dim = inner_dim self.temporal_transformer_blocks = nn.ModuleList( [ - TemporalBasicTransformerBlock( + TemporalTransformerBlock( inner_dim, time_mix_inner_dim, num_attention_heads,