diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py new file mode 100644 index 0000000000..1d1d10a6ad --- /dev/null +++ b/scripts/convert_mochi_to_diffusers.py @@ -0,0 +1,187 @@ +import argparse +from contextlib import nullcontext + +import torch +from accelerate import init_empty_weights +from safetensors.torch import load_file + +# from transformers import T5EncoderModel, T5Tokenizer +from diffusers import MochiTransformer3DModel +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +TOKENIZER_MAX_LENGTH = 256 + +parser = argparse.ArgumentParser() +parser.add_argument("--transformer_checkpoint_path", default=None, type=str) +# parser.add_argument("--vae_checkpoint_path", default=None, type=str) +parser.add_argument("--output_path", required=True, type=str) +parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving") +parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory") +parser.add_argument("--dtype", type=str, default=None) + +args = parser.parse_args() + + +# This is specific to `AdaLayerNormContinuous`: +# Diffusers implementation split the linear projection into the scale, shift while Mochi split it into shift, scale +def swap_scale_shift(weight, dim): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path): + original_state_dict = load_file(ckpt_path, device="cpu") + new_state_dict = {} + + # Convert patch_embed + new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight") + new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias") + + # Convert time_embed + new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight") + new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias") + new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight") + new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias") + new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight") + new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias") + new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight") + new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias") + new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight") + new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias") + new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight") + new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias") + + # Convert transformer blocks + num_layers = 48 + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + old_prefix = f"blocks.{i}." + + # norm1 + new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight") + new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias") + if i < num_layers - 1: + new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop( + old_prefix + "mod_y.weight" + ) + new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop( + old_prefix + "mod_y.bias" + ) + else: + new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop( + old_prefix + "mod_y.weight" + ) + new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop( + old_prefix + "mod_y.bias" + ) + + # Visual attention + qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.to_q.weight"] = q + new_state_dict[block_prefix + "attn1.to_k.weight"] = k + new_state_dict[block_prefix + "attn1.to_v.weight"] = v + new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop( + old_prefix + "attn.q_norm_x.weight" + ) + new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop( + old_prefix + "attn.k_norm_x.weight" + ) + new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop( + old_prefix + "attn.proj_x.weight" + ) + new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias") + + # Context attention + qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q + new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k + new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v + new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop( + old_prefix + "attn.q_norm_y.weight" + ) + new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop( + old_prefix + "attn.k_norm_y.weight" + ) + if i < num_layers - 1: + new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop( + old_prefix + "attn.proj_y.weight" + ) + new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop( + old_prefix + "attn.proj_y.bias" + ) + + # MLP + new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w1.weight") + new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight") + if i < num_layers - 1: + new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = original_state_dict.pop( + old_prefix + "mlp_y.w1.weight" + ) + new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop( + old_prefix + "mlp_y.w2.weight" + ) + + # Output layers + new_state_dict["norm_out.linear.weight"] = original_state_dict.pop("final_layer.mod.weight") + new_state_dict["norm_out.linear.bias"] = original_state_dict.pop("final_layer.mod.bias") + new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") + new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") + + new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies") + + print("Remaining Keys:", original_state_dict.keys()) + + return new_state_dict + + +# def convert_mochi_vae_checkpoint_to_diffusers(ckpt_path, vae_config): +# original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] +# return convert_ldm_vae_checkpoint(original_state_dict, vae_config) + + +def main(args): + if args.dtype is None: + dtype = None + if args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + elif args.dtype == "fp32": + dtype = torch.float32 + else: + raise ValueError(f"Unsupported dtype: {args.dtype}") + + transformer = None + # vae = None + + if args.transformer_checkpoint_path is not None: + converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers( + args.transformer_checkpoint_path + ) + transformer = MochiTransformer3DModel() + transformer.load_state_dict(converted_transformer_state_dict, strict=True) + if dtype is not None: + # Original checkpoint data type will be preserved + transformer = transformer.to(dtype=dtype) + + # text_encoder_id = "google/t5-v1_1-xxl" + # tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) + # text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + + # # Apparently, the conversion does not work anymore without this :shrug: + # for param in text_encoder.parameters(): + # param.data = param.data.contiguous() + + transformer.save_pretrained("/raid/aryan/mochi-diffusers", subfolder="transformer") + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 789458a262..c71cbfd5a4 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -100,6 +100,7 @@ else: "Kandinsky3UNet", "LatteTransformer3DModel", "LuminaNextDiT2DModel", + "MochiTransformer3DModel", "ModelMixin", "MotionAdapter", "MultiAdapter", @@ -579,6 +580,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: Kandinsky3UNet, LatteTransformer3DModel, LuminaNextDiT2DModel, + MochiTransformer3DModel, ModelMixin, MotionAdapter, MultiAdapter, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 4dda8c36ba..27177b2adc 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -56,6 +56,7 @@ if is_torch_available(): _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] + _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] @@ -106,6 +107,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: HunyuanDiT2DModel, LatteTransformer3DModel, LuminaNextDiT2DModel, + MochiTransformer3DModel, PixArtTransformer2DModel, PriorTransformer, SD3Transformer2DModel, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e735c4ee7d..c17556463c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -120,6 +120,7 @@ class Attention(nn.Module): _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, out_dim: int = None, + out_context_dim: int = None, context_pre_only=None, pre_only=False, elementwise_affine: bool = True, @@ -142,6 +143,7 @@ class Attention(nn.Module): self.dropout = dropout self.fused_projections = False self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim self.context_pre_only = context_pre_only self.pre_only = pre_only @@ -241,7 +243,7 @@ class Attention(nn.Module): self.to_out.append(nn.Dropout(dropout)) if self.context_pre_only is not None and not self.context_pre_only: - self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) if qk_norm is not None and added_kv_proj_dim is not None: if qk_norm == "fp32_layer_norm": @@ -1792,6 +1794,7 @@ class FluxAttnProcessor2_0: hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states @@ -3078,6 +3081,93 @@ class LuminaAttnProcessor2_0: return hidden_states +class MochiAttnProcessor2_0: + """Attention processor used in Mochi.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + if image_rotary_emb is not None: + + def apply_rotary_emb(x, freqs_cos, freqs_sin): + x_even = x[..., 0::2].float() + x_odd = x[..., 1::2].float() + + cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) + sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) + + return torch.stack([cos, sin], dim=-1).flatten(-2) + + query = apply_rotary_emb(query, *image_rotary_emb) + key = apply_rotary_emb(key, *image_rotary_emb) + + query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) + encoder_query, encoder_key, encoder_value = ( + encoder_query.transpose(1, 2), + encoder_key.transpose(1, 2), + encoder_value.transpose(1, 2), + ) + + sequence_length = query.size(2) + encoder_sequence_length = encoder_query.size(2) + + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( + (sequence_length, encoder_sequence_length), dim=1 + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + class FusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 4ccddbbaf4..3cf808430c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1302,6 +1302,41 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module): return conditioning +class MochiCombinedTimestepCaptionEmbedding(nn.Module): + def __init__( + self, + embedding_dim: int, + pooled_projection_dim: int, + text_embed_dim: int, + time_embed_dim: int = 256, + num_attention_heads: int = 8, + ) -> None: + super().__init__() + + self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0) + self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim) + self.pooler = MochiAttentionPool( + num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim + ) + self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim) + + def forward( + self, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + hidden_dtype: Optional[torch.dtype] = None, + ): + time_proj = self.time_proj(timestep) + time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype)) + + pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask) + caption_proj = self.caption_proj(encoder_hidden_states) + + conditioning = time_emb + pooled_projections + return conditioning, caption_proj + + class TextTimeEmbedding(nn.Module): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): super().__init__() @@ -1445,7 +1480,7 @@ class MochiAttentionPool(nn.Module): self.to_kv = nn.Linear(embed_dim, 2 * embed_dim) self.to_q = nn.Linear(embed_dim, embed_dim) self.to_out = nn.Linear(embed_dim, self.output_dim) - + @staticmethod def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor: """ @@ -1504,9 +1539,7 @@ class MochiAttentionPool(nn.Module): q = q.unsqueeze(2) # (B, H, 1, head_dim) # Compute attention. - x = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, dropout_p=0.0 - ) # (B, H, 1, head_dim) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim) # Concatenate heads and run output. x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 029c147fcb..dcfaed90b3 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -237,6 +237,33 @@ class LuminaRMSNormZero(nn.Module): return x, gate_msa, scale_mlp, gate_mlp +class MochiRMSNormZero(nn.Module): + r""" + Adaptive RMS Norm used in Mochi. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__( + self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False + ) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, hidden_dim) + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward( + self, hidden_states: torch.Tensor, emb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + + return hidden_states, gate_msa, scale_mlp, gate_mlp + + class AdaLayerNormSingle(nn.Module): r""" Norm layer adaptive layer norm single (adaLN-single). @@ -358,20 +385,21 @@ class LuminaLayerNormContinuous(nn.Module): out_dim: Optional[int] = None, ): super().__init__() + # AdaLN self.silu = nn.SiLU() self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + if norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) else: raise ValueError(f"unknown norm_type {norm_type}") - # linear_2 + + self.linear_2 = None if out_dim is not None: - self.linear_2 = nn.Linear( - embedding_dim, - out_dim, - bias=bias, - ) + self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) def forward( self, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 58787c079e..e1c2c1edf1 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -16,5 +16,6 @@ if is_torch_available(): from .transformer_2d import Transformer2DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel + from .transformer_mochi import MochiTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index c56a7845cb..3b6c0decbe 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.nn as nn @@ -21,11 +21,12 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention, FeedForward -from ..embeddings import PatchEmbed, MochiAttentionPool, TimestepEmbedding, Timesteps +from ..attention import FeedForward +from ..attention_processor import Attention, MochiAttnProcessor2_0 +from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNorm +from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -38,36 +39,160 @@ class MochiTransformerBlock(nn.Module): dim: int, num_attention_heads: int, attention_head_dim: int, - caption_dim: int, - update_captions: bool = True, + pooled_projection_dim: int, + qk_norm: str = "rms_norm", + activation_fn: str = "swiglu", + context_pre_only: bool = True, + eps: float = 1e-6, ) -> None: super().__init__() - # TODO: Replace this with norm - self.mod_x = nn.Linear(dim, 4 * dim) - if self.update_y: - self.mod_y = nn.Linear(dim, 4 * caption_dim) + self.context_pre_only = context_pre_only + self.ff_inner_dim = (4 * dim * 2) // 3 + self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3 + + self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False) + + if not context_pre_only: + self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False) else: - self.mod_y = nn.Linear(dim, caption_dim) - - # TODO(aryan): attention class does not look compatible - self.attn1 = Attention(...) - # norms go in attention - # self.q_norm_x = RMSNorm(attention_head_dim) - # self.k_norm_x = RMSNorm(attention_head_dim) - # self.q_norm_y = RMSNorm(attention_head_dim) - # self.k_norm_y = RMSNorm(attention_head_dim) + self.norm1_context = LuminaLayerNormContinuous( + embedding_dim=pooled_projection_dim, + conditioning_embedding_dim=dim, + eps=eps, + elementwise_affine=False, + norm_type="rms_norm", + out_dim=None, + ) - self.proj_x = nn.Linear(dim, dim) + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=False, + qk_norm=qk_norm, + added_kv_proj_dim=pooled_projection_dim, + added_proj_bias=False, + out_dim=dim, + out_context_dim=pooled_projection_dim, + context_pre_only=context_pre_only, + processor=MochiAttnProcessor2_0(), + eps=eps, + elementwise_affine=True, + ) - self.proj_y = nn.Linear(dim, caption_dim) if update_captions else None - - def forward(self): - pass + self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + + self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + + self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) + self.ff_context = None + if not context_pre_only: + self.ff_context = FeedForward( + pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False + ) + + self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + breakpoint() + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + + if not self.context_pre_only: + norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context( + encoder_hidden_states, temb + ) + else: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + + attn_hidden_states, context_attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) + norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + + if not self.context_pre_only: + encoder_hidden_states = encoder_hidden_states + self.norm2_context( + context_attn_hidden_states + ) * torch.tanh(enc_gate_msa).unsqueeze(1) + norm_encoder_hidden_states = encoder_hidden_states + self.norm3_context(encoder_hidden_states) * ( + 1 + enc_scale_mlp.unsqueeze(1) + ) + + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_output * torch.tanh(gate_mlp).unsqueeze(1) + + if not self.context_pre_only: + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + context_ff_output * torch.tanh(enc_gate_mlp).unsqueeze(0) + + return hidden_states, encoder_hidden_states + + +class MochiRoPE(nn.Module): + def __init__(self, base_height: int = 192, base_width: int = 192, theta: float = 10000.0) -> None: + super().__init__() + + self.target_area = base_height * base_width + + def _centers(self, start, stop, num, device, dtype) -> torch.Tensor: + edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype) + return (edges[:-1] + edges[1:]) / 2 + + def _get_positions( + self, + num_frames: int, + height: int, + width: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + scale = (self.target_area / (height * width)) ** 0.5 + + t = torch.arange(num_frames, device=device, dtype=dtype) + h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype) + w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype) + + grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") + + positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3) + return positions + + def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: + freqs = torch.einsum("nd,dhf->nhf", pos, freqs) + freqs_cos = torch.cos(freqs) + freqs_sin = torch.sin(freqs) + return freqs_cos, freqs_sin + + def forward( + self, + pos_frequencies: torch.Tensor, + num_frames: int, + height: int, + width: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + pos = self._get_positions(num_frames, height, width, device, dtype) + rope_cos, rope_sin = self._create_rope(pos_frequencies, pos) + return rope_cos, rope_sin @maybe_allow_in_graph -class MochiTransformer3D(ModelMixin, ConfigMixin): +class MochiTransformer3DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config @@ -77,42 +202,105 @@ class MochiTransformer3D(ModelMixin, ConfigMixin): num_attention_heads: int = 24, attention_head_dim: int = 128, num_layers: int = 48, - caption_dim=1536, - mlp_ratio_x=4.0, - mlp_ratio_y=4.0, - in_channels=12, - qk_norm=True, - qkv_bias=False, - out_bias=True, - timestep_mlp_bias=True, - timestep_scale=1000.0, - text_embed_dim=4096, - max_sequence_length=256, + pooled_projection_dim: int = 1536, + in_channels: int = 12, + out_channels: Optional[int] = None, + qk_norm: str = "rms_norm", + text_embed_dim: int = 4096, + time_embed_dim: int = 256, + activation_fn: str = "swiglu", + max_sequence_length: int = 256, ) -> None: super().__init__() inner_dim = num_attention_heads * attention_head_dim - + out_channels = out_channels or in_channels + self.patch_embed = PatchEmbed( patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, + pos_embed_type=None, ) - self.caption_embedder = MochiAttentionPool(num_attention_heads=8, embed_dim=text_embed_dim, output_dim=inner_dim) - self.caption_proj = nn.Linear(text_embed_dim, caption_dim) - - self.pos_frequencies = nn.Parameter( - torch.empty(3, num_attention_heads, attention_head_dim // 2) + self.time_embed = MochiCombinedTimestepCaptionEmbedding( + embedding_dim=inner_dim, + pooled_projection_dim=pooled_projection_dim, + text_embed_dim=text_embed_dim, + time_embed_dim=time_embed_dim, + num_attention_heads=8, ) - self.transformer_blocks = nn.ModuleList([ - MochiTransformerBlock( - dim=inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - caption_dim=caption_dim, - update_captions=i < num_layers - 1, + self.pos_frequencies = nn.Parameter(torch.empty(3, num_attention_heads, attention_head_dim // 2)) + self.rope = MochiRoPE() + + self.transformer_blocks = nn.ModuleList( + [ + MochiTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + pooled_projection_dim=pooled_projection_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + context_pre_only=i == num_layers - 1, + ) + for i in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm" + ) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_attention_mask: torch.Tensor, + return_dict: bool = True, + ) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p = self.config.patch_size + + post_patch_height = height // p + post_patch_width = width // p + + temb, encoder_hidden_states = self.time_embed( + timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype + ) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.patch_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) + + image_rotary_emb = self.rope( + self.pos_frequencies, + num_frames, + post_patch_height, + post_patch_width, + device=hidden_states.device, + dtype=torch.float32, + ) + + for i, block in enumerate(self.transformer_blocks): + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, ) - for i in range(num_layers) - ]) + + # TODO(aryan): do something with self.pos_frequencies + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) + hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) + output = hidden_states.reshape(batch_size, -1, num_frames, height, width) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_mochi_original.py b/src/diffusers/models/transformers/transformer_mochi_original.py index 52bdfa0710..a428e57a3b 100644 --- a/src/diffusers/models/transformers/transformer_mochi_original.py +++ b/src/diffusers/models/transformers/transformer_mochi_original.py @@ -2,7 +2,7 @@ import collections import functools import itertools import math -from typing import Any, Callable, Dict, Optional, List +from typing import Callable, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -19,8 +19,10 @@ def _ntuple(n): return parse + to_2tuple = _ntuple(2) + def centers(start: float, stop, num, dtype=None, device=None): """linspace through bin centers. @@ -94,8 +96,7 @@ def compute_mixed_rotation( num_heads: int Returns: - freqs_cos: [N, num_heads, num_freqs] - cosine components - freqs_sin: [N, num_heads, num_freqs] - sine components + freqs_cos: [N, num_heads, num_freqs] - cosine components freqs_sin: [N, num_heads, num_freqs] - sine components """ with torch.autocast("cuda", enabled=False): assert freqs.ndim == 3 @@ -132,9 +133,7 @@ class TimestepEmbedder(nn.Module): args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: - embedding = torch.cat( - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 - ) + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): @@ -220,15 +219,17 @@ class PatchEmbed(nn.Module): device=device, ) assert norm_layer is None - self.norm = ( - norm_layer(embed_dim, device=device) if norm_layer else nn.Identity() - ) + self.norm = norm_layer(embed_dim, device=device) if norm_layer else nn.Identity() def forward(self, x): B, _C, T, H, W = x.shape if not self.dynamic_img_pad: - assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." - assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." + assert ( + H % self.patch_size[0] == 0 + ), f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." + assert ( + W % self.patch_size[1] == 0 + ), f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." else: pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] @@ -337,9 +338,7 @@ class AttentionPool(nn.Module): q = q.unsqueeze(2) # (B, H, 1, head_dim) # Compute attention. - x = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, dropout_p=0.0 - ) # (B, H, 1, head_dim) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim) # Concatenate heads and run output. x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) @@ -470,9 +469,9 @@ class AsymmetricJointBlock(nn.Module): num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens Returns: - x: (B, N, dim) tensor of visual tokens after block - y: (B, L, dim) tensor of text tokens after block + x: (B, N, dim) tensor of visual tokens after block y: (B, L, dim) tensor of text tokens after block """ + breakpoint() N = x.size(1) c = F.silu(c) @@ -540,9 +539,7 @@ class AsymmetricAttention(nn.Module): self.update_y = update_y self.softmax_scale = softmax_scale if dim_x % num_heads != 0: - raise ValueError( - f"dim_x={dim_x} should be divisible by num_heads={num_heads}" - ) + raise ValueError(f"dim_x={dim_x} should be divisible by num_heads={num_heads}") # Input layers. self.qkv_bias = qkv_bias @@ -559,158 +556,292 @@ class AsymmetricAttention(nn.Module): # Output layers. y features go back down from dim_x -> dim_y. self.proj_x = nn.Linear(dim_x, dim_x, bias=out_bias, device=device) - self.proj_y = ( - nn.Linear(dim_x, dim_y, bias=out_bias, device=device) - if update_y - else nn.Identity() + self.proj_y = nn.Linear(dim_x, dim_y, bias=out_bias, device=device) if update_y else nn.Identity() + + def run_qkv_y(self, y): + qkv_y = self.qkv_y(y) + qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, -1, self.head_dim) + q_y, k_y, v_y = qkv_y.unbind(2) + return q_y, k_y, v_y + + # cp_rank, cp_size = cp.get_cp_rank_size() + # local_heads = self.num_heads // cp_size + + # if cp.is_cp_active(): + # # Only predict local heads. + # assert not self.qkv_bias + # W_qkv_y = self.qkv_y.weight.view( + # 3, self.num_heads, self.head_dim, self.dim_y + # ) + # W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads) + # W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y) + # qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim) + # else: + # qkv_y = self.qkv_y(y) # (B, L, 3 * dim) + + # qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim) + # q_y, k_y, v_y = qkv_y.unbind(2) + # return q_y, k_y, v_y + + def prepare_qkv( + self, + x: torch.Tensor, # (B, N, dim_x) + y: torch.Tensor, # (B, L, dim_y) + *, + scale_x: torch.Tensor, + scale_y: torch.Tensor, + rope_cos: torch.Tensor, + rope_sin: torch.Tensor, + valid_token_indices: torch.Tensor = None, + ): + breakpoint() + # Pre-norm for visual features + x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size + + # Process visual features + qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) + # assert qkv_x.dtype == torch.bfloat16 + # qkv_x = cp.all_to_all_collect_tokens( + # qkv_x, self.num_heads + # ) # (3, B, N, local_h, head_dim) + B, M, _ = qkv_x.size() + qkv_x = qkv_x.view(B, M, 3, -1, 128) + qkv_x = qkv_x.permute(2, 0, 1, 3, 4) + + # Process text features + y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) + q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim) + q_y = self.q_norm_y(q_y) + k_y = self.k_norm_y(k_y) + + # Split qkv_x into q, k, v + q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim) + q_x = self.q_norm_x(q_x) + q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) + k_x = self.k_norm_x(k_x) + k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) + + # Unite streams + qkv = unify_streams( + q_x, + k_x, + v_x, + q_y, + k_y, + v_y, + valid_token_indices, ) - # def run_qkv_y(self, y): - # cp_rank, cp_size = cp.get_cp_rank_size() - # local_heads = self.num_heads // cp_size + return qkv - # if cp.is_cp_active(): - # # Only predict local heads. - # assert not self.qkv_bias - # W_qkv_y = self.qkv_y.weight.view( - # 3, self.num_heads, self.head_dim, self.dim_y - # ) - # W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads) - # W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y) - # qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim) - # else: - # qkv_y = self.qkv_y(y) # (B, L, 3 * dim) + @torch.compiler.disable() + def run_attention( + self, + qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim) + *, + B: int, + L: int, + M: int, + cu_seqlens: torch.Tensor = None, + max_seqlen_in_batch: int = None, + valid_token_indices: torch.Tensor = None, + ): + breakpoint() + N = M + local_heads = self.num_heads + # local_dim = local_heads * self.head_dim + # with torch.autocast("cuda", enabled=False): + # out: torch.Tensor = flash_attn_varlen_qkvpacked_func( + # qkv, + # cu_seqlens=cu_seqlens, + # max_seqlen=max_seqlen_in_batch, + # dropout_p=0.0, + # softmax_scale=self.softmax_scale, + # ) # (total, local_heads, head_dim) + # out = out.view(total, local_dim) - # qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim) - # q_y, k_y, v_y = qkv_y.unbind(2) - # return q_y, k_y, v_y + q, k, v = qkv.unbind(1) + out = F.scaled_dot_product_attention(q, k, v) - # def prepare_qkv( - # self, - # x: torch.Tensor, # (B, N, dim_x) - # y: torch.Tensor, # (B, L, dim_y) - # *, - # scale_x: torch.Tensor, - # scale_y: torch.Tensor, - # rope_cos: torch.Tensor, - # rope_sin: torch.Tensor, - # valid_token_indices: torch.Tensor, - # ): - # # Pre-norm for visual features - # x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size + # x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype) + x, y = out.split_with_sizes((N, L), dim=0) + # assert x.size() == (B, N, local_dim) + # assert y.size() == (B, L, local_dim) - # # Process visual features - # qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) - # assert qkv_x.dtype == torch.bfloat16 - # qkv_x = cp.all_to_all_collect_tokens( - # qkv_x, self.num_heads - # ) # (3, B, N, local_h, head_dim) + x = x.view(B, -1, local_heads, self.head_dim).flatten(2, 3) + x = self.proj_x(x) # (B, M, dim_x) - # # Process text features - # y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) - # q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim) - # q_y = self.q_norm_y(q_y) - # k_y = self.k_norm_y(k_y) + y = y.view(B, -1, local_heads, self.head_dim).flatten(2, 3) + y = self.proj_y(y) # (B, L, dim_y) + return x, y - # # Split qkv_x into q, k, v - # q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim) - # q_x = self.q_norm_x(q_x) - # q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) - # k_x = self.k_norm_x(k_x) - # k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) + def forward( + self, + x: torch.Tensor, # (B, N, dim_x) + y: torch.Tensor, # (B, L, dim_y) + *, + scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. + scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. + packed_indices: Dict[str, torch.Tensor] = None, + **rope_rotation, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass of asymmetric multi-modal attention. - # # Unite streams - # qkv = unify_streams( - # q_x, - # k_x, - # v_x, - # q_y, - # k_y, - # v_y, - # valid_token_indices, - # ) + Args: + x: (B, N, dim_x) tensor for visual tokens + y: (B, L, dim_y) tensor of text token features + packed_indices: Dict with keys for Flash Attention + num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens - # return qkv + Returns: + x: (B, N, dim_x) tensor of visual tokens after multi-modal attention y: (B, L, dim_y) tensor of text token + features after multi-modal attention + """ + B, L, _ = y.shape + _, M, _ = x.shape - # @torch.compiler.disable() - # def run_attention( - # self, - # qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim) - # *, - # B: int, - # L: int, - # M: int, - # cu_seqlens: torch.Tensor, - # max_seqlen_in_batch: int, - # valid_token_indices: torch.Tensor, - # ): - # with torch.autocast("cuda", enabled=False): - # out: torch.Tensor = flash_attn_varlen_qkvpacked_func( - # qkv, - # cu_seqlens=cu_seqlens, - # max_seqlen=max_seqlen_in_batch, - # dropout_p=0.0, - # softmax_scale=self.softmax_scale, - # ) # (total, local_heads, head_dim) - # out = out.view(total, local_dim) + # Predict a packed QKV tensor from visual and text features. + # Don't checkpoint the all_to_all. + qkv = self.prepare_qkv( + x=x, + y=y, + scale_x=scale_x, + scale_y=scale_y, + rope_cos=rope_rotation.get("rope_cos"), + rope_sin=rope_rotation.get("rope_sin"), + # valid_token_indices=packed_indices["valid_token_indices_kv"], + ) # (total <= B * (N + L), 3, local_heads, head_dim) - # x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype) - # assert x.size() == (B, N, local_dim) - # assert y.size() == (B, L, local_dim) + x, y = self.run_attention( + qkv, + B=B, + L=L, + M=M, + # cu_seqlens=packed_indices["cu_seqlens_kv"], + # max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"], + # valid_token_indices=packed_indices["valid_token_indices_kv"], + ) + return x, y - # x = x.view(B, N, local_heads, self.head_dim) - # x = self.proj_x(x) # (B, M, dim_x) - # y = self.proj_y(y) # (B, L, dim_y) - # return x, y +def apply_rotary_emb_qk_real( + xqk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers. - # def forward( - # self, - # x: torch.Tensor, # (B, N, dim_x) - # y: torch.Tensor, # (B, L, dim_y) - # *, - # scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. - # scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. - # packed_indices: Dict[str, torch.Tensor] = None, - # **rope_rotation, - # ) -> Tuple[torch.Tensor, torch.Tensor]: - # """Forward pass of asymmetric multi-modal attention. + Args: + xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D) + Can be either just query or just key, or both stacked along some batch or * dim. + freqs_cos (torch.Tensor): Precomputed cosine frequency tensor. + freqs_sin (torch.Tensor): Precomputed sine frequency tensor. - # Args: - # x: (B, N, dim_x) tensor for visual tokens - # y: (B, L, dim_y) tensor of text token features - # packed_indices: Dict with keys for Flash Attention - # num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens + Returns: + torch.Tensor: The input tensor with rotary embeddings applied. + """ + # assert xqk.dtype == torch.bfloat16 + # Split the last dimension into even and odd parts + xqk_even = xqk[..., 0::2] + xqk_odd = xqk[..., 1::2] - # Returns: - # x: (B, N, dim_x) tensor of visual tokens after multi-modal attention - # y: (B, L, dim_y) tensor of text token features after multi-modal attention - # """ - # B, L, _ = y.shape - # _, M, _ = x.shape + # Apply rotation + cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk) + sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk) - # # Predict a packed QKV tensor from visual and text features. - # # Don't checkpoint the all_to_all. - # qkv = self.prepare_qkv( - # x=x, - # y=y, - # scale_x=scale_x, - # scale_y=scale_y, - # rope_cos=rope_rotation.get("rope_cos"), - # rope_sin=rope_rotation.get("rope_sin"), - # valid_token_indices=packed_indices["valid_token_indices_kv"], - # ) # (total <= B * (N + L), 3, local_heads, head_dim) + # Interleave the results back into the original shape + out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2) + # assert out.dtype == torch.bfloat16 + return out - # x, y = self.run_attention( - # qkv, - # B=B, - # L=L, - # M=M, - # cu_seqlens=packed_indices["cu_seqlens_kv"], - # max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"], - # valid_token_indices=packed_indices["valid_token_indices_kv"], - # ) - # return x, y + +class PadSplitXY(torch.autograd.Function): + """ + Merge heads, pad and extract visual and text tokens, and split along the sequence length. + """ + + @staticmethod + def forward( + ctx, + xy: torch.Tensor, + indices: torch.Tensor, + B: int, + N: int, + L: int, + dtype: torch.dtype, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim). + indices: Valid token indices out of unpacked tensor. Shape: (total,) + + Returns: + x: Visual tokens. Shape: (B, N, num_heads * head_dim). y: Text tokens. Shape: (B, L, num_heads * head_dim). + """ + ctx.save_for_backward(indices) + ctx.B, ctx.N, ctx.L = B, N, L + D = xy.size(1) + + # Pad sequences to (B, N + L, dim). + assert indices.ndim == 1 + output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype) + indices = indices.unsqueeze(1).expand(-1, D) # (total,) -> (total, num_heads * head_dim) + output.scatter_(0, indices, xy) + xy = output.view(B, N + L, D) + + # Split visual and text tokens along the sequence length. + return torch.tensor_split(xy, (N,), dim=1) + + +def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]: + return PadSplitXY.apply(xy, indices, B, N, L, dtype) + + +class UnifyStreams(torch.autograd.Function): + """Unify visual and text streams.""" + + @staticmethod + def forward( + ctx, + q_x: torch.Tensor, + k_x: torch.Tensor, + v_x: torch.Tensor, + q_y: torch.Tensor, + k_y: torch.Tensor, + v_y: torch.Tensor, + indices: torch.Tensor, + ): + """ + Args: + q_x: (B, N, num_heads, head_dim) + k_x: (B, N, num_heads, head_dim) + v_x: (B, N, num_heads, head_dim) + q_y: (B, L, num_heads, head_dim) + k_y: (B, L, num_heads, head_dim) + v_y: (B, L, num_heads, head_dim) + indices: (total <= B * (N + L)) + + Returns: + qkv: (total <= B * (N + L), 3, num_heads, head_dim) + """ + ctx.save_for_backward(indices) + B, N, num_heads, head_dim = q_x.size() + ctx.B, ctx.N, ctx.L = B, N, q_y.size(1) + D = num_heads * head_dim + + q = torch.cat([q_x, q_y], dim=1) + k = torch.cat([k_x, k_y], dim=1) + v = torch.cat([v_x, v_y], dim=1) + qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D) + + # indices = indices[:, None, None].expand(-1, 3, D) + # qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim) + return qkv.unflatten(2, (num_heads, head_dim)) + + +def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor: + return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices) class FinalLayer(nn.Module): @@ -726,13 +857,9 @@ class FinalLayer(nn.Module): device: Optional[torch.device] = None, ): super().__init__() - self.norm_final = nn.LayerNorm( - hidden_size, elementwise_affine=False, eps=1e-6, device=device - ) + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, device=device) self.mod = nn.Linear(hidden_size, 2 * hidden_size, device=device) - self.linear = nn.Linear( - hidden_size, patch_size * patch_size * out_channels, device=device - ) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, device=device) def forward(self, x, c): c = F.silu(c) @@ -777,15 +904,11 @@ class MochiTransformer3DModel(nn.Module): self.num_heads = num_heads self.hidden_size_x = hidden_size_x self.hidden_size_y = hidden_size_y - self.head_dim = ( - hidden_size_x // num_heads - ) # Head dimension and count is determined by visual. + self.head_dim = hidden_size_x // num_heads # Head dimension and count is determined by visual. self.use_extended_posenc = use_extended_posenc self.t5_token_length = t5_token_length self.t5_feat_dim = t5_feat_dim - self.rope_theta = ( - rope_theta # Scaling factor for frequency computation for temporal RoPE. - ) + self.rope_theta = rope_theta # Scaling factor for frequency computation for temporal RoPE. self.x_embedder = PatchEmbed( patch_size=patch_size, @@ -796,24 +919,16 @@ class MochiTransformer3DModel(nn.Module): ) # Conditionings # Timestep - self.t_embedder = TimestepEmbedder( - hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale - ) + self.t_embedder = TimestepEmbedder(hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale) # Caption Pooling (T5) - self.t5_y_embedder = AttentionPool( - t5_feat_dim, num_heads=8, output_dim=hidden_size_x, device=device - ) + self.t5_y_embedder = AttentionPool(t5_feat_dim, num_heads=8, output_dim=hidden_size_x, device=device) # Dense Embedding Projection (T5) - self.t5_yproj = nn.Linear( - t5_feat_dim, hidden_size_y, bias=True, device=device - ) + self.t5_yproj = nn.Linear(t5_feat_dim, hidden_size_y, bias=True, device=device) # Initialize pos_frequencies as an empty parameter. - self.pos_frequencies = nn.Parameter( - torch.empty(3, self.num_heads, self.head_dim // 2, device=device) - ) + self.pos_frequencies = nn.Parameter(torch.empty(3, self.num_heads, self.head_dim // 2, device=device)) # for depth 48: # b = 0: AsymmetricJointBlock, update_y=True @@ -839,9 +954,7 @@ class MochiTransformer3DModel(nn.Module): blocks.append(block) self.blocks = nn.ModuleList(blocks) - self.final_layer = FinalLayer( - hidden_size_x, patch_size, self.out_channels, device=device - ) + self.final_layer = FinalLayer(hidden_size_x, patch_size, self.out_channels, device=device) def embed_x(self, x: torch.Tensor) -> torch.Tensor: """ @@ -861,6 +974,7 @@ class MochiTransformer3DModel(nn.Module): t5_mask: torch.Tensor, ): """Prepare input and conditioning embeddings.""" + breakpoint() with torch.profiler.record_function("x_emb_pe"): # Visual patch embeddings with positional encoding. @@ -878,9 +992,7 @@ class MochiTransformer3DModel(nn.Module): pH, pW = H // self.patch_size, W // self.patch_size N = T * pH * pW assert x.size(1) == N - pos = create_position_matrix( - T, pH=pH, pW=pW, device=x.device, dtype=torch.float32 - ) # (N, 3) + pos = create_position_matrix(T, pH=pH, pW=pW, device=x.device, dtype=torch.float32) # (N, 3) rope_cos, rope_sin = compute_mixed_rotation( freqs=self.pos_frequencies, pos=pos ) # Each are (N, num_heads, dim // 2) @@ -896,9 +1008,7 @@ class MochiTransformer3DModel(nn.Module): t5_feat.size(1) == self.t5_token_length ), f"Expected L={self.t5_token_length}, got {t5_feat.shape} for y_feat." t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D) - assert ( - t5_y_pool.size(0) == B - ), f"Expected B={B}, got {t5_y_pool.shape} for t5_y_pool." + assert t5_y_pool.size(0) == B, f"Expected B={B}, got {t5_y_pool.shape} for t5_y_pool." c = c_t + t5_y_pool @@ -921,16 +1031,17 @@ class MochiTransformer3DModel(nn.Module): Args: x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images) sigma: (B,) tensor of noise standard deviations - y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048) + y_feat: + List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, + y_feat_dim=2048) y_mask: List((B, L) boolean tensor indicating which tokens are not padding) packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices. """ B, _, T, H, W = x.shape - x, c, y_feat, rope_cos, rope_sin = self.prepare( - x, sigma, y_feat[0], y_mask[0] - ) + x, c, y_feat, rope_cos, rope_sin = self.prepare(x, sigma, y_feat[0], y_mask[0]) del y_mask + breakpoint() for i, block in enumerate(self.blocks): x, y_feat = block( diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 10d0399a67..908865be5d 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -347,6 +347,21 @@ class LuminaNextDiT2DModel(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class MochiTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ModelMixin(metaclass=DummyObject): _backends = ["torch"]