From e488d09df1b57f21241f67dd0fb7cb2750b2e100 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 10:26:08 +0200 Subject: [PATCH 01/19] update --- src/diffusers/models/embeddings.py | 84 ++ .../models/transformers/transformer_mochi.py | 118 +++ .../transformer_mochi_original.py | 961 ++++++++++++++++++ 3 files changed, 1163 insertions(+) create mode 100644 src/diffusers/models/transformers/transformer_mochi.py create mode 100644 src/diffusers/models/transformers/transformer_mochi_original.py diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 44f01c46eb..4ccddbbaf4 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1430,6 +1430,90 @@ class AttentionPooling(nn.Module): return a[:, 0, :] # cls_token +class MochiAttentionPool(nn.Module): + def __init__( + self, + num_attention_heads: int, + embed_dim: int, + output_dim: Optional[int] = None, + ) -> None: + super().__init__() + + self.output_dim = output_dim or embed_dim + self.num_attention_heads = num_attention_heads + + self.to_kv = nn.Linear(embed_dim, 2 * embed_dim) + self.to_q = nn.Linear(embed_dim, embed_dim) + self.to_out = nn.Linear(embed_dim, self.output_dim) + + @staticmethod + def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor: + """ + Pool tokens in x using mask. + + NOTE: We assume x does not require gradients. + + Args: + x: (B, L, D) tensor of tokens. + mask: (B, L) boolean tensor indicating which tokens are not padding. + + Returns: + pooled: (B, D) tensor of pooled tokens. + """ + assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens. + assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens. + mask = mask[:, :, None].to(dtype=x.dtype) + mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1) + pooled = (x * mask).sum(dim=1, keepdim=keepdim) + return pooled + + def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + r""" + Args: + x (`torch.Tensor`): + Tensor of shape `(B, S, D)` of input tokens. + mask (`torch.Tensor`): + Boolean ensor of shape `(B, S)` indicating which tokens are not padding. + + Returns: + `torch.Tensor`: + `(B, D)` tensor of pooled tokens. + """ + D = x.size(2) + + # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L). + attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L). + attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L). + + # Average non-padding token features. These will be used as the query. + x_pool = self.pool_tokens(x, mask, keepdim=True) # (B, 1, D) + + # Concat pooled features to input sequence. + x = torch.cat([x_pool, x], dim=1) # (B, L+1, D) + + # Compute queries, keys, values. Only the mean token is used to create a query. + kv = self.to_kv(x) # (B, L+1, 2 * D) + q = self.to_q(x[:, 0]) # (B, D) + + # Extract heads. + head_dim = D // self.num_attention_heads + kv = kv.unflatten(2, (2, self.num_attention_heads, head_dim)) # (B, 1+L, 2, H, head_dim) + kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim) + k, v = kv.unbind(2) # (B, H, 1+L, head_dim) + q = q.unflatten(1, (self.num_attention_heads, head_dim)) # (B, H, head_dim) + q = q.unsqueeze(2) # (B, H, 1, head_dim) + + # Compute attention. + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=0.0 + ) # (B, H, 1, head_dim) + + # Concatenate heads and run output. + x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) + x = self.to_out(x) + return x + + def get_fourier_embeds_from_boundingbox(embed_dim, box): """ Args: diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py new file mode 100644 index 0000000000..c56a7845cb --- /dev/null +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -0,0 +1,118 @@ +# Copyright 2024 The Genmo team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import Attention, FeedForward +from ..embeddings import PatchEmbed, MochiAttentionPool, TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class MochiTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + caption_dim: int, + update_captions: bool = True, + ) -> None: + super().__init__() + + # TODO: Replace this with norm + self.mod_x = nn.Linear(dim, 4 * dim) + if self.update_y: + self.mod_y = nn.Linear(dim, 4 * caption_dim) + else: + self.mod_y = nn.Linear(dim, caption_dim) + + # TODO(aryan): attention class does not look compatible + self.attn1 = Attention(...) + # norms go in attention + # self.q_norm_x = RMSNorm(attention_head_dim) + # self.k_norm_x = RMSNorm(attention_head_dim) + # self.q_norm_y = RMSNorm(attention_head_dim) + # self.k_norm_y = RMSNorm(attention_head_dim) + + self.proj_x = nn.Linear(dim, dim) + + self.proj_y = nn.Linear(dim, caption_dim) if update_captions else None + + def forward(self): + pass + + +@maybe_allow_in_graph +class MochiTransformer3D(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 2, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 48, + caption_dim=1536, + mlp_ratio_x=4.0, + mlp_ratio_y=4.0, + in_channels=12, + qk_norm=True, + qkv_bias=False, + out_bias=True, + timestep_mlp_bias=True, + timestep_scale=1000.0, + text_embed_dim=4096, + max_sequence_length=256, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + self.caption_embedder = MochiAttentionPool(num_attention_heads=8, embed_dim=text_embed_dim, output_dim=inner_dim) + self.caption_proj = nn.Linear(text_embed_dim, caption_dim) + + self.pos_frequencies = nn.Parameter( + torch.empty(3, num_attention_heads, attention_head_dim // 2) + ) + + self.transformer_blocks = nn.ModuleList([ + MochiTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + caption_dim=caption_dim, + update_captions=i < num_layers - 1, + ) + for i in range(num_layers) + ]) diff --git a/src/diffusers/models/transformers/transformer_mochi_original.py b/src/diffusers/models/transformers/transformer_mochi_original.py new file mode 100644 index 0000000000..52bdfa0710 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_mochi_original.py @@ -0,0 +1,961 @@ +import collections +import functools +import itertools +import math +from typing import Any, Callable, Dict, Optional, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(itertools.repeat(x, n)) + + return parse + +to_2tuple = _ntuple(2) + +def centers(start: float, stop, num, dtype=None, device=None): + """linspace through bin centers. + + Args: + start (float): Start of the range. + stop (float): End of the range. + num (int): Number of points. + dtype (torch.dtype): Data type of the points. + device (torch.device): Device of the points. + + Returns: + centers (Tensor): Centers of the bins. Shape: (num,). + """ + edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device) + return (edges[:-1] + edges[1:]) / 2 + + +@functools.lru_cache(maxsize=1) +def create_position_matrix( + T: int, + pH: int, + pW: int, + device: torch.device, + dtype: torch.dtype, + *, + target_area: float = 36864, +): + """ + Args: + T: int - Temporal dimension + pH: int - Height dimension after patchify + pW: int - Width dimension after patchify + + Returns: + pos: [T * pH * pW, 3] - position matrix + """ + with torch.no_grad(): + # Create 1D tensors for each dimension + t = torch.arange(T, dtype=dtype) + + # Positionally interpolate to area 36864. + # (3072x3072 frame with 16x16 patches = 192x192 latents). + # This automatically scales rope positions when the resolution changes. + # We use a large target area so the model is more sensitive + # to changes in the learned pos_frequencies matrix. + scale = math.sqrt(target_area / (pW * pH)) + w = centers(-pW * scale / 2, pW * scale / 2, pW) + h = centers(-pH * scale / 2, pH * scale / 2, pH) + + # Use meshgrid to create 3D grids + grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") + + # Stack and reshape the grids. + pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3] + pos = pos.view(-1, 3) # [T * pH * pW, 3] + pos = pos.to(dtype=dtype, device=device) + + return pos + + +def compute_mixed_rotation( + freqs: torch.Tensor, + pos: torch.Tensor, +): + """ + Project each 3-dim position into per-head, per-head-dim 1D frequencies. + + Args: + freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position + pos: [N, 3] - position of each token + num_heads: int + + Returns: + freqs_cos: [N, num_heads, num_freqs] - cosine components + freqs_sin: [N, num_heads, num_freqs] - sine components + """ + with torch.autocast("cuda", enabled=False): + assert freqs.ndim == 3 + freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs) + freqs_cos = torch.cos(freqs_sum) + freqs_sin = torch.sin(freqs_sum) + return freqs_cos, freqs_sin + + +class TimestepEmbedder(nn.Module): + def __init__( + self, + hidden_size: int, + frequency_embedding_size: int = 256, + *, + bias: bool = True, + timestep_scale: Optional[float] = None, + device: Optional[torch.device] = None, + ): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=bias, device=device), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=bias, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + self.timestep_scale = timestep_scale + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) + freqs.mul_(-math.log(max_period) / half).exp_() + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + if self.timestep_scale is not None: + t = t * self.timestep_scale + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class PooledCaptionEmbedder(nn.Module): + def __init__( + self, + caption_feature_dim: int, + hidden_size: int, + *, + bias: bool = True, + device: Optional[torch.device] = None, + ): + super().__init__() + self.caption_feature_dim = caption_feature_dim + self.hidden_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=bias, device=device), + ) + + def forward(self, x): + return self.mlp(x) + + +class FeedForward(nn.Module): + def __init__( + self, + in_features: int, + hidden_size: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + device: Optional[torch.device] = None, + ): + super().__init__() + # keep parameter count and computation constant compared to standard FFN + hidden_size = int(2 * hidden_size / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_size = int(ffn_dim_multiplier * hidden_size) + hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of) + + self.hidden_dim = hidden_size + self.w1 = nn.Linear(in_features, 2 * hidden_size, bias=False, device=device) + self.w2 = nn.Linear(hidden_size, in_features, bias=False, device=device) + + def forward(self, x): + x, gate = self.w1(x).chunk(2, dim=-1) + x = self.w2(F.silu(x) * gate) + return x + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten: bool = True, + bias: bool = True, + dynamic_img_pad: bool = False, + device: Optional[torch.device] = None, + ): + super().__init__() + self.patch_size = to_2tuple(patch_size) + self.flatten = flatten + self.dynamic_img_pad = dynamic_img_pad + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + device=device, + ) + assert norm_layer is None + self.norm = ( + norm_layer(embed_dim, device=device) if norm_layer else nn.Identity() + ) + + def forward(self, x): + B, _C, T, H, W = x.shape + if not self.dynamic_img_pad: + assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." + assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." + else: + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) + + x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T) + x = self.proj(x) + + # Flatten temporal and spatial dimensions. + if not self.flatten: + raise NotImplementedError("Must flatten output.") + x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T) + + x = self.norm(x) + return x + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-5, device=None): + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device)) + self.register_parameter("bias", None) + + def forward(self, x): + x_fp32 = x.float() + x_normed = x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps) + return (x_normed * self.weight).type_as(x) + + +def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor: + """ + Pool tokens in x using mask. + + NOTE: We assume x does not require gradients. + + Args: + x: (B, L, D) tensor of tokens. + mask: (B, L) boolean tensor indicating which tokens are not padding. + + Returns: + pooled: (B, D) tensor of pooled tokens. + """ + assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens. + assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens. + mask = mask[:, :, None].to(dtype=x.dtype) + mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1) + pooled = (x * mask).sum(dim=1, keepdim=keepdim) + return pooled + + +class AttentionPool(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + output_dim: int = None, + device: Optional[torch.device] = None, + ): + """ + Args: + spatial_dim (int): Number of tokens in sequence length. + embed_dim (int): Dimensionality of input tokens. + num_heads (int): Number of attention heads. + output_dim (int): Dimensionality of output tokens. Defaults to embed_dim. + """ + super().__init__() + self.num_heads = num_heads + self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device) + self.to_q = nn.Linear(embed_dim, embed_dim, device=device) + self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device) + + def forward(self, x, mask): + """ + Args: + x (torch.Tensor): (B, L, D) tensor of input tokens. + mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding. + + NOTE: We assume x does not require gradients. + + Returns: + x (torch.Tensor): (B, D) tensor of pooled tokens. + """ + D = x.size(2) + + # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L). + attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L). + attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L). + + # Average non-padding token features. These will be used as the query. + x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D) + + # Concat pooled features to input sequence. + x = torch.cat([x_pool, x], dim=1) # (B, L+1, D) + + # Compute queries, keys, values. Only the mean token is used to create a query. + kv = self.to_kv(x) # (B, L+1, 2 * D) + q = self.to_q(x[:, 0]) # (B, D) + + # Extract heads. + head_dim = D // self.num_heads + kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim) + kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim) + k, v = kv.unbind(2) # (B, H, 1+L, head_dim) + q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim) + q = q.unsqueeze(2) # (B, H, 1, head_dim) + + # Compute attention. + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=0.0 + ) # (B, H, 1, head_dim) + + # Concatenate heads and run output. + x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) + x = self.to_out(x) + return x + + +class ResidualTanhGatedRMSNorm(torch.autograd.Function): + @staticmethod + def forward(ctx, x, x_res, gate, eps=1e-6): + # Convert to fp32 for precision + x_res_fp32 = x_res.float() + + # Compute RMS + mean_square = x_res_fp32.pow(2).mean(-1, keepdim=True) + scale = torch.rsqrt(mean_square + eps) + + # Apply tanh to gate + tanh_gate = torch.tanh(gate).unsqueeze(1) + + # Normalize and apply gated scaling + x_normed = x_res_fp32 * scale * tanh_gate + + # Apply residual connection + output = x + x_normed.type_as(x) + + return output + + +def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6): + return ResidualTanhGatedRMSNorm.apply(x, x_res, gate, eps) + + +class ModulatedRMSNorm(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale, eps=1e-6): + # Convert to fp32 for precision + x_fp32 = x.float() + scale_fp32 = scale.float() + + # Compute RMS + mean_square = x_fp32.pow(2).mean(-1, keepdim=True) + inv_rms = torch.rsqrt(mean_square + eps) + + # Normalize and modulate + x_normed = x_fp32 * inv_rms + x_modulated = x_normed * (1 + scale_fp32.unsqueeze(1)) + + return x_modulated.type_as(x) + + +def modulated_rmsnorm(x, scale, eps=1e-6): + return ModulatedRMSNorm.apply(x, scale, eps) + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class AsymmetricJointBlock(nn.Module): + def __init__( + self, + hidden_size_x: int, + hidden_size_y: int, + num_heads: int, + *, + mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens. + mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens. + update_y: bool = True, # Whether to update text tokens in this block. + device: Optional[torch.device] = None, + **block_kwargs, + ): + super().__init__() + self.update_y = update_y + self.hidden_size_x = hidden_size_x + self.hidden_size_y = hidden_size_y + self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device) + if self.update_y: + self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device) + else: + self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device) + + # Self-attention: + self.attn = AsymmetricAttention( + hidden_size_x, + hidden_size_y, + num_heads=num_heads, + update_y=update_y, + device=device, + **block_kwargs, + ) + + # MLP. + mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x) + assert mlp_hidden_dim_x == int(1536 * 8) + self.mlp_x = FeedForward( + in_features=hidden_size_x, + hidden_size=mlp_hidden_dim_x, + multiple_of=256, + ffn_dim_multiplier=None, + device=device, + ) + + # MLP for text not needed in last block. + if self.update_y: + mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y) + self.mlp_y = FeedForward( + in_features=hidden_size_y, + hidden_size=mlp_hidden_dim_y, + multiple_of=256, + ffn_dim_multiplier=None, + device=device, + ) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, + y: torch.Tensor, + **attn_kwargs, + ): + """Forward pass of a block. + + Args: + x: (B, N, dim) tensor of visual tokens + c: (B, dim) tensor of conditioned features + y: (B, L, dim) tensor of text tokens + num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens + + Returns: + x: (B, N, dim) tensor of visual tokens after block + y: (B, L, dim) tensor of text tokens after block + """ + N = x.size(1) + + c = F.silu(c) + mod_x = self.mod_x(c) + scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1) + + mod_y = self.mod_y(c) + if self.update_y: + scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1) + else: + scale_msa_y = mod_y + + # Self-attention block. + x_attn, y_attn = self.attn( + x, + y, + scale_x=scale_msa_x, + scale_y=scale_msa_y, + **attn_kwargs, + ) + + assert x_attn.size(1) == N + x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x) + if self.update_y: + y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y) + + # MLP block. + x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x) + if self.update_y: + y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y) + + return x, y + + def ff_block_x(self, x, scale_x, gate_x): + x_mod = modulated_rmsnorm(x, scale_x) + x_res = self.mlp_x(x_mod) + x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm + return x + + def ff_block_y(self, y, scale_y, gate_y): + y_mod = modulated_rmsnorm(y, scale_y) + y_res = self.mlp_y(y_mod) + y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm + return y + + +class AsymmetricAttention(nn.Module): + def __init__( + self, + dim_x: int, + dim_y: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_norm: bool = False, + update_y: bool = True, + out_bias: bool = True, + softmax_scale: Optional[float] = None, + device: Optional[torch.device] = None, + ): + super().__init__() + self.dim_x = dim_x + self.dim_y = dim_y + self.num_heads = num_heads + self.head_dim = dim_x // num_heads + self.update_y = update_y + self.softmax_scale = softmax_scale + if dim_x % num_heads != 0: + raise ValueError( + f"dim_x={dim_x} should be divisible by num_heads={num_heads}" + ) + + # Input layers. + self.qkv_bias = qkv_bias + self.qkv_x = nn.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device) + # Project text features to match visual features (dim_y -> dim_x) + self.qkv_y = nn.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device) + + # Query and key normalization for stability. + assert qk_norm + self.q_norm_x = RMSNorm(self.head_dim, device=device) + self.k_norm_x = RMSNorm(self.head_dim, device=device) + self.q_norm_y = RMSNorm(self.head_dim, device=device) + self.k_norm_y = RMSNorm(self.head_dim, device=device) + + # Output layers. y features go back down from dim_x -> dim_y. + self.proj_x = nn.Linear(dim_x, dim_x, bias=out_bias, device=device) + self.proj_y = ( + nn.Linear(dim_x, dim_y, bias=out_bias, device=device) + if update_y + else nn.Identity() + ) + + # def run_qkv_y(self, y): + # cp_rank, cp_size = cp.get_cp_rank_size() + # local_heads = self.num_heads // cp_size + + # if cp.is_cp_active(): + # # Only predict local heads. + # assert not self.qkv_bias + # W_qkv_y = self.qkv_y.weight.view( + # 3, self.num_heads, self.head_dim, self.dim_y + # ) + # W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads) + # W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y) + # qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim) + # else: + # qkv_y = self.qkv_y(y) # (B, L, 3 * dim) + + # qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim) + # q_y, k_y, v_y = qkv_y.unbind(2) + # return q_y, k_y, v_y + + # def prepare_qkv( + # self, + # x: torch.Tensor, # (B, N, dim_x) + # y: torch.Tensor, # (B, L, dim_y) + # *, + # scale_x: torch.Tensor, + # scale_y: torch.Tensor, + # rope_cos: torch.Tensor, + # rope_sin: torch.Tensor, + # valid_token_indices: torch.Tensor, + # ): + # # Pre-norm for visual features + # x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size + + # # Process visual features + # qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) + # assert qkv_x.dtype == torch.bfloat16 + # qkv_x = cp.all_to_all_collect_tokens( + # qkv_x, self.num_heads + # ) # (3, B, N, local_h, head_dim) + + # # Process text features + # y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) + # q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim) + # q_y = self.q_norm_y(q_y) + # k_y = self.k_norm_y(k_y) + + # # Split qkv_x into q, k, v + # q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim) + # q_x = self.q_norm_x(q_x) + # q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) + # k_x = self.k_norm_x(k_x) + # k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) + + # # Unite streams + # qkv = unify_streams( + # q_x, + # k_x, + # v_x, + # q_y, + # k_y, + # v_y, + # valid_token_indices, + # ) + + # return qkv + + # @torch.compiler.disable() + # def run_attention( + # self, + # qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim) + # *, + # B: int, + # L: int, + # M: int, + # cu_seqlens: torch.Tensor, + # max_seqlen_in_batch: int, + # valid_token_indices: torch.Tensor, + # ): + # with torch.autocast("cuda", enabled=False): + # out: torch.Tensor = flash_attn_varlen_qkvpacked_func( + # qkv, + # cu_seqlens=cu_seqlens, + # max_seqlen=max_seqlen_in_batch, + # dropout_p=0.0, + # softmax_scale=self.softmax_scale, + # ) # (total, local_heads, head_dim) + # out = out.view(total, local_dim) + + # x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype) + # assert x.size() == (B, N, local_dim) + # assert y.size() == (B, L, local_dim) + + # x = x.view(B, N, local_heads, self.head_dim) + # x = self.proj_x(x) # (B, M, dim_x) + + # y = self.proj_y(y) # (B, L, dim_y) + # return x, y + + # def forward( + # self, + # x: torch.Tensor, # (B, N, dim_x) + # y: torch.Tensor, # (B, L, dim_y) + # *, + # scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. + # scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. + # packed_indices: Dict[str, torch.Tensor] = None, + # **rope_rotation, + # ) -> Tuple[torch.Tensor, torch.Tensor]: + # """Forward pass of asymmetric multi-modal attention. + + # Args: + # x: (B, N, dim_x) tensor for visual tokens + # y: (B, L, dim_y) tensor of text token features + # packed_indices: Dict with keys for Flash Attention + # num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens + + # Returns: + # x: (B, N, dim_x) tensor of visual tokens after multi-modal attention + # y: (B, L, dim_y) tensor of text token features after multi-modal attention + # """ + # B, L, _ = y.shape + # _, M, _ = x.shape + + # # Predict a packed QKV tensor from visual and text features. + # # Don't checkpoint the all_to_all. + # qkv = self.prepare_qkv( + # x=x, + # y=y, + # scale_x=scale_x, + # scale_y=scale_y, + # rope_cos=rope_rotation.get("rope_cos"), + # rope_sin=rope_rotation.get("rope_sin"), + # valid_token_indices=packed_indices["valid_token_indices_kv"], + # ) # (total <= B * (N + L), 3, local_heads, head_dim) + + # x, y = self.run_attention( + # qkv, + # B=B, + # L=L, + # M=M, + # cu_seqlens=packed_indices["cu_seqlens_kv"], + # max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"], + # valid_token_indices=packed_indices["valid_token_indices_kv"], + # ) + # return x, y + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__( + self, + hidden_size, + patch_size, + out_channels, + device: Optional[torch.device] = None, + ): + super().__init__() + self.norm_final = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, device=device + ) + self.mod = nn.Linear(hidden_size, 2 * hidden_size, device=device) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, device=device + ) + + def forward(self, x, c): + c = F.silu(c) + shift, scale = self.mod(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class MochiTransformer3DModel(nn.Module): + """ + Diffusion model with a Transformer backbone. + + Ingests text embeddings instead of a label. + """ + + def __init__( + self, + *, + patch_size=2, + in_channels=4, + hidden_size_x=1152, + hidden_size_y=1152, + depth=48, + num_heads=16, + mlp_ratio_x=8.0, + mlp_ratio_y=4.0, + t5_feat_dim: int = 4096, + t5_token_length: int = 256, + patch_embed_bias: bool = True, + timestep_mlp_bias: bool = True, + timestep_scale: Optional[float] = None, + use_extended_posenc: bool = False, + rope_theta: float = 10000.0, + device: Optional[torch.device] = None, + **block_kwargs, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.hidden_size_x = hidden_size_x + self.hidden_size_y = hidden_size_y + self.head_dim = ( + hidden_size_x // num_heads + ) # Head dimension and count is determined by visual. + self.use_extended_posenc = use_extended_posenc + self.t5_token_length = t5_token_length + self.t5_feat_dim = t5_feat_dim + self.rope_theta = ( + rope_theta # Scaling factor for frequency computation for temporal RoPE. + ) + + self.x_embedder = PatchEmbed( + patch_size=patch_size, + in_chans=in_channels, + embed_dim=hidden_size_x, + bias=patch_embed_bias, + device=device, + ) + # Conditionings + # Timestep + self.t_embedder = TimestepEmbedder( + hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale + ) + + # Caption Pooling (T5) + self.t5_y_embedder = AttentionPool( + t5_feat_dim, num_heads=8, output_dim=hidden_size_x, device=device + ) + + # Dense Embedding Projection (T5) + self.t5_yproj = nn.Linear( + t5_feat_dim, hidden_size_y, bias=True, device=device + ) + + # Initialize pos_frequencies as an empty parameter. + self.pos_frequencies = nn.Parameter( + torch.empty(3, self.num_heads, self.head_dim // 2, device=device) + ) + + # for depth 48: + # b = 0: AsymmetricJointBlock, update_y=True + # b = 1: AsymmetricJointBlock, update_y=True + # ... + # b = 46: AsymmetricJointBlock, update_y=True + # b = 47: AsymmetricJointBlock, update_y=False. No need to update text features. + blocks = [] + for b in range(depth): + # Joint multi-modal block + update_y = b < depth - 1 + block = AsymmetricJointBlock( + hidden_size_x, + hidden_size_y, + num_heads, + mlp_ratio_x=mlp_ratio_x, + mlp_ratio_y=mlp_ratio_y, + update_y=update_y, + device=device, + **block_kwargs, + ) + + blocks.append(block) + self.blocks = nn.ModuleList(blocks) + + self.final_layer = FinalLayer( + hidden_size_x, patch_size, self.out_channels, device=device + ) + + def embed_x(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (B, C=12, T, H, W) tensor of visual tokens + + Returns: + x: (B, C=3072, N) tensor of visual tokens with positional embedding. + """ + return self.x_embedder(x) # Convert BcTHW to BCN + + def prepare( + self, + x: torch.Tensor, + sigma: torch.Tensor, + t5_feat: torch.Tensor, + t5_mask: torch.Tensor, + ): + """Prepare input and conditioning embeddings.""" + + with torch.profiler.record_function("x_emb_pe"): + # Visual patch embeddings with positional encoding. + T, H, W = x.shape[-3:] + pH, pW = H // self.patch_size, W // self.patch_size + x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2 + assert x.ndim == 3 + B = x.size(0) + + with torch.profiler.record_function("rope_cis"): + # Construct position array of size [N, 3]. + # pos[:, 0] is the frame index for each location, + # pos[:, 1] is the row index for each location, and + # pos[:, 2] is the column index for each location. + pH, pW = H // self.patch_size, W // self.patch_size + N = T * pH * pW + assert x.size(1) == N + pos = create_position_matrix( + T, pH=pH, pW=pW, device=x.device, dtype=torch.float32 + ) # (N, 3) + rope_cos, rope_sin = compute_mixed_rotation( + freqs=self.pos_frequencies, pos=pos + ) # Each are (N, num_heads, dim // 2) + + with torch.profiler.record_function("t_emb"): + # Global vector embedding for conditionings. + c_t = self.t_embedder(1 - sigma) # (B, D) + + with torch.profiler.record_function("t5_pool"): + # Pool T5 tokens using attention pooler + # Note y_feat[1] contains T5 token features. + assert ( + t5_feat.size(1) == self.t5_token_length + ), f"Expected L={self.t5_token_length}, got {t5_feat.shape} for y_feat." + t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D) + assert ( + t5_y_pool.size(0) == B + ), f"Expected B={B}, got {t5_y_pool.shape} for t5_y_pool." + + c = c_t + t5_y_pool + + y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D) + + return x, c, y_feat, rope_cos, rope_sin + + def forward( + self, + x: torch.Tensor, + sigma: torch.Tensor, + y_feat: List[torch.Tensor], + y_mask: List[torch.Tensor], + packed_indices: Dict[str, torch.Tensor] = None, + rope_cos: torch.Tensor = None, + rope_sin: torch.Tensor = None, + ): + """Forward pass of DiT. + + Args: + x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images) + sigma: (B,) tensor of noise standard deviations + y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048) + y_mask: List((B, L) boolean tensor indicating which tokens are not padding) + packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices. + """ + B, _, T, H, W = x.shape + + x, c, y_feat, rope_cos, rope_sin = self.prepare( + x, sigma, y_feat[0], y_mask[0] + ) + del y_mask + + for i, block in enumerate(self.blocks): + x, y_feat = block( + x, + c, + y_feat, + rope_cos=rope_cos, + rope_sin=rope_sin, + packed_indices=packed_indices, + ) # (B, M, D), (B, L, D) + del y_feat # Final layers don't use dense text features. + + x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) + + patch = x.size(2) + x = rearrange(x, "(G B) M P -> B (G M) P", G=1, P=patch) + x = rearrange( + x, + "B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)", + T=T, + hp=H // self.patch_size, + wp=W // self.patch_size, + p1=self.patch_size, + p2=self.patch_size, + c=self.out_channels, + ) + + return x From 64275b0e66b867475afe996ef5c11271de069d21 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 00:17:16 +0200 Subject: [PATCH 02/19] udpate --- src/diffusers/models/attention_processor.py | 138 ++++++++++++++++++ src/diffusers/models/normalization.py | 23 +++ .../models/transformers/transformer_mochi.py | 67 ++++++--- 3 files changed, 210 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e735c4ee7d..7635026d3e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -717,6 +717,144 @@ class Attention(nn.Module): self.fused_projections = fuse +class AsymmetricAttention(nn.Module): + def __init__( + self, + query_dim: int, + query_context_dim: int, + num_attentions_heads: int = 8, + attention_head_dim: int = 64, + bias: bool = False, + context_bias: bool = False, + out_dim: Optional[int] = None, + out_context_dim: Optional[int] = None, + qk_norm: Optional[str] = None, + eps: float = 1e-5, + elementwise_affine: bool = True, + out_bias: bool = True, + processor: Optional["AttnProcessor"] = None, + ) -> None: + from .normalization import RMSNorm + + self.query_dim = query_dim + self.query_context_dim = query_context_dim + self.inner_dim = out_dim if out_dim is not None else num_attentions_heads * attention_head_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_context_dim + + self.scale = attention_head_dim ** -0.5 + self.num_attention_heads = out_dim // attention_head_dim if out_dim is not None else num_attentions_heads + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + self.norm_context_q = None + self.norm_context_k = None + elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) + self.norm_context_q = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) + self.norm_context_k = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) + else: + raise ValueError((f"Unknown qk_norm: {qk_norm}. Should be None or `rms_norm`.")) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) + + self.to_context_q = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias) + self.to_context_k = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias) + self.to_context_v = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias) + + # TODO(aryan): Take care of dropouts for training purpose in future + self.to_out = nn.ModuleList([ + nn.Linear(self.inner_dim, self.out_dim) + ]) + self.to_out = nn.ModuleList([ + nn.Linear(self.inner_dim, self.out_context_dim) + ]) + + if processor is None: + processor = AsymmetricAttnProcessor2_0() + + self.set_processor(processor) + + +# Similar to SD3 +# class AsymmetricAttnProcessor2_0: +# r""" +# Processor for implementing Asymmetric SDPA as described in Genmo/Mochi (TODO(aryan) add link). +# """ + +# def __init__(self): +# if not hasattr(F, "scaled_dot_product_attention"): +# raise ImportError("AsymmetricAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + +# def __call__( +# self, +# attn: AsymmetricAttention, +# hidden_states: torch.Tensor, +# encoder_hidden_states: torch.Tensor, +# temb: torch.Tensor, +# image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, +# ) -> torch.Tensor: +# batch_size = hidden_states.size(0) + +# query = attn.to_q(hidden_states) +# key = attn.to_k(hidden_states) +# value = attn.to_v(hidden_states) + +# query_context = attn.to_context_q(encoder_hidden_states) +# key_context = attn.to_context_k(encoder_hidden_states) +# value_context = attn.to_context_v(encoder_hidden_states) + +# inner_dim = key.shape[-1] +# head_dim = inner_dim / attn.num_attention_heads + +# query = query.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) +# key = key.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) +# value = value.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) + +# query_context = query_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) +# key_context = key_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) +# value_context = value_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) + +# if attn.norm_q is not None: +# query = attn.norm_q(query) +# if attn.norm_k is not None: +# key = attn.norm_k(key) + +# if attn.norm_context_q is not None: +# query_context = attn.norm_context_q(query_context) +# key_context = attn.norm_context_k(key_context) + +# if image_rotary_emb is not None: +# from .embeddings import apply_rotary_emb + +# query = apply_rotary_emb(query, image_rotary_emb) +# key = apply_rotary_emb(key, image_rotary_emb) + +# sequence_length = query.size(1) +# context_sequence_length = query_context.size(1) +# query = torch.cat([query, query_context], dim=1) +# key = torch.cat([key, key_context], dim=1) +# value = torch.cat([value, value_context], dim=1) + +# hidden_states = F.scaled_dot_product_attention( +# query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False +# ) + +# hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) +# hidden_states = hidden_states.to(query.dtype) + +# hidden_states, encoder_hidden_states = hidden_states.split_with_sizes([sequence_length, context_sequence_length], dim=1) + +# hidden_states = attn.to_out[0](hidden_states) +# encoder_hidden_states = attn.to_context_out[0](encoder_hidden_states) + +# return hidden_states, encoder_hidden_states + + class AttnProcessor: r""" Default processor for performing attention-related computations. diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 029c147fcb..03e03b0c19 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -237,6 +237,29 @@ class LuminaRMSNormZero(nn.Module): return x, gate_msa, scale_mlp, gate_mlp +class MochiRMSNormZero(nn.Module): + r""" + Adaptive RMS Norm used in Mochi. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__(self, embedding_dim: int, hidden_dim: int, norm_eps: float = 1e-5, elementwise_affine: bool = False) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, hidden_dim) + self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=elementwise_affine) + + def forward(self, hidden_states: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + + return hidden_states, gate_msa, scale_mlp, gate_mlp + + class AdaLayerNormSingle(nn.Module): r""" Norm layer adaptive layer norm single (adaLN-single). diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index c56a7845cb..6ac5d1a49d 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -25,7 +25,7 @@ from ..attention import Attention, FeedForward from ..embeddings import PatchEmbed, MochiAttentionPool, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNorm +from ..normalization import MochiRMSNormZero, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -39,31 +39,60 @@ class MochiTransformerBlock(nn.Module): num_attention_heads: int, attention_head_dim: int, caption_dim: int, + activation_fn: str = "swiglu", update_captions: bool = True, ) -> None: super().__init__() - # TODO: Replace this with norm - self.mod_x = nn.Linear(dim, 4 * dim) - if self.update_y: - self.mod_y = nn.Linear(dim, 4 * caption_dim) - else: - self.mod_y = nn.Linear(dim, caption_dim) + self.update_captions = update_captions - # TODO(aryan): attention class does not look compatible - self.attn1 = Attention(...) - # norms go in attention - # self.q_norm_x = RMSNorm(attention_head_dim) - # self.k_norm_x = RMSNorm(attention_head_dim) - # self.q_norm_y = RMSNorm(attention_head_dim) - # self.k_norm_y = RMSNorm(attention_head_dim) + self.norm1 = MochiRMSNormZero(dim, 4 * dim) - self.proj_x = nn.Linear(dim, dim) + if update_captions: + self.norm_context1 = MochiRMSNormZero(dim, 4 * caption_dim) + else: + self.norm_context1 = RMSNorm(caption_dim, eps=1e-5, elementwise_affine=False) + + self.attn = Attention( + query_dim=dim, + heads=num_attention_heads, + attention_head_dim=attention_head_dim, + out_dim=4 * dim, + qk_norm="rms_norm", + eps=1e-5, + elementwise_affine=False, + ) + self.attn_context = Attention( + query_dim=dim, + heads=num_attention_heads, + attention_head_dim=attention_head_dim, + out_dim=4 * caption_dim if update_captions else caption_dim, + qk_norm="rms_norm", + eps=1e-5, + elementwise_affine=False, + ) - self.proj_y = nn.Linear(dim, caption_dim) if update_captions else None + self.ff = FeedForward(dim, mult=4, activation_fn=activation_fn) + self.ff_context = FeedForward(caption_dim, mult=4, activation_fn=activation_fn) - def forward(self): - pass + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + + if self.update_captions: + norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm_context1(encoder_hidden_states, temb) + else: + norm_encoder_hidden_states = self.norm_context1(encoder_hidden_states) + + attn_hidden_states = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + image_rotary_emb=image_rotary_emb, + ) + attn_encoder_hidden_states = self.attn_context( + hidden_states=norm_encoder_hidden_states, + encoder_hidden_states=None, + image_rotary_emb=None, + ) @maybe_allow_in_graph @@ -87,6 +116,7 @@ class MochiTransformer3D(ModelMixin, ConfigMixin): timestep_mlp_bias=True, timestep_scale=1000.0, text_embed_dim=4096, + activation_fn: str = "swiglu", max_sequence_length=256, ) -> None: super().__init__() @@ -112,6 +142,7 @@ class MochiTransformer3D(ModelMixin, ConfigMixin): num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, caption_dim=caption_dim, + activation_fn=activation_fn, update_captions=i < num_layers - 1, ) for i in range(num_layers) From da48940b56cf54d37ad8a1473badc1da387ca782 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 01:27:25 +0200 Subject: [PATCH 03/19] update transformer --- src/diffusers/models/attention_processor.py | 138 ---------------- src/diffusers/models/embeddings.py | 22 +++ .../models/transformers/transformer_mochi.py | 147 ++++++++++++------ 3 files changed, 125 insertions(+), 182 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7635026d3e..e735c4ee7d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -717,144 +717,6 @@ class Attention(nn.Module): self.fused_projections = fuse -class AsymmetricAttention(nn.Module): - def __init__( - self, - query_dim: int, - query_context_dim: int, - num_attentions_heads: int = 8, - attention_head_dim: int = 64, - bias: bool = False, - context_bias: bool = False, - out_dim: Optional[int] = None, - out_context_dim: Optional[int] = None, - qk_norm: Optional[str] = None, - eps: float = 1e-5, - elementwise_affine: bool = True, - out_bias: bool = True, - processor: Optional["AttnProcessor"] = None, - ) -> None: - from .normalization import RMSNorm - - self.query_dim = query_dim - self.query_context_dim = query_context_dim - self.inner_dim = out_dim if out_dim is not None else num_attentions_heads * attention_head_dim - self.out_dim = out_dim if out_dim is not None else query_dim - self.out_context_dim = out_context_dim if out_context_dim is not None else query_context_dim - - self.scale = attention_head_dim ** -0.5 - self.num_attention_heads = out_dim // attention_head_dim if out_dim is not None else num_attentions_heads - - if qk_norm is None: - self.norm_q = None - self.norm_k = None - self.norm_context_q = None - self.norm_context_k = None - elif qk_norm == "rms_norm": - self.norm_q = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) - self.norm_k = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) - self.norm_context_q = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) - self.norm_context_k = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) - else: - raise ValueError((f"Unknown qk_norm: {qk_norm}. Should be None or `rms_norm`.")) - - self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) - - self.to_context_q = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias) - self.to_context_k = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias) - self.to_context_v = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias) - - # TODO(aryan): Take care of dropouts for training purpose in future - self.to_out = nn.ModuleList([ - nn.Linear(self.inner_dim, self.out_dim) - ]) - self.to_out = nn.ModuleList([ - nn.Linear(self.inner_dim, self.out_context_dim) - ]) - - if processor is None: - processor = AsymmetricAttnProcessor2_0() - - self.set_processor(processor) - - -# Similar to SD3 -# class AsymmetricAttnProcessor2_0: -# r""" -# Processor for implementing Asymmetric SDPA as described in Genmo/Mochi (TODO(aryan) add link). -# """ - -# def __init__(self): -# if not hasattr(F, "scaled_dot_product_attention"): -# raise ImportError("AsymmetricAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - -# def __call__( -# self, -# attn: AsymmetricAttention, -# hidden_states: torch.Tensor, -# encoder_hidden_states: torch.Tensor, -# temb: torch.Tensor, -# image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, -# ) -> torch.Tensor: -# batch_size = hidden_states.size(0) - -# query = attn.to_q(hidden_states) -# key = attn.to_k(hidden_states) -# value = attn.to_v(hidden_states) - -# query_context = attn.to_context_q(encoder_hidden_states) -# key_context = attn.to_context_k(encoder_hidden_states) -# value_context = attn.to_context_v(encoder_hidden_states) - -# inner_dim = key.shape[-1] -# head_dim = inner_dim / attn.num_attention_heads - -# query = query.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) -# key = key.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) -# value = value.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) - -# query_context = query_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) -# key_context = key_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) -# value_context = value_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) - -# if attn.norm_q is not None: -# query = attn.norm_q(query) -# if attn.norm_k is not None: -# key = attn.norm_k(key) - -# if attn.norm_context_q is not None: -# query_context = attn.norm_context_q(query_context) -# key_context = attn.norm_context_k(key_context) - -# if image_rotary_emb is not None: -# from .embeddings import apply_rotary_emb - -# query = apply_rotary_emb(query, image_rotary_emb) -# key = apply_rotary_emb(key, image_rotary_emb) - -# sequence_length = query.size(1) -# context_sequence_length = query_context.size(1) -# query = torch.cat([query, query_context], dim=1) -# key = torch.cat([key, key_context], dim=1) -# value = torch.cat([value, value_context], dim=1) - -# hidden_states = F.scaled_dot_product_attention( -# query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False -# ) - -# hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) -# hidden_states = hidden_states.to(query.dtype) - -# hidden_states, encoder_hidden_states = hidden_states.split_with_sizes([sequence_length, context_sequence_length], dim=1) - -# hidden_states = attn.to_out[0](hidden_states) -# encoder_hidden_states = attn.to_context_out[0](encoder_hidden_states) - -# return hidden_states, encoder_hidden_states - - class AttnProcessor: r""" Default processor for performing attention-related computations. diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 4ccddbbaf4..02bf8c460d 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1302,6 +1302,28 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module): return conditioning +class MochiCombinedTimestepCaptionEmbedding(nn.Module): + def __init__(self, embedding_dim: int, pooled_projection_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8) -> None: + super().__init__() + + self.time_proj = Timesteps( + num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0 + ) + self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim) + self.pooler = MochiAttentionPool(num_attention_heads=num_attention_heads, embed_dim=pooled_projection_dim, output_dim=embedding_dim) + self.caption_proj = nn.Linear(embedding_dim, pooled_projection_dim) + + def forward(self, timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, hidden_dtype: Optional[torch.dtype] = None): + time_proj = self.time_proj(timestep) + time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype)) + + pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask) + caption_proj = self.caption_proj(encoder_hidden_states) + + conditioning = time_emb + pooled_projections + return conditioning, caption_proj + + class TextTimeEmbedding(nn.Module): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): super().__init__() diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 6ac5d1a49d..9ede9c2849 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -21,11 +21,11 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention, FeedForward -from ..embeddings import PatchEmbed, MochiAttentionPool, TimestepEmbedding, Timesteps +from ..attention import Attention, FeedForward, JointAttnProcessor2_0 +from ..embeddings import PatchEmbed, MochiCombinedTimestepCaptionEmbedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import MochiRMSNormZero, RMSNorm +from ..normalization import AdaLayerNormContinuous, MochiRMSNormZero, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -38,61 +38,73 @@ class MochiTransformerBlock(nn.Module): dim: int, num_attention_heads: int, attention_head_dim: int, - caption_dim: int, + pooled_projection_dim: int, + qk_norm: str = "rms_norm", activation_fn: str = "swiglu", - update_captions: bool = True, + context_pre_only: bool = True, ) -> None: super().__init__() - self.update_captions = update_captions + self.context_pre_only = context_pre_only self.norm1 = MochiRMSNormZero(dim, 4 * dim) - if update_captions: - self.norm_context1 = MochiRMSNormZero(dim, 4 * caption_dim) + if context_pre_only: + self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim) else: - self.norm_context1 = RMSNorm(caption_dim, eps=1e-5, elementwise_affine=False) + self.norm1_context = RMSNorm(pooled_projection_dim, eps=1e-6, elementwise_affine=False) self.attn = Attention( query_dim=dim, heads=num_attention_heads, attention_head_dim=attention_head_dim, out_dim=4 * dim, - qk_norm="rms_norm", - eps=1e-5, - elementwise_affine=False, - ) - self.attn_context = Attention( - query_dim=dim, - heads=num_attention_heads, - attention_head_dim=attention_head_dim, - out_dim=4 * caption_dim if update_captions else caption_dim, - qk_norm="rms_norm", - eps=1e-5, + qk_norm=qk_norm, + eps=1e-6, elementwise_affine=False, + processor=JointAttnProcessor2_0(), ) + self.norm2 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) + self.norm2_context = RMSNorm(pooled_projection_dim, eps=1e-6, elementwise_affine=False) + + self.norm3 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) + self.norm3_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False) + self.ff = FeedForward(dim, mult=4, activation_fn=activation_fn) - self.ff_context = FeedForward(caption_dim, mult=4, activation_fn=activation_fn) + self.ff_context = FeedForward(pooled_projection_dim, mult=4, activation_fn=activation_fn) + + self.norm4 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) + self.norm4_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False) def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) - if self.update_captions: - norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm_context1(encoder_hidden_states, temb) + if self.context_pre_only: + norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(encoder_hidden_states, temb) else: - norm_encoder_hidden_states = self.norm_context1(encoder_hidden_states) + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) - attn_hidden_states = self.attn( + attn_hidden_states, context_attn_hidden_states = self.attn( hidden_states=norm_hidden_states, - encoder_hidden_states=None, + encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, ) - attn_encoder_hidden_states = self.attn_context( - hidden_states=norm_encoder_hidden_states, - encoder_hidden_states=None, - image_rotary_emb=None, - ) + + hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) + hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + if not self.context_pre_only: + encoder_hidden_states = encoder_hidden_states + self.norm2_context(context_attn_hidden_states) * torch.tanh(enc_gate_msa).unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1)) + + ff_output = self.ff(hidden_states) + context_ff_output = self.ff_context(encoder_hidden_states) + + hidden_states = hidden_states + ff_output * torch.tanh(gate_mlp).unsqueeze(1) + if not self.context_pre_only: + encoder_hidden_states = encoder_hidden_states + context_ff_output * torch.tanh(enc_gate_mlp).unsqueeze(0) + + return hidden_states, encoder_hidden_states @maybe_allow_in_graph @@ -106,22 +118,28 @@ class MochiTransformer3D(ModelMixin, ConfigMixin): num_attention_heads: int = 24, attention_head_dim: int = 128, num_layers: int = 48, - caption_dim=1536, - mlp_ratio_x=4.0, - mlp_ratio_y=4.0, + pooled_projection_dim: int = 1536, in_channels=12, - qk_norm=True, - qkv_bias=False, - out_bias=True, + out_channels: Optional[int] = None, + qk_norm: str = "rms_norm", timestep_mlp_bias=True, timestep_scale=1000.0, - text_embed_dim=4096, + text_embed_dim: int = 4096, + time_embed_dim: int = 256, activation_fn: str = "swiglu", - max_sequence_length=256, + max_sequence_length: int = 256, ) -> None: super().__init__() inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + self.time_embed = MochiCombinedTimestepCaptionEmbedding( + embedding_dim=text_embed_dim, + pooled_projection_dim=pooled_projection_dim, + time_embed_dim=time_embed_dim, + num_attention_heads=8, + ) self.patch_embed = PatchEmbed( patch_size=patch_size, @@ -129,9 +147,6 @@ class MochiTransformer3D(ModelMixin, ConfigMixin): embed_dim=inner_dim, ) - self.caption_embedder = MochiAttentionPool(num_attention_heads=8, embed_dim=text_embed_dim, output_dim=inner_dim) - self.caption_proj = nn.Linear(text_embed_dim, caption_dim) - self.pos_frequencies = nn.Parameter( torch.empty(3, num_attention_heads, attention_head_dim // 2) ) @@ -141,9 +156,53 @@ class MochiTransformer3D(ModelMixin, ConfigMixin): dim=inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, - caption_dim=caption_dim, + pooled_projection_dim=pooled_projection_dim, + qk_norm=qk_norm, activation_fn=activation_fn, - update_captions=i < num_layers - 1, + context_pre_only=i < num_layers - 1, ) for i in range(num_layers) ]) + + self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm") + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_attention_mask: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + return_dict: bool = True, + ) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p = self.config.patch_size + + post_patch_height = height // p + post_patch_width = width // p + + temb, caption_proj = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask) + + hidden_states = self.patch_embed(hidden_states) + + for i, block in enumerate(self.transformer_blocks): + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + # TODO(aryan): do something with self.pos_frequencies + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_height, p, p, -1) + hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) + output = hidden_states.reshape(batch_size, -1, num_frames, height, width) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) From 05ebd6cd8296856199cc6961f1e3ef35878a3e08 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 01:27:51 +0200 Subject: [PATCH 04/19] make style --- src/diffusers/models/embeddings.py | 28 +++-- src/diffusers/models/normalization.py | 8 +- .../models/transformers/transformer_mochi.py | 82 ++++++++------ .../transformer_mochi_original.py | 106 ++++++------------ 4 files changed, 107 insertions(+), 117 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 02bf8c460d..896f479139 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1303,17 +1303,25 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module): class MochiCombinedTimestepCaptionEmbedding(nn.Module): - def __init__(self, embedding_dim: int, pooled_projection_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8) -> None: + def __init__( + self, embedding_dim: int, pooled_projection_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8 + ) -> None: super().__init__() - - self.time_proj = Timesteps( - num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0 - ) + + self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0) self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim) - self.pooler = MochiAttentionPool(num_attention_heads=num_attention_heads, embed_dim=pooled_projection_dim, output_dim=embedding_dim) + self.pooler = MochiAttentionPool( + num_attention_heads=num_attention_heads, embed_dim=pooled_projection_dim, output_dim=embedding_dim + ) self.caption_proj = nn.Linear(embedding_dim, pooled_projection_dim) - def forward(self, timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, hidden_dtype: Optional[torch.dtype] = None): + def forward( + self, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + hidden_dtype: Optional[torch.dtype] = None, + ): time_proj = self.time_proj(timestep) time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype)) @@ -1467,7 +1475,7 @@ class MochiAttentionPool(nn.Module): self.to_kv = nn.Linear(embed_dim, 2 * embed_dim) self.to_q = nn.Linear(embed_dim, embed_dim) self.to_out = nn.Linear(embed_dim, self.output_dim) - + @staticmethod def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor: """ @@ -1526,9 +1534,7 @@ class MochiAttentionPool(nn.Module): q = q.unsqueeze(2) # (B, H, 1, head_dim) # Compute attention. - x = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, dropout_p=0.0 - ) # (B, H, 1, head_dim) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim) # Concatenate heads and run output. x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 03e03b0c19..e11faee490 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -245,14 +245,18 @@ class MochiRMSNormZero(nn.Module): embedding_dim (`int`): The size of each embedding vector. """ - def __init__(self, embedding_dim: int, hidden_dim: int, norm_eps: float = 1e-5, elementwise_affine: bool = False) -> None: + def __init__( + self, embedding_dim: int, hidden_dim: int, norm_eps: float = 1e-5, elementwise_affine: bool = False + ) -> None: super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, hidden_dim) self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=elementwise_affine) - def forward(self, hidden_states: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def forward( + self, hidden_states: torch.Tensor, emb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = self.linear(self.silu(emb)) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 9ede9c2849..92227e215c 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.nn as nn @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward, JointAttnProcessor2_0 -from ..embeddings import PatchEmbed, MochiCombinedTimestepCaptionEmbedding +from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, MochiRMSNormZero, RMSNorm @@ -46,14 +46,14 @@ class MochiTransformerBlock(nn.Module): super().__init__() self.context_pre_only = context_pre_only - + self.norm1 = MochiRMSNormZero(dim, 4 * dim) if context_pre_only: self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim) else: self.norm1_context = RMSNorm(pooled_projection_dim, eps=1e-6, elementwise_affine=False) - + self.attn = Attention( query_dim=dim, heads=num_attention_heads, @@ -67,7 +67,7 @@ class MochiTransformerBlock(nn.Module): self.norm2 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) self.norm2_context = RMSNorm(pooled_projection_dim, eps=1e-6, elementwise_affine=False) - + self.norm3 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) self.norm3_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False) @@ -76,15 +76,23 @@ class MochiTransformerBlock(nn.Module): self.norm4 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) self.norm4_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False) - - def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) if self.context_pre_only: - norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(encoder_hidden_states, temb) + norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context( + encoder_hidden_states, temb + ) else: norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) - + attn_hidden_states, context_attn_hidden_states = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, @@ -94,16 +102,20 @@ class MochiTransformerBlock(nn.Module): hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) if not self.context_pre_only: - encoder_hidden_states = encoder_hidden_states + self.norm2_context(context_attn_hidden_states) * torch.tanh(enc_gate_msa).unsqueeze(1) - encoder_hidden_states = encoder_hidden_states + self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1)) - + encoder_hidden_states = encoder_hidden_states + self.norm2_context( + context_attn_hidden_states + ) * torch.tanh(enc_gate_msa).unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + self.norm3_context(encoder_hidden_states) * ( + 1 + enc_scale_mlp.unsqueeze(1) + ) + ff_output = self.ff(hidden_states) context_ff_output = self.ff_context(encoder_hidden_states) - + hidden_states = hidden_states + ff_output * torch.tanh(gate_mlp).unsqueeze(1) if not self.context_pre_only: encoder_hidden_states = encoder_hidden_states + context_ff_output * torch.tanh(enc_gate_mlp).unsqueeze(0) - + return hidden_states, encoder_hidden_states @@ -140,33 +152,35 @@ class MochiTransformer3D(ModelMixin, ConfigMixin): time_embed_dim=time_embed_dim, num_attention_heads=8, ) - + self.patch_embed = PatchEmbed( patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, ) - self.pos_frequencies = nn.Parameter( - torch.empty(3, num_attention_heads, attention_head_dim // 2) + self.pos_frequencies = nn.Parameter(torch.empty(3, num_attention_heads, attention_head_dim // 2)) + + self.transformer_blocks = nn.ModuleList( + [ + MochiTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + pooled_projection_dim=pooled_projection_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + context_pre_only=i < num_layers - 1, + ) + for i in range(num_layers) + ] ) - self.transformer_blocks = nn.ModuleList([ - MochiTransformerBlock( - dim=inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - pooled_projection_dim=pooled_projection_dim, - qk_norm=qk_norm, - activation_fn=activation_fn, - context_pre_only=i < num_layers - 1, - ) - for i in range(num_layers) - ]) - - self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm") + self.norm_out = AdaLayerNormContinuous( + inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm" + ) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) - + def forward( self, hidden_states: torch.Tensor, @@ -193,13 +207,13 @@ class MochiTransformer3D(ModelMixin, ConfigMixin): temb=temb, image_rotary_emb=image_rotary_emb, ) - + # TODO(aryan): do something with self.pos_frequencies hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_height, p, p, -1) + hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) output = hidden_states.reshape(batch_size, -1, num_frames, height, width) diff --git a/src/diffusers/models/transformers/transformer_mochi_original.py b/src/diffusers/models/transformers/transformer_mochi_original.py index 52bdfa0710..022492c6ce 100644 --- a/src/diffusers/models/transformers/transformer_mochi_original.py +++ b/src/diffusers/models/transformers/transformer_mochi_original.py @@ -2,7 +2,7 @@ import collections import functools import itertools import math -from typing import Any, Callable, Dict, Optional, List +from typing import Callable, Dict, List, Optional import torch import torch.nn as nn @@ -19,8 +19,10 @@ def _ntuple(n): return parse + to_2tuple = _ntuple(2) + def centers(start: float, stop, num, dtype=None, device=None): """linspace through bin centers. @@ -94,8 +96,7 @@ def compute_mixed_rotation( num_heads: int Returns: - freqs_cos: [N, num_heads, num_freqs] - cosine components - freqs_sin: [N, num_heads, num_freqs] - sine components + freqs_cos: [N, num_heads, num_freqs] - cosine components freqs_sin: [N, num_heads, num_freqs] - sine components """ with torch.autocast("cuda", enabled=False): assert freqs.ndim == 3 @@ -132,9 +133,7 @@ class TimestepEmbedder(nn.Module): args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: - embedding = torch.cat( - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 - ) + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): @@ -220,15 +219,17 @@ class PatchEmbed(nn.Module): device=device, ) assert norm_layer is None - self.norm = ( - norm_layer(embed_dim, device=device) if norm_layer else nn.Identity() - ) + self.norm = norm_layer(embed_dim, device=device) if norm_layer else nn.Identity() def forward(self, x): B, _C, T, H, W = x.shape if not self.dynamic_img_pad: - assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." - assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." + assert ( + H % self.patch_size[0] == 0 + ), f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." + assert ( + W % self.patch_size[1] == 0 + ), f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." else: pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] @@ -337,9 +338,7 @@ class AttentionPool(nn.Module): q = q.unsqueeze(2) # (B, H, 1, head_dim) # Compute attention. - x = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, dropout_p=0.0 - ) # (B, H, 1, head_dim) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim) # Concatenate heads and run output. x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) @@ -470,8 +469,7 @@ class AsymmetricJointBlock(nn.Module): num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens Returns: - x: (B, N, dim) tensor of visual tokens after block - y: (B, L, dim) tensor of text tokens after block + x: (B, N, dim) tensor of visual tokens after block y: (B, L, dim) tensor of text tokens after block """ N = x.size(1) @@ -540,9 +538,7 @@ class AsymmetricAttention(nn.Module): self.update_y = update_y self.softmax_scale = softmax_scale if dim_x % num_heads != 0: - raise ValueError( - f"dim_x={dim_x} should be divisible by num_heads={num_heads}" - ) + raise ValueError(f"dim_x={dim_x} should be divisible by num_heads={num_heads}") # Input layers. self.qkv_bias = qkv_bias @@ -559,11 +555,7 @@ class AsymmetricAttention(nn.Module): # Output layers. y features go back down from dim_x -> dim_y. self.proj_x = nn.Linear(dim_x, dim_x, bias=out_bias, device=device) - self.proj_y = ( - nn.Linear(dim_x, dim_y, bias=out_bias, device=device) - if update_y - else nn.Identity() - ) + self.proj_y = nn.Linear(dim_x, dim_y, bias=out_bias, device=device) if update_y else nn.Identity() # def run_qkv_y(self, y): # cp_rank, cp_size = cp.get_cp_rank_size() @@ -676,16 +668,12 @@ class AsymmetricAttention(nn.Module): # ) -> Tuple[torch.Tensor, torch.Tensor]: # """Forward pass of asymmetric multi-modal attention. - # Args: - # x: (B, N, dim_x) tensor for visual tokens - # y: (B, L, dim_y) tensor of text token features - # packed_indices: Dict with keys for Flash Attention - # num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens + # Args: # x: (B, N, dim_x) tensor for visual tokens # y: (B, L, dim_y) tensor of text token features # + packed_indices: Dict with keys for Flash Attention # num_frames: Number of frames in the video. N = num_frames * + num_spatial_tokens - # Returns: - # x: (B, N, dim_x) tensor of visual tokens after multi-modal attention - # y: (B, L, dim_y) tensor of text token features after multi-modal attention - # """ + # Returns: # x: (B, N, dim_x) tensor of visual tokens after multi-modal attention # y: (B, L, dim_y) tensor of text + token features after multi-modal attention #""" # B, L, _ = y.shape # _, M, _ = x.shape @@ -726,13 +714,9 @@ class FinalLayer(nn.Module): device: Optional[torch.device] = None, ): super().__init__() - self.norm_final = nn.LayerNorm( - hidden_size, elementwise_affine=False, eps=1e-6, device=device - ) + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, device=device) self.mod = nn.Linear(hidden_size, 2 * hidden_size, device=device) - self.linear = nn.Linear( - hidden_size, patch_size * patch_size * out_channels, device=device - ) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, device=device) def forward(self, x, c): c = F.silu(c) @@ -777,15 +761,11 @@ class MochiTransformer3DModel(nn.Module): self.num_heads = num_heads self.hidden_size_x = hidden_size_x self.hidden_size_y = hidden_size_y - self.head_dim = ( - hidden_size_x // num_heads - ) # Head dimension and count is determined by visual. + self.head_dim = hidden_size_x // num_heads # Head dimension and count is determined by visual. self.use_extended_posenc = use_extended_posenc self.t5_token_length = t5_token_length self.t5_feat_dim = t5_feat_dim - self.rope_theta = ( - rope_theta # Scaling factor for frequency computation for temporal RoPE. - ) + self.rope_theta = rope_theta # Scaling factor for frequency computation for temporal RoPE. self.x_embedder = PatchEmbed( patch_size=patch_size, @@ -796,24 +776,16 @@ class MochiTransformer3DModel(nn.Module): ) # Conditionings # Timestep - self.t_embedder = TimestepEmbedder( - hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale - ) + self.t_embedder = TimestepEmbedder(hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale) # Caption Pooling (T5) - self.t5_y_embedder = AttentionPool( - t5_feat_dim, num_heads=8, output_dim=hidden_size_x, device=device - ) + self.t5_y_embedder = AttentionPool(t5_feat_dim, num_heads=8, output_dim=hidden_size_x, device=device) # Dense Embedding Projection (T5) - self.t5_yproj = nn.Linear( - t5_feat_dim, hidden_size_y, bias=True, device=device - ) + self.t5_yproj = nn.Linear(t5_feat_dim, hidden_size_y, bias=True, device=device) # Initialize pos_frequencies as an empty parameter. - self.pos_frequencies = nn.Parameter( - torch.empty(3, self.num_heads, self.head_dim // 2, device=device) - ) + self.pos_frequencies = nn.Parameter(torch.empty(3, self.num_heads, self.head_dim // 2, device=device)) # for depth 48: # b = 0: AsymmetricJointBlock, update_y=True @@ -839,9 +811,7 @@ class MochiTransformer3DModel(nn.Module): blocks.append(block) self.blocks = nn.ModuleList(blocks) - self.final_layer = FinalLayer( - hidden_size_x, patch_size, self.out_channels, device=device - ) + self.final_layer = FinalLayer(hidden_size_x, patch_size, self.out_channels, device=device) def embed_x(self, x: torch.Tensor) -> torch.Tensor: """ @@ -878,9 +848,7 @@ class MochiTransformer3DModel(nn.Module): pH, pW = H // self.patch_size, W // self.patch_size N = T * pH * pW assert x.size(1) == N - pos = create_position_matrix( - T, pH=pH, pW=pW, device=x.device, dtype=torch.float32 - ) # (N, 3) + pos = create_position_matrix(T, pH=pH, pW=pW, device=x.device, dtype=torch.float32) # (N, 3) rope_cos, rope_sin = compute_mixed_rotation( freqs=self.pos_frequencies, pos=pos ) # Each are (N, num_heads, dim // 2) @@ -896,9 +864,7 @@ class MochiTransformer3DModel(nn.Module): t5_feat.size(1) == self.t5_token_length ), f"Expected L={self.t5_token_length}, got {t5_feat.shape} for y_feat." t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D) - assert ( - t5_y_pool.size(0) == B - ), f"Expected B={B}, got {t5_y_pool.shape} for t5_y_pool." + assert t5_y_pool.size(0) == B, f"Expected B={B}, got {t5_y_pool.shape} for t5_y_pool." c = c_t + t5_y_pool @@ -921,15 +887,15 @@ class MochiTransformer3DModel(nn.Module): Args: x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images) sigma: (B,) tensor of noise standard deviations - y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048) + y_feat: + List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, + y_feat_dim=2048) y_mask: List((B, L) boolean tensor indicating which tokens are not padding) packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices. """ B, _, T, H, W = x.shape - x, c, y_feat, rope_cos, rope_sin = self.prepare( - x, sigma, y_feat[0], y_mask[0] - ) + x, c, y_feat, rope_cos, rope_sin = self.prepare(x, sigma, y_feat[0], y_mask[0]) del y_mask for i, block in enumerate(self.blocks): From 0e9e281ad1dd13dcccf81ba6217d4e00514eca22 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 02:40:31 +0200 Subject: [PATCH 05/19] fix --- src/diffusers/models/attention_processor.py | 212 ++++++++++++++++++ src/diffusers/models/embeddings.py | 6 +- .../models/transformers/transformer_mochi.py | 52 +++-- .../transformer_mochi_original.py | 24 +- 4 files changed, 258 insertions(+), 36 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e735c4ee7d..bf8e67a2ab 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -717,6 +717,218 @@ class Attention(nn.Module): self.fused_projections = fuse +class AsymmetricAttention(nn.Module): + def __init__( + self, + query_dim: int, + query_context_dim: int, + num_attention_heads: int = 8, + attention_head_dim: int = 64, + bias: bool = False, + context_bias: bool = False, + out_dim: Optional[int] = None, + out_context_dim: Optional[int] = None, + qk_norm: Optional[str] = None, + eps: float = 1e-5, + elementwise_affine: bool = True, + processor: Optional["AttnProcessor"] = None, + ) -> None: + super().__init__() + + from .normalization import RMSNorm + + self.query_dim = query_dim + self.query_context_dim = query_context_dim + self.inner_dim = out_dim if out_dim is not None else num_attention_heads * attention_head_dim + self.out_dim = out_dim if out_dim is not None else query_dim + + self.scale = attention_head_dim ** -0.5 + self.num_attention_heads = out_dim // attention_head_dim if out_dim is not None else num_attention_heads + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + self.norm_context_q = None + self.norm_context_k = None + elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) + self.norm_context_q = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) + self.norm_context_k = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) + else: + raise ValueError((f"Unknown qk_norm: {qk_norm}. Should be None or `rms_norm`.")) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) + + self.to_context_q = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias) + self.to_context_k = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias) + self.to_context_v = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias) + + # TODO(aryan): Take care of dropouts for training purpose in future + self.to_out = nn.ModuleList([ + nn.Linear(self.inner_dim, self.out_dim) + ]) + + self.to_context_out = None + if out_context_dim is not None: + self.to_context_out = nn.ModuleList([ + nn.Linear(self.inner_dim, out_context_dim) + ]) + + if processor is None: + processor = AsymmetricAttnProcessor2_0() + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self) -> "AttentionProcessor": + r""" + Get the attention processor in use. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + return self.processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + +class AsymmetricAttnProcessor2_0: + r""" + Processor for implementing Asymmetric SDPA as described in Genmo/Mochi (TODO(aryan) add link). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AsymmetricAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: AsymmetricAttention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + batch_size = hidden_states.size(0) + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query_context = attn.to_context_q(encoder_hidden_states) + key_context = attn.to_context_k(encoder_hidden_states) + value_context = attn.to_context_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim / attn.num_attention_heads + + query = query.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) + key = key.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) + value = value.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) + + query_context = query_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) + key_context = key_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) + value_context = value_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + if attn.norm_context_q is not None: + query_context = attn.norm_context_q(query_context) + if attn.norm_context_k is not None: + key_context = attn.norm_context_k(key_context) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + sequence_length = query.size(1) + context_sequence_length = query_context.size(1) + + query = torch.cat([query, query_context], dim=1) + key = torch.cat([key, key_context], dim=1) + value = torch.cat([value, value_context], dim=1) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes([sequence_length, context_sequence_length], dim=1) + + hidden_states = attn.to_out[0](hidden_states) + encoder_hidden_states = attn.to_context_out[0](encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + class AttnProcessor: r""" Default processor for performing attention-related computations. diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 896f479139..3788829f16 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1304,16 +1304,16 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module): class MochiCombinedTimestepCaptionEmbedding(nn.Module): def __init__( - self, embedding_dim: int, pooled_projection_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8 + self, embedding_dim: int, pooled_projection_dim: int, text_embed_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8 ) -> None: super().__init__() self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0) self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim) self.pooler = MochiAttentionPool( - num_attention_heads=num_attention_heads, embed_dim=pooled_projection_dim, output_dim=embedding_dim + num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim ) - self.caption_proj = nn.Linear(embedding_dim, pooled_projection_dim) + self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim) def forward( self, diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 92227e215c..2fa8f77443 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -21,7 +21,8 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention, FeedForward, JointAttnProcessor2_0 +from ..attention import FeedForward +from ..attention_processor import AsymmetricAttention, AsymmetricAttnProcessor2_0 from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -46,23 +47,27 @@ class MochiTransformerBlock(nn.Module): super().__init__() self.context_pre_only = context_pre_only + self.ff_inner_dim = (4 * dim * 2) // 3 + self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3 self.norm1 = MochiRMSNormZero(dim, 4 * dim) - if context_pre_only: + if not context_pre_only: self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim) else: - self.norm1_context = RMSNorm(pooled_projection_dim, eps=1e-6, elementwise_affine=False) + self.norm1_context = nn.Linear(dim, pooled_projection_dim) - self.attn = Attention( + self.attn = AsymmetricAttention( query_dim=dim, - heads=num_attention_heads, + query_context_dim=pooled_projection_dim, + num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, - out_dim=4 * dim, + out_dim=dim, + out_context_dim=None if context_pre_only else pooled_projection_dim, qk_norm=qk_norm, eps=1e-6, elementwise_affine=False, - processor=JointAttnProcessor2_0(), + processor=AsymmetricAttnProcessor2_0(), ) self.norm2 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) @@ -71,8 +76,10 @@ class MochiTransformerBlock(nn.Module): self.norm3 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) self.norm3_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False) - self.ff = FeedForward(dim, mult=4, activation_fn=activation_fn) - self.ff_context = FeedForward(pooled_projection_dim, mult=4, activation_fn=activation_fn) + self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) + self.ff_context = None + if not context_pre_only: + self.ff_context = FeedForward(pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False) self.norm4 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) self.norm4_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False) @@ -110,10 +117,10 @@ class MochiTransformerBlock(nn.Module): ) ff_output = self.ff(hidden_states) - context_ff_output = self.ff_context(encoder_hidden_states) - hidden_states = hidden_states + ff_output * torch.tanh(gate_mlp).unsqueeze(1) + if not self.context_pre_only: + context_ff_output = self.ff_context(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + context_ff_output * torch.tanh(enc_gate_mlp).unsqueeze(0) return hidden_states, encoder_hidden_states @@ -131,11 +138,9 @@ class MochiTransformer3D(ModelMixin, ConfigMixin): attention_head_dim: int = 128, num_layers: int = 48, pooled_projection_dim: int = 1536, - in_channels=12, + in_channels: int = 12, out_channels: Optional[int] = None, qk_norm: str = "rms_norm", - timestep_mlp_bias=True, - timestep_scale=1000.0, text_embed_dim: int = 4096, time_embed_dim: int = 256, activation_fn: str = "swiglu", @@ -146,19 +151,20 @@ class MochiTransformer3D(ModelMixin, ConfigMixin): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels - self.time_embed = MochiCombinedTimestepCaptionEmbedding( - embedding_dim=text_embed_dim, - pooled_projection_dim=pooled_projection_dim, - time_embed_dim=time_embed_dim, - num_attention_heads=8, - ) - self.patch_embed = PatchEmbed( patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, ) + self.time_embed = MochiCombinedTimestepCaptionEmbedding( + embedding_dim=inner_dim, + pooled_projection_dim=pooled_projection_dim, + text_embed_dim=text_embed_dim, + time_embed_dim=time_embed_dim, + num_attention_heads=8, + ) + self.pos_frequencies = nn.Parameter(torch.empty(3, num_attention_heads, attention_head_dim // 2)) self.transformer_blocks = nn.ModuleList( @@ -170,7 +176,7 @@ class MochiTransformer3D(ModelMixin, ConfigMixin): pooled_projection_dim=pooled_projection_dim, qk_norm=qk_norm, activation_fn=activation_fn, - context_pre_only=i < num_layers - 1, + context_pre_only=i == num_layers - 1, ) for i in range(num_layers) ] @@ -196,7 +202,7 @@ class MochiTransformer3D(ModelMixin, ConfigMixin): post_patch_height = height // p post_patch_width = width // p - temb, caption_proj = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask) + temb, encoder_hidden_states = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask) hidden_states = self.patch_embed(hidden_states) diff --git a/src/diffusers/models/transformers/transformer_mochi_original.py b/src/diffusers/models/transformers/transformer_mochi_original.py index 022492c6ce..9c8924decb 100644 --- a/src/diffusers/models/transformers/transformer_mochi_original.py +++ b/src/diffusers/models/transformers/transformer_mochi_original.py @@ -96,7 +96,8 @@ def compute_mixed_rotation( num_heads: int Returns: - freqs_cos: [N, num_heads, num_freqs] - cosine components freqs_sin: [N, num_heads, num_freqs] - sine components + freqs_cos: [N, num_heads, num_freqs] - cosine components + freqs_sin: [N, num_heads, num_freqs] - sine components """ with torch.autocast("cuda", enabled=False): assert freqs.ndim == 3 @@ -469,7 +470,8 @@ class AsymmetricJointBlock(nn.Module): num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens Returns: - x: (B, N, dim) tensor of visual tokens after block y: (B, L, dim) tensor of text tokens after block + x: (B, N, dim) tensor of visual tokens after block + y: (B, L, dim) tensor of text tokens after block """ N = x.size(1) @@ -668,12 +670,16 @@ class AsymmetricAttention(nn.Module): # ) -> Tuple[torch.Tensor, torch.Tensor]: # """Forward pass of asymmetric multi-modal attention. - # Args: # x: (B, N, dim_x) tensor for visual tokens # y: (B, L, dim_y) tensor of text token features # - packed_indices: Dict with keys for Flash Attention # num_frames: Number of frames in the video. N = num_frames * - num_spatial_tokens + # Args: + # x: (B, N, dim_x) tensor for visual tokens + # y: (B, L, dim_y) tensor of text token features + # packed_indices: Dict with keys for Flash Attention + # num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens - # Returns: # x: (B, N, dim_x) tensor of visual tokens after multi-modal attention # y: (B, L, dim_y) tensor of text - token features after multi-modal attention #""" + # Returns: + # x: (B, N, dim_x) tensor of visual tokens after multi-modal attention + # y: (B, L, dim_y) tensor of text token features after multi-modal attention + # """ # B, L, _ = y.shape # _, M, _ = x.shape @@ -887,9 +893,7 @@ class MochiTransformer3DModel(nn.Module): Args: x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images) sigma: (B,) tensor of noise standard deviations - y_feat: - List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, - y_feat_dim=2048) + y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048) y_mask: List((B, L) boolean tensor indicating which tokens are not padding) packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices. """ From c2a155714b5c477b0b849bd0420c09176529c761 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 03:48:10 +0200 Subject: [PATCH 06/19] add conversion script --- scripts/convert_mochi_to_diffusers.py | 185 ++++++++++++++++++ src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/attention_processor.py | 5 +- src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_mochi.py | 8 +- src/diffusers/utils/dummy_pt_objects.py | 15 ++ 7 files changed, 213 insertions(+), 5 deletions(-) create mode 100644 scripts/convert_mochi_to_diffusers.py diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py new file mode 100644 index 0000000000..da3cae97da --- /dev/null +++ b/scripts/convert_mochi_to_diffusers.py @@ -0,0 +1,185 @@ +import argparse +from contextlib import nullcontext + +import torch +from accelerate import init_empty_weights +from safetensors.torch import load_file +# from transformers import T5EncoderModel, T5Tokenizer + +from diffusers import MochiTransformer3DModel +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +TOKENIZER_MAX_LENGTH = 224 + +parser = argparse.ArgumentParser() +parser.add_argument("--transformer_checkpoint_path", default=None, type=str) +# parser.add_argument("--vae_checkpoint_path", default=None, type=str) +parser.add_argument("--output_path", required=True, type=str) +parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving") +parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory") +parser.add_argument("--dtype", type=str, default=None) + +args = parser.parse_args() + + +# This is specific to `AdaLayerNormContinuous`: +# Diffusers implementation split the linear projection into the scale, shift while Mochi split it into shift, scale +def swap_scale_shift(weight, dim): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path): + original_state_dict = load_file(ckpt_path, device="cpu") + new_state_dict = {} + + # Convert patch_embed + new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight") + new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias") + + # Convert time_embed + new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight") + new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias") + new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight") + new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias") + new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight") + new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias") + new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight") + new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias") + new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight") + new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias") + new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight") + new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias") + + # Convert transformer blocks + num_layers = 48 + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + old_prefix = f"blocks.{i}." + + # norm1 + new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight") + new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias") + if i < num_layers - 1: + new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop( + old_prefix + "mod_y.weight" + ) + new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop( + old_prefix + "mod_y.bias" + ) + else: + new_state_dict[block_prefix + "norm1_context.weight"] = original_state_dict.pop( + old_prefix + "mod_y.weight" + ) + new_state_dict[block_prefix + "norm1_context.bias"] = original_state_dict.pop(old_prefix + "mod_y.bias") + + # Visual attention + qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.to_q.weight"] = q + new_state_dict[block_prefix + "attn1.to_k.weight"] = k + new_state_dict[block_prefix + "attn1.to_v.weight"] = v + new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop( + old_prefix + "attn.q_norm_x.weight" + ) + new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop( + old_prefix + "attn.k_norm_x.weight" + ) + new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop( + old_prefix + "attn.proj_x.weight" + ) + new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias") + + # Context attention + qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.to_context_q.weight"] = q + new_state_dict[block_prefix + "attn1.to_context_k.weight"] = k + new_state_dict[block_prefix + "attn1.to_context_v.weight"] = v + new_state_dict[block_prefix + "attn1.norm_context_q.weight"] = original_state_dict.pop( + old_prefix + "attn.q_norm_y.weight" + ) + new_state_dict[block_prefix + "attn1.norm_context_k.weight"] = original_state_dict.pop( + old_prefix + "attn.k_norm_y.weight" + ) + if i < num_layers - 1: + new_state_dict[block_prefix + "attn1.to_context_out.0.weight"] = original_state_dict.pop( + old_prefix + "attn.proj_y.weight" + ) + new_state_dict[block_prefix + "attn1.to_context_out.0.bias"] = original_state_dict.pop( + old_prefix + "attn.proj_y.bias" + ) + + # MLP + new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w1.weight") + new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight") + if i < num_layers - 1: + new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = original_state_dict.pop( + old_prefix + "mlp_y.w1.weight" + ) + new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop( + old_prefix + "mlp_y.w2.weight" + ) + + # Output layers + new_state_dict["norm_out.linear.weight"] = original_state_dict.pop("final_layer.mod.weight") + new_state_dict["norm_out.linear.bias"] = original_state_dict.pop("final_layer.mod.bias") + new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") + new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") + + new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies") + + print("Remaining Keys:", original_state_dict.keys()) + + return new_state_dict + + +# def convert_mochi_vae_checkpoint_to_diffusers(ckpt_path, vae_config): +# original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] +# return convert_ldm_vae_checkpoint(original_state_dict, vae_config) + + +def main(args): + if args.dtype is None: + dtype = None + if args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + elif args.dtype == "fp32": + dtype = torch.float32 + else: + raise ValueError(f"Unsupported dtype: {args.dtype}") + + transformer = None + vae = None + + if args.transformer_checkpoint_path is not None: + converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers( + args.transformer_checkpoint_path + ) + transformer = MochiTransformer3DModel() + transformer.load_state_dict(converted_transformer_state_dict, strict=True) + if dtype is not None: + # Original checkpoint data type will be preserved + transformer = transformer.to(dtype=dtype) + + # text_encoder_id = "google/t5-v1_1-xxl" + # tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) + # text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + + # # Apparently, the conversion does not work anymore without this :shrug: + # for param in text_encoder.parameters(): + # param.data = param.data.contiguous() + + transformer.save_pretrained("/raid/aryan/mochi-diffusers", subfolder="transformer") + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 789458a262..c71cbfd5a4 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -100,6 +100,7 @@ else: "Kandinsky3UNet", "LatteTransformer3DModel", "LuminaNextDiT2DModel", + "MochiTransformer3DModel", "ModelMixin", "MotionAdapter", "MultiAdapter", @@ -579,6 +580,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: Kandinsky3UNet, LatteTransformer3DModel, LuminaNextDiT2DModel, + MochiTransformer3DModel, ModelMixin, MotionAdapter, MultiAdapter, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 4dda8c36ba..27177b2adc 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -56,6 +56,7 @@ if is_torch_available(): _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] + _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] @@ -106,6 +107,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: HunyuanDiT2DModel, LatteTransformer3DModel, LuminaNextDiT2DModel, + MochiTransformer3DModel, PixArtTransformer2DModel, PriorTransformer, SD3Transformer2DModel, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index bf8e67a2ab..d298793bb6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -771,11 +771,14 @@ class AsymmetricAttention(nn.Module): nn.Linear(self.inner_dim, self.out_dim) ]) - self.to_context_out = None if out_context_dim is not None: self.to_context_out = nn.ModuleList([ nn.Linear(self.inner_dim, out_context_dim) ]) + else: + self.to_context_out = nn.ModuleList([ + nn.Identity() + ]) if processor is None: processor = AsymmetricAttnProcessor2_0() diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 58787c079e..e1c2c1edf1 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -16,5 +16,6 @@ if is_torch_available(): from .transformer_2d import Transformer2DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel + from .transformer_mochi import MochiTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 2fa8f77443..bcf55f780d 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -57,7 +57,7 @@ class MochiTransformerBlock(nn.Module): else: self.norm1_context = nn.Linear(dim, pooled_projection_dim) - self.attn = AsymmetricAttention( + self.attn1 = AsymmetricAttention( query_dim=dim, query_context_dim=pooled_projection_dim, num_attention_heads=num_attention_heads, @@ -66,7 +66,7 @@ class MochiTransformerBlock(nn.Module): out_context_dim=None if context_pre_only else pooled_projection_dim, qk_norm=qk_norm, eps=1e-6, - elementwise_affine=False, + elementwise_affine=True, processor=AsymmetricAttnProcessor2_0(), ) @@ -100,7 +100,7 @@ class MochiTransformerBlock(nn.Module): else: norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) - attn_hidden_states, context_attn_hidden_states = self.attn( + attn_hidden_states, context_attn_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, @@ -127,7 +127,7 @@ class MochiTransformerBlock(nn.Module): @maybe_allow_in_graph -class MochiTransformer3D(ModelMixin, ConfigMixin): +class MochiTransformer3DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 10d0399a67..908865be5d 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -347,6 +347,21 @@ class LuminaNextDiT2DModel(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class MochiTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ModelMixin(metaclass=DummyObject): _backends = ["torch"] From be5bbe53e149d56ed9b91bb0a37e0f635d47d81c Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 03:48:31 +0200 Subject: [PATCH 07/19] update --- scripts/convert_mochi_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py index da3cae97da..24c3954915 100644 --- a/scripts/convert_mochi_to_diffusers.py +++ b/scripts/convert_mochi_to_diffusers.py @@ -12,7 +12,7 @@ from diffusers.utils.import_utils import is_accelerate_available CTX = init_empty_weights if is_accelerate_available else nullcontext -TOKENIZER_MAX_LENGTH = 224 +TOKENIZER_MAX_LENGTH = 256 parser = argparse.ArgumentParser() parser.add_argument("--transformer_checkpoint_path", default=None, type=str) From 1e9bc91b5cf5e392a3306c5342950eb5b9aa6400 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 04:02:37 +0200 Subject: [PATCH 08/19] fix --- src/diffusers/models/transformers/transformer_mochi.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index bcf55f780d..a898840d2d 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -107,20 +107,21 @@ class MochiTransformerBlock(nn.Module): ) hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) - hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + if not self.context_pre_only: encoder_hidden_states = encoder_hidden_states + self.norm2_context( context_attn_hidden_states ) * torch.tanh(enc_gate_msa).unsqueeze(1) - encoder_hidden_states = encoder_hidden_states + self.norm3_context(encoder_hidden_states) * ( + norm_encoder_hidden_states = encoder_hidden_states + self.norm3_context(encoder_hidden_states) * ( 1 + enc_scale_mlp.unsqueeze(1) ) - ff_output = self.ff(hidden_states) + ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + ff_output * torch.tanh(gate_mlp).unsqueeze(1) if not self.context_pre_only: - context_ff_output = self.ff_context(encoder_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + context_ff_output * torch.tanh(enc_gate_mlp).unsqueeze(0) return hidden_states, encoder_hidden_states From 98a4554ac6b31002a9881897881346c96ebc1ee8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 05:04:43 +0200 Subject: [PATCH 09/19] update --- scripts/convert_mochi_to_diffusers.py | 14 +- src/diffusers/models/attention_processor.py | 219 +----------------- .../models/transformers/transformer_mochi.py | 27 ++- 3 files changed, 26 insertions(+), 234 deletions(-) diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py index 24c3954915..83f642f65c 100644 --- a/scripts/convert_mochi_to_diffusers.py +++ b/scripts/convert_mochi_to_diffusers.py @@ -99,20 +99,20 @@ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path): qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight") q, k, v = qkv_weight.chunk(3, dim=0) - new_state_dict[block_prefix + "attn1.to_context_q.weight"] = q - new_state_dict[block_prefix + "attn1.to_context_k.weight"] = k - new_state_dict[block_prefix + "attn1.to_context_v.weight"] = v - new_state_dict[block_prefix + "attn1.norm_context_q.weight"] = original_state_dict.pop( + new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q + new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k + new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v + new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop( old_prefix + "attn.q_norm_y.weight" ) - new_state_dict[block_prefix + "attn1.norm_context_k.weight"] = original_state_dict.pop( + new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop( old_prefix + "attn.k_norm_y.weight" ) if i < num_layers - 1: - new_state_dict[block_prefix + "attn1.to_context_out.0.weight"] = original_state_dict.pop( + new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop( old_prefix + "attn.proj_y.weight" ) - new_state_dict[block_prefix + "attn1.to_context_out.0.bias"] = original_state_dict.pop( + new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop( old_prefix + "attn.proj_y.bias" ) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d298793bb6..cfbc2bc140 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -120,6 +120,7 @@ class Attention(nn.Module): _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, out_dim: int = None, + out_context_dim: int = None, context_pre_only=None, pre_only=False, elementwise_affine: bool = True, @@ -142,6 +143,7 @@ class Attention(nn.Module): self.dropout = dropout self.fused_projections = False self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim self.context_pre_only = context_pre_only self.pre_only = pre_only @@ -241,7 +243,7 @@ class Attention(nn.Module): self.to_out.append(nn.Dropout(dropout)) if self.context_pre_only is not None and not self.context_pre_only: - self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) if qk_norm is not None and added_kv_proj_dim is not None: if qk_norm == "fp32_layer_norm": @@ -717,221 +719,6 @@ class Attention(nn.Module): self.fused_projections = fuse -class AsymmetricAttention(nn.Module): - def __init__( - self, - query_dim: int, - query_context_dim: int, - num_attention_heads: int = 8, - attention_head_dim: int = 64, - bias: bool = False, - context_bias: bool = False, - out_dim: Optional[int] = None, - out_context_dim: Optional[int] = None, - qk_norm: Optional[str] = None, - eps: float = 1e-5, - elementwise_affine: bool = True, - processor: Optional["AttnProcessor"] = None, - ) -> None: - super().__init__() - - from .normalization import RMSNorm - - self.query_dim = query_dim - self.query_context_dim = query_context_dim - self.inner_dim = out_dim if out_dim is not None else num_attention_heads * attention_head_dim - self.out_dim = out_dim if out_dim is not None else query_dim - - self.scale = attention_head_dim ** -0.5 - self.num_attention_heads = out_dim // attention_head_dim if out_dim is not None else num_attention_heads - - if qk_norm is None: - self.norm_q = None - self.norm_k = None - self.norm_context_q = None - self.norm_context_k = None - elif qk_norm == "rms_norm": - self.norm_q = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) - self.norm_k = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) - self.norm_context_q = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) - self.norm_context_k = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine) - else: - raise ValueError((f"Unknown qk_norm: {qk_norm}. Should be None or `rms_norm`.")) - - self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) - - self.to_context_q = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias) - self.to_context_k = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias) - self.to_context_v = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias) - - # TODO(aryan): Take care of dropouts for training purpose in future - self.to_out = nn.ModuleList([ - nn.Linear(self.inner_dim, self.out_dim) - ]) - - if out_context_dim is not None: - self.to_context_out = nn.ModuleList([ - nn.Linear(self.inner_dim, out_context_dim) - ]) - else: - self.to_context_out = nn.ModuleList([ - nn.Identity() - ]) - - if processor is None: - processor = AsymmetricAttnProcessor2_0() - - self.set_processor(processor) - - def set_processor(self, processor: "AttnProcessor") -> None: - r""" - Set the attention processor to use. - - Args: - processor (`AttnProcessor`): - The attention processor to use. - """ - # if current processor is in `self._modules` and if passed `processor` is not, we need to - # pop `processor` from `self._modules` - if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) - ): - logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") - self._modules.pop("processor") - - self.processor = processor - - def get_processor(self) -> "AttentionProcessor": - r""" - Get the attention processor in use. - - Returns: - "AttentionProcessor": The attention processor in use. - """ - return self.processor - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **cross_attention_kwargs, - ) -> torch.Tensor: - r""" - The forward method of the `Attention` class. - - Args: - hidden_states (`torch.Tensor`): - The hidden states of the query. - encoder_hidden_states (`torch.Tensor`, *optional*): - The hidden states of the encoder. - attention_mask (`torch.Tensor`, *optional*): - The attention mask to use. If `None`, no mask is applied. - **cross_attention_kwargs: - Additional keyword arguments to pass along to the cross attention. - - Returns: - `torch.Tensor`: The output of the attention layer. - """ - # The `Attention` class can call different attention processors / attention functions - # here we simply pass along all tensors to the selected processor class - # For standard processors that are defined here, `**cross_attention_kwargs` is empty - - attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"ip_adapter_masks"} - unused_kwargs = [ - k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters - ] - if len(unused_kwargs) > 0: - logger.warning( - f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." - ) - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} - - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - -class AsymmetricAttnProcessor2_0: - r""" - Processor for implementing Asymmetric SDPA as described in Genmo/Mochi (TODO(aryan) add link). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AsymmetricAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: AsymmetricAttention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> torch.Tensor: - batch_size = hidden_states.size(0) - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - query_context = attn.to_context_q(encoder_hidden_states) - key_context = attn.to_context_k(encoder_hidden_states) - value_context = attn.to_context_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim / attn.num_attention_heads - - query = query.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) - key = key.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) - value = value.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) - - query_context = query_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) - key_context = key_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) - value_context = value_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - if attn.norm_context_q is not None: - query_context = attn.norm_context_q(query_context) - if attn.norm_context_k is not None: - key_context = attn.norm_context_k(key_context) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - sequence_length = query.size(1) - context_sequence_length = query_context.size(1) - - query = torch.cat([query, query_context], dim=1) - key = torch.cat([key, key_context], dim=1) - value = torch.cat([value, value_context], dim=1) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) - hidden_states, encoder_hidden_states = hidden_states.split_with_sizes([sequence_length, context_sequence_length], dim=1) - - hidden_states = attn.to_out[0](hidden_states) - encoder_hidden_states = attn.to_context_out[0](encoder_hidden_states) - - return hidden_states, encoder_hidden_states - - class AttnProcessor: r""" Default processor for performing attention-related computations. diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index a898840d2d..8aa7e48d3f 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward -from ..attention_processor import AsymmetricAttention, AsymmetricAttnProcessor2_0 +from ..attention_processor import Attention, FluxAttnProcessor2_0 from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -57,17 +57,21 @@ class MochiTransformerBlock(nn.Module): else: self.norm1_context = nn.Linear(dim, pooled_projection_dim) - self.attn1 = AsymmetricAttention( + self.attn1 = Attention( query_dim=dim, - query_context_dim=pooled_projection_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - out_dim=dim, - out_context_dim=None if context_pre_only else pooled_projection_dim, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=False, qk_norm=qk_norm, + added_kv_proj_dim=pooled_projection_dim, + added_proj_bias=False, + out_dim=dim, + out_context_dim=pooled_projection_dim, + context_pre_only=context_pre_only, + processor=FluxAttnProcessor2_0(), eps=1e-6, elementwise_affine=True, - processor=AsymmetricAttnProcessor2_0(), ) self.norm2 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) @@ -93,7 +97,7 @@ class MochiTransformerBlock(nn.Module): ) -> Tuple[torch.Tensor, torch.Tensor]: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) - if self.context_pre_only: + if not self.context_pre_only: norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context( encoder_hidden_states, temb ) @@ -203,9 +207,11 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): post_patch_height = height // p post_patch_width = width // p - temb, encoder_hidden_states = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask) + temb, encoder_hidden_states = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.patch_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) for i, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = block( @@ -216,7 +222,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): ) # TODO(aryan): do something with self.pos_frequencies - hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From 85c8734cdcac9a37691d31fb39e2d6bea1ea9c29 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 05:47:03 +0200 Subject: [PATCH 10/19] fix --- scripts/convert_mochi_to_diffusers.py | 10 ++++---- src/diffusers/models/attention_processor.py | 4 +++- src/diffusers/models/embeddings.py | 7 +++++- src/diffusers/models/normalization.py | 13 ++++++----- .../models/transformers/transformer_mochi.py | 23 ++++++++++++++----- 5 files changed, 39 insertions(+), 18 deletions(-) diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py index 83f642f65c..1d1d10a6ad 100644 --- a/scripts/convert_mochi_to_diffusers.py +++ b/scripts/convert_mochi_to_diffusers.py @@ -4,8 +4,8 @@ from contextlib import nullcontext import torch from accelerate import init_empty_weights from safetensors.torch import load_file -# from transformers import T5EncoderModel, T5Tokenizer +# from transformers import T5EncoderModel, T5Tokenizer from diffusers import MochiTransformer3DModel from diffusers.utils.import_utils import is_accelerate_available @@ -72,10 +72,12 @@ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path): old_prefix + "mod_y.bias" ) else: - new_state_dict[block_prefix + "norm1_context.weight"] = original_state_dict.pop( + new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop( old_prefix + "mod_y.weight" ) - new_state_dict[block_prefix + "norm1_context.bias"] = original_state_dict.pop(old_prefix + "mod_y.bias") + new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop( + old_prefix + "mod_y.bias" + ) # Visual attention qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight") @@ -158,7 +160,7 @@ def main(args): raise ValueError(f"Unsupported dtype: {args.dtype}") transformer = None - vae = None + # vae = None if args.transformer_checkpoint_path is not None: converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers( diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cfbc2bc140..ce0f9d87c8 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1794,7 +1794,9 @@ class FluxAttnProcessor2_0: hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if hasattr(attn, "to_add_out"): + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states else: diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 3788829f16..3cf808430c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1304,7 +1304,12 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module): class MochiCombinedTimestepCaptionEmbedding(nn.Module): def __init__( - self, embedding_dim: int, pooled_projection_dim: int, text_embed_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8 + self, + embedding_dim: int, + pooled_projection_dim: int, + text_embed_dim: int, + time_embed_dim: int = 256, + num_attention_heads: int = 8, ) -> None: super().__init__() diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index e11faee490..9058320998 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -385,20 +385,21 @@ class LuminaLayerNormContinuous(nn.Module): out_dim: Optional[int] = None, ): super().__init__() + # AdaLN self.silu = nn.SiLU() self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + if norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) else: raise ValueError(f"unknown norm_type {norm_type}") - # linear_2 + + self.linear_2 = None if out_dim is not None: - self.linear_2 = nn.Linear( - embedding_dim, - out_dim, - bias=bias, - ) + self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) def forward( self, diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 8aa7e48d3f..7ece241e4b 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -26,7 +26,7 @@ from ..attention_processor import Attention, FluxAttnProcessor2_0 from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, MochiRMSNormZero, RMSNorm +from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -55,7 +55,14 @@ class MochiTransformerBlock(nn.Module): if not context_pre_only: self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim) else: - self.norm1_context = nn.Linear(dim, pooled_projection_dim) + self.norm1_context = LuminaLayerNormContinuous( + embedding_dim=pooled_projection_dim, + conditioning_embedding_dim=dim, + eps=1e-6, + elementwise_affine=False, + norm_type="rms_norm", + out_dim=None, + ) self.attn1 = Attention( query_dim=dim, @@ -83,7 +90,9 @@ class MochiTransformerBlock(nn.Module): self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) self.ff_context = None if not context_pre_only: - self.ff_context = FeedForward(pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False) + self.ff_context = FeedForward( + pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False + ) self.norm4 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) self.norm4_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False) @@ -102,7 +111,7 @@ class MochiTransformerBlock(nn.Module): encoder_hidden_states, temb ) else: - norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) attn_hidden_states, context_attn_hidden_states = self.attn1( hidden_states=norm_hidden_states, @@ -112,7 +121,7 @@ class MochiTransformerBlock(nn.Module): hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) - + if not self.context_pre_only: encoder_hidden_states = encoder_hidden_states + self.norm2_context( context_attn_hidden_states @@ -207,7 +216,9 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): post_patch_height = height // p post_patch_width = width // p - temb, encoder_hidden_states = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype) + temb, encoder_hidden_states = self.time_embed( + timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype + ) hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.patch_embed(hidden_states) From ccc1b36b09d3e297c1013c71df0adc272eaf3502 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 24 Oct 2024 09:40:10 +0200 Subject: [PATCH 11/19] update --- .../pipelines/mochi/pipeline_mochi.py | 654 ++++++++++++++++++ .../pipelines/mochi/pipeline_output.py | 20 + 2 files changed, 674 insertions(+) create mode 100644 src/diffusers/pipelines/mochi/pipeline_mochi.py create mode 100644 src/diffusers/pipelines/mochi/pipeline_output.py diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py new file mode 100644 index 0000000000..f9a026440c --- /dev/null +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -0,0 +1,654 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...image_processor import VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import MochiTransformer3D +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import MochiPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import MochiPipeline + + >>> pipe = MochiPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class MochiPipeline( + DiffusionPipeline, +): + r""" + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: MochiTransformer3D, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False)[0] + + dtype = self.text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer` and `text_encoder`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device ori self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 16, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer` and `text_encoder`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.MochiPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return MochiPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/mochi/pipeline_output.py b/src/diffusers/pipelines/mochi/pipeline_output.py new file mode 100644 index 0000000000..cc14372794 --- /dev/null +++ b/src/diffusers/pipelines/mochi/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class MochiPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor From 2fd2ec40250e00de6965ec005aab1423b21b9291 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 13:48:22 +0200 Subject: [PATCH 12/19] fixes --- src/diffusers/models/attention_processor.py | 86 +++- src/diffusers/models/normalization.py | 4 +- .../models/transformers/transformer_mochi.py | 65 ++- .../transformer_mochi_original.py | 402 ++++++++++++------ 4 files changed, 410 insertions(+), 147 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ce0f9d87c8..dd61a6ab2c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1795,8 +1795,7 @@ class FluxAttnProcessor2_0: # dropout hidden_states = attn.to_out[1](hidden_states) - if hasattr(attn, "to_add_out"): - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states else: @@ -3082,6 +3081,89 @@ class LuminaAttnProcessor2_0: return hidden_states +class MochiAttnProcessor2_0: + """Attention processor used in Mochi.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + breakpoint() + batch_size = hidden_states.size(0) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + if image_rotary_emb is not None: + def apply_rotary_emb(x, freqs_cos, freqs_sin): + x_even = x[..., 0::2].float() + x_odd = x[..., 1::2].float() + + cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) + sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) + + return torch.stack([cos, sin], dim=-1).flatten(-2) + + query = apply_rotary_emb(query, *image_rotary_emb) + key = apply_rotary_emb(key, *image_rotary_emb) + + query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) + encoder_query, encoder_key, encoder_value = encoder_query.transpose(1, 2), encoder_key.transpose(1, 2), encoder_value.transpose(1, 2) + + sequence_length = query.size(2) + encoder_sequence_length = encoder_query.size(2) + + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length), dim=1) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + class FusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 9058320998..dcfaed90b3 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -246,13 +246,13 @@ class MochiRMSNormZero(nn.Module): """ def __init__( - self, embedding_dim: int, hidden_dim: int, norm_eps: float = 1e-5, elementwise_affine: bool = False + self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False ) -> None: super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, hidden_dim) - self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=elementwise_affine) + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) def forward( self, hidden_states: torch.Tensor, emb: torch.Tensor diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 7ece241e4b..7938d5e39f 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward -from ..attention_processor import Attention, FluxAttnProcessor2_0 +from ..attention_processor import Attention, MochiAttnProcessor2_0 from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -43,6 +43,7 @@ class MochiTransformerBlock(nn.Module): qk_norm: str = "rms_norm", activation_fn: str = "swiglu", context_pre_only: bool = True, + eps: float = 1e-6, ) -> None: super().__init__() @@ -50,15 +51,15 @@ class MochiTransformerBlock(nn.Module): self.ff_inner_dim = (4 * dim * 2) // 3 self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3 - self.norm1 = MochiRMSNormZero(dim, 4 * dim) + self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False) if not context_pre_only: - self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim) + self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False) else: self.norm1_context = LuminaLayerNormContinuous( embedding_dim=pooled_projection_dim, conditioning_embedding_dim=dim, - eps=1e-6, + eps=eps, elementwise_affine=False, norm_type="rms_norm", out_dim=None, @@ -76,16 +77,16 @@ class MochiTransformerBlock(nn.Module): out_dim=dim, out_context_dim=pooled_projection_dim, context_pre_only=context_pre_only, - processor=FluxAttnProcessor2_0(), - eps=1e-6, + processor=MochiAttnProcessor2_0(), + eps=eps, elementwise_affine=True, ) - self.norm2 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) - self.norm2_context = RMSNorm(pooled_projection_dim, eps=1e-6, elementwise_affine=False) + self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) - self.norm3 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) - self.norm3_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False) + self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) self.ff_context = None @@ -94,8 +95,8 @@ class MochiTransformerBlock(nn.Module): pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False ) - self.norm4 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) - self.norm4_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False) + self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) def forward( self, @@ -104,6 +105,7 @@ class MochiTransformerBlock(nn.Module): temb: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + breakpoint() norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) if not self.context_pre_only: @@ -140,6 +142,40 @@ class MochiTransformerBlock(nn.Module): return hidden_states, encoder_hidden_states +class MochiRoPE(nn.Module): + def __init__(self, base_height: int = 192, base_width: int = 192, theta: float = 10000.0) -> None: + super().__init__() + + self.target_area = base_height * base_width + + def _centers(self, start, stop, num, device, dtype) -> torch.Tensor: + edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype) + return (edges[:-1] + edges[1:]) / 2 + + def _get_positions(self, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + scale = (self.target_area / (height * width)) ** 0.5 + + t = torch.arange(num_frames, device=device, dtype=dtype) + h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype) + w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype) + + grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") + + positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3) + return positions + + def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: + freqs = torch.einsum("nd,dhf->nhf", pos, freqs) + freqs_cos = torch.cos(freqs) + freqs_sin = torch.sin(freqs) + return freqs_cos, freqs_sin + + def forward(self, pos_frequencies: torch.Tensor, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> Tuple[torch.Tensor, torch.Tensor]: + pos = self._get_positions(num_frames, height, width, device, dtype) + rope_cos, rope_sin = self._create_rope(pos_frequencies, pos) + return rope_cos, rope_sin + + @maybe_allow_in_graph class MochiTransformer3DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @@ -169,6 +205,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, + pos_embed_type=None, ) self.time_embed = MochiCombinedTimestepCaptionEmbedding( @@ -180,6 +217,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): ) self.pos_frequencies = nn.Parameter(torch.empty(3, num_attention_heads, attention_head_dim // 2)) + self.rope = MochiRoPE() self.transformer_blocks = nn.ModuleList( [ @@ -207,7 +245,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, ) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape @@ -224,6 +261,8 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): hidden_states = self.patch_embed(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) + image_rotary_emb = self.rope(self.pos_frequencies, num_frames, post_patch_height, post_patch_width, device=hidden_states.device, dtype=torch.float32) + for i, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, diff --git a/src/diffusers/models/transformers/transformer_mochi_original.py b/src/diffusers/models/transformers/transformer_mochi_original.py index 9c8924decb..2dad5c5c86 100644 --- a/src/diffusers/models/transformers/transformer_mochi_original.py +++ b/src/diffusers/models/transformers/transformer_mochi_original.py @@ -2,7 +2,7 @@ import collections import functools import itertools import math -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -473,6 +473,7 @@ class AsymmetricJointBlock(nn.Module): x: (B, N, dim) tensor of visual tokens after block y: (B, L, dim) tensor of text tokens after block """ + breakpoint() N = x.size(1) c = F.silu(c) @@ -559,152 +560,291 @@ class AsymmetricAttention(nn.Module): self.proj_x = nn.Linear(dim_x, dim_x, bias=out_bias, device=device) self.proj_y = nn.Linear(dim_x, dim_y, bias=out_bias, device=device) if update_y else nn.Identity() - # def run_qkv_y(self, y): - # cp_rank, cp_size = cp.get_cp_rank_size() - # local_heads = self.num_heads // cp_size + def run_qkv_y(self, y): + qkv_y = self.qkv_y(y) + qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, -1, self.head_dim) + q_y, k_y, v_y = qkv_y.unbind(2) + return q_y, k_y, v_y - # if cp.is_cp_active(): - # # Only predict local heads. - # assert not self.qkv_bias - # W_qkv_y = self.qkv_y.weight.view( - # 3, self.num_heads, self.head_dim, self.dim_y - # ) - # W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads) - # W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y) - # qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim) - # else: - # qkv_y = self.qkv_y(y) # (B, L, 3 * dim) + # cp_rank, cp_size = cp.get_cp_rank_size() + # local_heads = self.num_heads // cp_size - # qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim) - # q_y, k_y, v_y = qkv_y.unbind(2) - # return q_y, k_y, v_y + # if cp.is_cp_active(): + # # Only predict local heads. + # assert not self.qkv_bias + # W_qkv_y = self.qkv_y.weight.view( + # 3, self.num_heads, self.head_dim, self.dim_y + # ) + # W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads) + # W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y) + # qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim) + # else: + # qkv_y = self.qkv_y(y) # (B, L, 3 * dim) - # def prepare_qkv( - # self, - # x: torch.Tensor, # (B, N, dim_x) - # y: torch.Tensor, # (B, L, dim_y) - # *, - # scale_x: torch.Tensor, - # scale_y: torch.Tensor, - # rope_cos: torch.Tensor, - # rope_sin: torch.Tensor, - # valid_token_indices: torch.Tensor, - # ): - # # Pre-norm for visual features - # x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size + # qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim) + # q_y, k_y, v_y = qkv_y.unbind(2) + # return q_y, k_y, v_y - # # Process visual features - # qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) - # assert qkv_x.dtype == torch.bfloat16 - # qkv_x = cp.all_to_all_collect_tokens( - # qkv_x, self.num_heads - # ) # (3, B, N, local_h, head_dim) + def prepare_qkv( + self, + x: torch.Tensor, # (B, N, dim_x) + y: torch.Tensor, # (B, L, dim_y) + *, + scale_x: torch.Tensor, + scale_y: torch.Tensor, + rope_cos: torch.Tensor, + rope_sin: torch.Tensor, + valid_token_indices: torch.Tensor = None, + ): + breakpoint() + # Pre-norm for visual features + x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size - # # Process text features - # y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) - # q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim) - # q_y = self.q_norm_y(q_y) - # k_y = self.k_norm_y(k_y) + # Process visual features + qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) + # assert qkv_x.dtype == torch.bfloat16 + # qkv_x = cp.all_to_all_collect_tokens( + # qkv_x, self.num_heads + # ) # (3, B, N, local_h, head_dim) + B, M, _ = qkv_x.size() + qkv_x = qkv_x.view(B, M, 3, -1, 128) + qkv_x = qkv_x.permute(2, 0, 1, 3, 4) - # # Split qkv_x into q, k, v - # q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim) - # q_x = self.q_norm_x(q_x) - # q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) - # k_x = self.k_norm_x(k_x) - # k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) + # Process text features + y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) + q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim) + q_y = self.q_norm_y(q_y) + k_y = self.k_norm_y(k_y) - # # Unite streams - # qkv = unify_streams( - # q_x, - # k_x, - # v_x, - # q_y, - # k_y, - # v_y, - # valid_token_indices, - # ) + # Split qkv_x into q, k, v + q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim) + q_x = self.q_norm_x(q_x) + q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) + k_x = self.k_norm_x(k_x) + k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) - # return qkv + # Unite streams + qkv = unify_streams( + q_x, + k_x, + v_x, + q_y, + k_y, + v_y, + valid_token_indices, + ) - # @torch.compiler.disable() - # def run_attention( - # self, - # qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim) - # *, - # B: int, - # L: int, - # M: int, - # cu_seqlens: torch.Tensor, - # max_seqlen_in_batch: int, - # valid_token_indices: torch.Tensor, - # ): - # with torch.autocast("cuda", enabled=False): - # out: torch.Tensor = flash_attn_varlen_qkvpacked_func( - # qkv, - # cu_seqlens=cu_seqlens, - # max_seqlen=max_seqlen_in_batch, - # dropout_p=0.0, - # softmax_scale=self.softmax_scale, - # ) # (total, local_heads, head_dim) - # out = out.view(total, local_dim) + return qkv - # x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype) - # assert x.size() == (B, N, local_dim) - # assert y.size() == (B, L, local_dim) + @torch.compiler.disable() + def run_attention( + self, + qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim) + *, + B: int, + L: int, + M: int, + cu_seqlens: torch.Tensor = None, + max_seqlen_in_batch: int = None, + valid_token_indices: torch.Tensor = None, + ): + breakpoint() + N = M + local_heads = self.num_heads + local_dim = local_heads * self.head_dim + # with torch.autocast("cuda", enabled=False): + # out: torch.Tensor = flash_attn_varlen_qkvpacked_func( + # qkv, + # cu_seqlens=cu_seqlens, + # max_seqlen=max_seqlen_in_batch, + # dropout_p=0.0, + # softmax_scale=self.softmax_scale, + # ) # (total, local_heads, head_dim) + # out = out.view(total, local_dim) - # x = x.view(B, N, local_heads, self.head_dim) - # x = self.proj_x(x) # (B, M, dim_x) + q, k, v = qkv.unbind(1) + out = F.scaled_dot_product_attention(q, k, v) - # y = self.proj_y(y) # (B, L, dim_y) - # return x, y + # x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype) + x, y = out.split_with_sizes((N, L), dim=0) + # assert x.size() == (B, N, local_dim) + # assert y.size() == (B, L, local_dim) - # def forward( - # self, - # x: torch.Tensor, # (B, N, dim_x) - # y: torch.Tensor, # (B, L, dim_y) - # *, - # scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. - # scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. - # packed_indices: Dict[str, torch.Tensor] = None, - # **rope_rotation, - # ) -> Tuple[torch.Tensor, torch.Tensor]: - # """Forward pass of asymmetric multi-modal attention. + x = x.view(B, -1, local_heads, self.head_dim).flatten(2, 3) + x = self.proj_x(x) # (B, M, dim_x) - # Args: - # x: (B, N, dim_x) tensor for visual tokens - # y: (B, L, dim_y) tensor of text token features - # packed_indices: Dict with keys for Flash Attention - # num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens + y = y.view(B, -1, local_heads, self.head_dim).flatten(2, 3) + y = self.proj_y(y) # (B, L, dim_y) + return x, y - # Returns: - # x: (B, N, dim_x) tensor of visual tokens after multi-modal attention - # y: (B, L, dim_y) tensor of text token features after multi-modal attention - # """ - # B, L, _ = y.shape - # _, M, _ = x.shape + def forward( + self, + x: torch.Tensor, # (B, N, dim_x) + y: torch.Tensor, # (B, L, dim_y) + *, + scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. + scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. + packed_indices: Dict[str, torch.Tensor] = None, + **rope_rotation, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass of asymmetric multi-modal attention. - # # Predict a packed QKV tensor from visual and text features. - # # Don't checkpoint the all_to_all. - # qkv = self.prepare_qkv( - # x=x, - # y=y, - # scale_x=scale_x, - # scale_y=scale_y, - # rope_cos=rope_rotation.get("rope_cos"), - # rope_sin=rope_rotation.get("rope_sin"), - # valid_token_indices=packed_indices["valid_token_indices_kv"], - # ) # (total <= B * (N + L), 3, local_heads, head_dim) + Args: + x: (B, N, dim_x) tensor for visual tokens + y: (B, L, dim_y) tensor of text token features + packed_indices: Dict with keys for Flash Attention + num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens - # x, y = self.run_attention( - # qkv, - # B=B, - # L=L, - # M=M, - # cu_seqlens=packed_indices["cu_seqlens_kv"], - # max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"], - # valid_token_indices=packed_indices["valid_token_indices_kv"], - # ) - # return x, y + Returns: + x: (B, N, dim_x) tensor of visual tokens after multi-modal attention + y: (B, L, dim_y) tensor of text token features after multi-modal attention + """ + B, L, _ = y.shape + _, M, _ = x.shape + + # Predict a packed QKV tensor from visual and text features. + # Don't checkpoint the all_to_all. + qkv = self.prepare_qkv( + x=x, + y=y, + scale_x=scale_x, + scale_y=scale_y, + rope_cos=rope_rotation.get("rope_cos"), + rope_sin=rope_rotation.get("rope_sin"), + # valid_token_indices=packed_indices["valid_token_indices_kv"], + ) # (total <= B * (N + L), 3, local_heads, head_dim) + + x, y = self.run_attention( + qkv, + B=B, + L=L, + M=M, + # cu_seqlens=packed_indices["cu_seqlens_kv"], + # max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"], + # valid_token_indices=packed_indices["valid_token_indices_kv"], + ) + return x, y + +def apply_rotary_emb_qk_real( + xqk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers. + + Args: + xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D) + Can be either just query or just key, or both stacked along some batch or * dim. + freqs_cos (torch.Tensor): Precomputed cosine frequency tensor. + freqs_sin (torch.Tensor): Precomputed sine frequency tensor. + + Returns: + torch.Tensor: The input tensor with rotary embeddings applied. + """ + # assert xqk.dtype == torch.bfloat16 + # Split the last dimension into even and odd parts + xqk_even = xqk[..., 0::2] + xqk_odd = xqk[..., 1::2] + + # Apply rotation + cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk) + sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk) + + # Interleave the results back into the original shape + out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2) + # assert out.dtype == torch.bfloat16 + return out + +class PadSplitXY(torch.autograd.Function): + """ + Merge heads, pad and extract visual and text tokens, + and split along the sequence length. + """ + + @staticmethod + def forward( + ctx, + xy: torch.Tensor, + indices: torch.Tensor, + B: int, + N: int, + L: int, + dtype: torch.dtype, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim). + indices: Valid token indices out of unpacked tensor. Shape: (total,) + + Returns: + x: Visual tokens. Shape: (B, N, num_heads * head_dim). + y: Text tokens. Shape: (B, L, num_heads * head_dim). + """ + ctx.save_for_backward(indices) + ctx.B, ctx.N, ctx.L = B, N, L + D = xy.size(1) + + # Pad sequences to (B, N + L, dim). + assert indices.ndim == 1 + output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype) + indices = indices.unsqueeze(1).expand( + -1, D + ) # (total,) -> (total, num_heads * head_dim) + output.scatter_(0, indices, xy) + xy = output.view(B, N + L, D) + + # Split visual and text tokens along the sequence length. + return torch.tensor_split(xy, (N,), dim=1) + + +def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]: + return PadSplitXY.apply(xy, indices, B, N, L, dtype) + +class UnifyStreams(torch.autograd.Function): + """Unify visual and text streams.""" + + @staticmethod + def forward( + ctx, + q_x: torch.Tensor, + k_x: torch.Tensor, + v_x: torch.Tensor, + q_y: torch.Tensor, + k_y: torch.Tensor, + v_y: torch.Tensor, + indices: torch.Tensor, + ): + """ + Args: + q_x: (B, N, num_heads, head_dim) + k_x: (B, N, num_heads, head_dim) + v_x: (B, N, num_heads, head_dim) + q_y: (B, L, num_heads, head_dim) + k_y: (B, L, num_heads, head_dim) + v_y: (B, L, num_heads, head_dim) + indices: (total <= B * (N + L)) + + Returns: + qkv: (total <= B * (N + L), 3, num_heads, head_dim) + """ + ctx.save_for_backward(indices) + B, N, num_heads, head_dim = q_x.size() + ctx.B, ctx.N, ctx.L = B, N, q_y.size(1) + D = num_heads * head_dim + + q = torch.cat([q_x, q_y], dim=1) + k = torch.cat([k_x, k_y], dim=1) + v = torch.cat([v_x, v_y], dim=1) + qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D) + + # indices = indices[:, None, None].expand(-1, 3, D) + # qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim) + return qkv.unflatten(2, (num_heads, head_dim)) + + +def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor: + return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices) class FinalLayer(nn.Module): @@ -837,6 +977,7 @@ class MochiTransformer3DModel(nn.Module): t5_mask: torch.Tensor, ): """Prepare input and conditioning embeddings.""" + breakpoint() with torch.profiler.record_function("x_emb_pe"): # Visual patch embeddings with positional encoding. @@ -901,6 +1042,7 @@ class MochiTransformer3DModel(nn.Module): x, c, y_feat, rope_cos, rope_sin = self.prepare(x, sigma, y_feat[0], y_mask[0]) del y_mask + breakpoint() for i, block in enumerate(self.blocks): x, y_feat = block( From 46f95d5cdbb17e15f454d7a959c8244b30ddcb7e Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 13:49:12 +0200 Subject: [PATCH 13/19] make style --- src/diffusers/models/attention_processor.py | 18 ++++++---- .../models/transformers/transformer_mochi.py | 36 +++++++++++++++---- .../transformer_mochi_original.py | 29 ++++++++------- 3 files changed, 54 insertions(+), 29 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index dd61a6ab2c..c17556463c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3096,9 +3096,6 @@ class MochiAttnProcessor2_0: attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - breakpoint() - batch_size = hidden_states.size(0) - query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) @@ -3124,8 +3121,9 @@ class MochiAttnProcessor2_0: encoder_query = attn.norm_added_q(encoder_query) if attn.norm_added_k is not None: encoder_key = attn.norm_added_k(encoder_key) - + if image_rotary_emb is not None: + def apply_rotary_emb(x, freqs_cos, freqs_sin): x_even = x[..., 0::2].float() x_odd = x[..., 1::2].float() @@ -3137,9 +3135,13 @@ class MochiAttnProcessor2_0: query = apply_rotary_emb(query, *image_rotary_emb) key = apply_rotary_emb(key, *image_rotary_emb) - + query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) - encoder_query, encoder_key, encoder_value = encoder_query.transpose(1, 2), encoder_key.transpose(1, 2), encoder_value.transpose(1, 2) + encoder_query, encoder_key, encoder_value = ( + encoder_query.transpose(1, 2), + encoder_key.transpose(1, 2), + encoder_value.transpose(1, 2), + ) sequence_length = query.size(2) encoder_sequence_length = encoder_query.size(2) @@ -3152,7 +3154,9 @@ class MochiAttnProcessor2_0: hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.to(query.dtype) - hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length), dim=1) + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( + (sequence_length, encoder_sequence_length), dim=1 + ) # linear proj hidden_states = attn.to_out[0](hidden_states) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 7938d5e39f..3b6c0decbe 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -145,16 +145,23 @@ class MochiTransformerBlock(nn.Module): class MochiRoPE(nn.Module): def __init__(self, base_height: int = 192, base_width: int = 192, theta: float = 10000.0) -> None: super().__init__() - + self.target_area = base_height * base_width - + def _centers(self, start, stop, num, device, dtype) -> torch.Tensor: edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype) return (edges[:-1] + edges[1:]) / 2 - - def _get_positions(self, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + + def _get_positions( + self, + num_frames: int, + height: int, + width: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: scale = (self.target_area / (height * width)) ** 0.5 - + t = torch.arange(num_frames, device=device, dtype=dtype) h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype) w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype) @@ -170,7 +177,15 @@ class MochiRoPE(nn.Module): freqs_sin = torch.sin(freqs) return freqs_cos, freqs_sin - def forward(self, pos_frequencies: torch.Tensor, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, + pos_frequencies: torch.Tensor, + num_frames: int, + height: int, + width: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: pos = self._get_positions(num_frames, height, width, device, dtype) rope_cos, rope_sin = self._create_rope(pos_frequencies, pos) return rope_cos, rope_sin @@ -261,7 +276,14 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): hidden_states = self.patch_embed(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) - image_rotary_emb = self.rope(self.pos_frequencies, num_frames, post_patch_height, post_patch_width, device=hidden_states.device, dtype=torch.float32) + image_rotary_emb = self.rope( + self.pos_frequencies, + num_frames, + post_patch_height, + post_patch_width, + device=hidden_states.device, + dtype=torch.float32, + ) for i, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = block( diff --git a/src/diffusers/models/transformers/transformer_mochi_original.py b/src/diffusers/models/transformers/transformer_mochi_original.py index 2dad5c5c86..a428e57a3b 100644 --- a/src/diffusers/models/transformers/transformer_mochi_original.py +++ b/src/diffusers/models/transformers/transformer_mochi_original.py @@ -96,8 +96,7 @@ def compute_mixed_rotation( num_heads: int Returns: - freqs_cos: [N, num_heads, num_freqs] - cosine components - freqs_sin: [N, num_heads, num_freqs] - sine components + freqs_cos: [N, num_heads, num_freqs] - cosine components freqs_sin: [N, num_heads, num_freqs] - sine components """ with torch.autocast("cuda", enabled=False): assert freqs.ndim == 3 @@ -470,8 +469,7 @@ class AsymmetricJointBlock(nn.Module): num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens Returns: - x: (B, N, dim) tensor of visual tokens after block - y: (B, L, dim) tensor of text tokens after block + x: (B, N, dim) tensor of visual tokens after block y: (B, L, dim) tensor of text tokens after block """ breakpoint() N = x.size(1) @@ -651,7 +649,7 @@ class AsymmetricAttention(nn.Module): breakpoint() N = M local_heads = self.num_heads - local_dim = local_heads * self.head_dim + # local_dim = local_heads * self.head_dim # with torch.autocast("cuda", enabled=False): # out: torch.Tensor = flash_attn_varlen_qkvpacked_func( # qkv, @@ -696,8 +694,8 @@ class AsymmetricAttention(nn.Module): num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens Returns: - x: (B, N, dim_x) tensor of visual tokens after multi-modal attention - y: (B, L, dim_y) tensor of text token features after multi-modal attention + x: (B, N, dim_x) tensor of visual tokens after multi-modal attention y: (B, L, dim_y) tensor of text token + features after multi-modal attention """ B, L, _ = y.shape _, M, _ = x.shape @@ -725,6 +723,7 @@ class AsymmetricAttention(nn.Module): ) return x, y + def apply_rotary_emb_qk_real( xqk: torch.Tensor, freqs_cos: torch.Tensor, @@ -756,10 +755,10 @@ def apply_rotary_emb_qk_real( # assert out.dtype == torch.bfloat16 return out + class PadSplitXY(torch.autograd.Function): """ - Merge heads, pad and extract visual and text tokens, - and split along the sequence length. + Merge heads, pad and extract visual and text tokens, and split along the sequence length. """ @staticmethod @@ -778,8 +777,7 @@ class PadSplitXY(torch.autograd.Function): indices: Valid token indices out of unpacked tensor. Shape: (total,) Returns: - x: Visual tokens. Shape: (B, N, num_heads * head_dim). - y: Text tokens. Shape: (B, L, num_heads * head_dim). + x: Visual tokens. Shape: (B, N, num_heads * head_dim). y: Text tokens. Shape: (B, L, num_heads * head_dim). """ ctx.save_for_backward(indices) ctx.B, ctx.N, ctx.L = B, N, L @@ -788,9 +786,7 @@ class PadSplitXY(torch.autograd.Function): # Pad sequences to (B, N + L, dim). assert indices.ndim == 1 output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype) - indices = indices.unsqueeze(1).expand( - -1, D - ) # (total,) -> (total, num_heads * head_dim) + indices = indices.unsqueeze(1).expand(-1, D) # (total,) -> (total, num_heads * head_dim) output.scatter_(0, indices, xy) xy = output.view(B, N + L, D) @@ -801,6 +797,7 @@ class PadSplitXY(torch.autograd.Function): def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]: return PadSplitXY.apply(xy, indices, B, N, L, dtype) + class UnifyStreams(torch.autograd.Function): """Unify visual and text streams.""" @@ -1034,7 +1031,9 @@ class MochiTransformer3DModel(nn.Module): Args: x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images) sigma: (B,) tensor of noise standard deviations - y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048) + y_feat: + List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, + y_feat_dim=2048) y_mask: List((B, L) boolean tensor indicating which tokens are not padding) packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices. """ From 275041d21e6bf786708eae66dc2aad1e7758e499 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 24 Oct 2024 14:26:23 +0200 Subject: [PATCH 14/19] update --- .../pipelines/mochi/pipeline_mochi.py | 41 ++++++++++++++++--- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index f9a026440c..dcfed214c5 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -139,6 +139,7 @@ def retrieve_timesteps( class MochiPipeline( DiffusionPipeline, + TextualInversionLoaderMixin ): r""" The Flux pipeline for text-to-image generation. @@ -187,14 +188,17 @@ class MochiPipeline( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 - ) + #TODO: determine these scaling factors from model parameters + self.vae_spatial_scale_factor = 8 + self.vae_temporal_scale_factor = 6 + self.patch_size = 2 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_sample_size = 64 + self.default_height = 64 + self.default_width = 64 def _get_t5_prompt_embeds( self, @@ -235,7 +239,7 @@ class MochiPipeline( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False)[0] + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False).last_hidden_state dtype = self.text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -246,7 +250,32 @@ class MochiPipeline( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - return prompt_embeds + return prompt_embeds, prompt_attention_mask + + def _pack_indices(self, attention_mask, latent_frames_dim, latent_height_dim, latent_width_dim): + N = latent_frames_dim * latent_height_dim * latent_width_dim // (self.patch_size**2) + + # Create an expanded token mask saying which tokens are valid across both visual and text tokens. + assert N > 0 and len(attention_mask) == 1 + attention_mask = attention_mask[0] + + mask = F.pad(attention_mask, (N, 0), value=True) # (B, N + L) + seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,) + valid_token_indices = torch.nonzero( + mask.flatten(), as_tuple=False + ).flatten() # up to (B * (N + L),) + + assert valid_token_indices.size(0) >= attention_mask.size(0) * N # At least (B * N,) + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + max_seqlen_in_batch = seqlens_in_batch.max().item() + + return { + "cu_seqlens_kv": cu_seqlens, + "max_seqlen_in_batch_kv": max_seqlen_in_batch, + "valid_token_indices_kv": valid_token_indices, + } def encode_prompt( self, From 44987ad98cd92a2d91a8cb8dba8d7503c57711d7 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 24 Oct 2024 16:31:10 +0200 Subject: [PATCH 15/19] update --- .../pipelines/mochi/pipeline_mochi.py | 175 ++++++++---------- 1 file changed, 74 insertions(+), 101 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index dcfed214c5..3d140b8864 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -15,22 +15,18 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch from transformers import T5EncoderModel, T5TokenizerFast from ...image_processor import VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.transformers import MochiTransformer3D +from ...models.transformers import MochiTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -53,13 +49,13 @@ EXAMPLE_DOC_STRING = """ >>> import torch >>> from diffusers import MochiPipeline - >>> pipe = MochiPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe = MochiPipeline.from_pretrained("black-forest-labs/mochi.1-schnell", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" >>> # Depending on the variant being used, the pipeline call will slightly vary. >>> # Refer to the pipeline documentation for more details. >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] - >>> image.save("flux.png") + >>> image.save("mochi.png") ``` """ @@ -77,6 +73,24 @@ def calculate_shift( return mu +# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 +def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + return sigma_schedule + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -137,17 +151,14 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class MochiPipeline( - DiffusionPipeline, - TextualInversionLoaderMixin -): +class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" - The Flux pipeline for text-to-image generation. + The mochi pipeline for text-to-image generation. Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ Args: - transformer ([`FluxTransformer2DModel`]): + transformer ([`mochiTransformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. @@ -177,7 +188,7 @@ class MochiPipeline( vae: AutoencoderKL, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, - transformer: MochiTransformer3D, + transformer: MochiTransformer3DModel, ): super().__init__() @@ -188,22 +199,22 @@ class MochiPipeline( transformer=transformer, scheduler=scheduler, ) - #TODO: determine these scaling factors from model parameters + # TODO: determine these scaling factors from model parameters self.vae_spatial_scale_factor = 8 self.vae_temporal_scale_factor = 6 self.patch_size = 2 - + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_height = 64 - self.default_width = 64 + self.default_height = 480 + self.default_width = 848 def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, + num_videos_per_prompt: int = 1, max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, @@ -227,10 +238,8 @@ class MochiPipeline( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - prompt_attention_mask = text_inputs.attention_mask - prompt_attention_mask = prompt_attention_mask.to(device) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) @@ -239,7 +248,9 @@ class MochiPipeline( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False).last_hidden_state + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device), output_hidden_states=False + ).last_hidden_state dtype = self.text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -247,43 +258,17 @@ class MochiPipeline( _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) return prompt_embeds, prompt_attention_mask - def _pack_indices(self, attention_mask, latent_frames_dim, latent_height_dim, latent_width_dim): - N = latent_frames_dim * latent_height_dim * latent_width_dim // (self.patch_size**2) - - # Create an expanded token mask saying which tokens are valid across both visual and text tokens. - assert N > 0 and len(attention_mask) == 1 - attention_mask = attention_mask[0] - - mask = F.pad(attention_mask, (N, 0), value=True) # (B, N + L) - seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,) - valid_token_indices = torch.nonzero( - mask.flatten(), as_tuple=False - ).flatten() # up to (B * (N + L),) - - assert valid_token_indices.size(0) >= attention_mask.size(0) * N # At least (B * N,) - cu_seqlens = F.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) - ) - max_seqlen_in_batch = seqlens_in_batch.max().item() - - return { - "cu_seqlens_kv": cu_seqlens, - "max_seqlen_in_batch_kv": max_seqlen_in_batch, - "valid_token_indices_kv": valid_token_indices, - } - def encode_prompt( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, + num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None, ): @@ -297,7 +282,7 @@ class MochiPipeline( used in all text-encoders device: (`torch.device`): torch device - num_images_per_prompt (`int`): + num_videos_per_prompt (`int`): number of images that should be generated per prompt prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not @@ -308,30 +293,29 @@ class MochiPipeline( lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ - device = device ori self._execution_device + device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt if prompt_embeds is None: - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, device=device, ) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + # TODO: Add negative prompts back return prompt_embeds def check_inputs( self, prompt, - prompt_2, height, width, prompt_embeds=None, - pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): @@ -350,25 +334,12 @@ class MochiPipeline( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") @@ -407,11 +378,16 @@ class MochiPipeline( num_channels_latents, height, width, + num_frames, dtype, device, generator, latents=None, ): + height = height // self.vae_spatial_scale_factor + width = width // self.vae_spatial_scale_factor + num_frames = (num_frames - 1) // (self.vae_temporal_scale_factor + 1) + shape = (batch_size, num_channels_latents, num_frames, height, width) if latents is not None: @@ -429,6 +405,10 @@ class MochiPipeline( def guidance_scale(self): return self._guidance_scale + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + @property def joint_attention_kwargs(self): return self._joint_attention_kwargs @@ -470,13 +450,12 @@ class MochiPipeline( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer` and `text_encoder`. If not defined, `prompt` is - will be used instead height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to 16): + The number of video frames to generate num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -490,7 +469,7 @@ class MochiPipeline( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): + num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) @@ -509,7 +488,7 @@ class MochiPipeline( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.flux.MochiPipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -528,13 +507,12 @@ class MochiPipeline( Examples: Returns: - [`~pipelines.flux.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ - - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor + height = height or self.default_height + width = width or self.default_width # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -542,7 +520,6 @@ class MochiPipeline( height, width, prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) @@ -564,22 +541,18 @@ class MochiPipeline( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) - ( - prompt_embeds, - pooled_prompt_embeds, - text_ids, - ) = self.encode_prompt( + (prompt_embeds) = self.encode_prompt( prompt=prompt, prompt_embeds=prompt_embeds, device=device, - num_images_per_prompt=num_images_per_prompt, + num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) # 4. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 - latents, latent_image_ids = self.prepare_latents( + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -591,8 +564,12 @@ class MochiPipeline( latents, ) - # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + # 5. Prepare timestep + + # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 + threshold_noise = 0.025 + sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) + image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, @@ -624,18 +601,14 @@ class MochiPipeline( for i, t in enumerate(timesteps): if self.interrupt: continue - + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, + hidden_states=latent_model_input, + timestep=timestep, encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] From ebcbad2f380d888ae132a01cfaa7ca5d52be675b Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 24 Oct 2024 18:07:43 +0200 Subject: [PATCH 16/19] update --- .../pipelines/mochi/pipeline_mochi.py | 51 +++++++++++++------ 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 3d140b8864..c8b8c0af98 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -261,15 +261,18 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - return prompt_embeds, prompt_attention_mask + return prompt_embeds def encode_prompt( self, prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, + do_classifier_free_guidance=True, lora_scale: Optional[float] = None, ): r""" @@ -277,9 +280,6 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer` and `text_encoder`. If not defined, `prompt` is - used in all text-encoders device: (`torch.device`): torch device num_videos_per_prompt (`int`): @@ -287,14 +287,15 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_embeds = self._get_t5_prompt_embeds( @@ -307,8 +308,32 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype # TODO: Add negative prompts back + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + ) - return prompt_embeds + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + return prompt_embeds, negative_prompt_embeds def check_inputs( self, @@ -541,7 +566,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) - (prompt_embeds) = self.encode_prompt( + (prompt_embeds, negative_prompt_embeds) = self.encode_prompt( prompt=prompt, prompt_embeds=prompt_embeds, device=device, @@ -589,12 +614,8 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: From 8700d64d62e3b66b5bdf0bbbd12ef803575c5656 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 24 Oct 2024 22:44:54 +0200 Subject: [PATCH 17/19] update --- .../pipelines/mochi/pipeline_mochi.py | 42 ++++++++----------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index c8b8c0af98..e2b3196a00 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -104,6 +104,7 @@ def retrieve_timesteps( Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. @@ -272,8 +273,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, - do_classifier_free_guidance=True, - lora_scale: Optional[float] = None, + do_classifier_free_guidance: bool = True, ): r""" @@ -305,14 +305,12 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): device=device, ) - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + prompt_embeds = prompt_embeds.to(self.text_encoder.dtype) - # TODO: Add negative prompts back if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - ) if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( @@ -332,6 +330,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): max_sequence_length=max_sequence_length, device=device, ) + negative_prompt_embeds = negative_prompt_embeds.to(self.text_encoder.dtype) return prompt_embeds, negative_prompt_embeds @@ -532,9 +531,9 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): Examples: Returns: - [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated - images. + [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. """ height = height or self.default_height width = width or self.default_width @@ -595,21 +594,12 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): threshold_noise = 0.025 sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) - image_seq_len = latents.shape[1] - mu = calculate_shift( - image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, - ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas, - mu=mu, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -628,12 +618,16 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): noise_pred = self.transformer( hidden_states=latent_model_input, - timestep=timestep, + timestep=timestep / 1000, encoder_hidden_states=prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] @@ -660,18 +654,16 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): xm.mark_step() if output_type == "latent": - image = latents + video = latents else: - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - image = self.vae.decode(latents, return_dict=False)[0] - image = self.image_processor.postprocess(image, output_type=output_type) + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess(video, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: - return (image,) + return (video,) - return MochiPipelineOutput(images=image) + return MochiPipelineOutput(frames=video) From 969c3aba883d01a8f08133f4f74e540897b1012c Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 24 Oct 2024 22:50:54 +0200 Subject: [PATCH 18/19] update --- src/diffusers/pipelines/mochi/pipeline_mochi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index e2b3196a00..233b1a892d 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -18,7 +18,6 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import T5EncoderModel, T5TokenizerFast -from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import MochiTransformer3DModel @@ -29,6 +28,7 @@ from ...utils import ( replace_example_docstring, ) from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import MochiPipelineOutput @@ -205,7 +205,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): self.vae_temporal_scale_factor = 6 self.patch_size = 2 - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) From 6552653f11391775c73af9b78beb91ef98c9ab63 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 25 Oct 2024 08:31:59 +0200 Subject: [PATCH 19/19] update --- src/diffusers/pipelines/mochi/pipeline_mochi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 233b1a892d..72068f938c 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -618,7 +618,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): noise_pred = self.transformer( hidden_states=latent_model_input, - timestep=timestep / 1000, + timestep=timestep, encoder_hidden_states=prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False,