From 2fd2ec40250e00de6965ec005aab1423b21b9291 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 13:48:22 +0200 Subject: [PATCH] fixes --- src/diffusers/models/attention_processor.py | 86 +++- src/diffusers/models/normalization.py | 4 +- .../models/transformers/transformer_mochi.py | 65 ++- .../transformer_mochi_original.py | 402 ++++++++++++------ 4 files changed, 410 insertions(+), 147 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ce0f9d87c8..dd61a6ab2c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1795,8 +1795,7 @@ class FluxAttnProcessor2_0: # dropout hidden_states = attn.to_out[1](hidden_states) - if hasattr(attn, "to_add_out"): - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states else: @@ -3082,6 +3081,89 @@ class LuminaAttnProcessor2_0: return hidden_states +class MochiAttnProcessor2_0: + """Attention processor used in Mochi.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + breakpoint() + batch_size = hidden_states.size(0) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + if image_rotary_emb is not None: + def apply_rotary_emb(x, freqs_cos, freqs_sin): + x_even = x[..., 0::2].float() + x_odd = x[..., 1::2].float() + + cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) + sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) + + return torch.stack([cos, sin], dim=-1).flatten(-2) + + query = apply_rotary_emb(query, *image_rotary_emb) + key = apply_rotary_emb(key, *image_rotary_emb) + + query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) + encoder_query, encoder_key, encoder_value = encoder_query.transpose(1, 2), encoder_key.transpose(1, 2), encoder_value.transpose(1, 2) + + sequence_length = query.size(2) + encoder_sequence_length = encoder_query.size(2) + + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length), dim=1) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + class FusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 9058320998..dcfaed90b3 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -246,13 +246,13 @@ class MochiRMSNormZero(nn.Module): """ def __init__( - self, embedding_dim: int, hidden_dim: int, norm_eps: float = 1e-5, elementwise_affine: bool = False + self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False ) -> None: super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, hidden_dim) - self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=elementwise_affine) + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) def forward( self, hidden_states: torch.Tensor, emb: torch.Tensor diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 7ece241e4b..7938d5e39f 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward -from ..attention_processor import Attention, FluxAttnProcessor2_0 +from ..attention_processor import Attention, MochiAttnProcessor2_0 from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -43,6 +43,7 @@ class MochiTransformerBlock(nn.Module): qk_norm: str = "rms_norm", activation_fn: str = "swiglu", context_pre_only: bool = True, + eps: float = 1e-6, ) -> None: super().__init__() @@ -50,15 +51,15 @@ class MochiTransformerBlock(nn.Module): self.ff_inner_dim = (4 * dim * 2) // 3 self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3 - self.norm1 = MochiRMSNormZero(dim, 4 * dim) + self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False) if not context_pre_only: - self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim) + self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False) else: self.norm1_context = LuminaLayerNormContinuous( embedding_dim=pooled_projection_dim, conditioning_embedding_dim=dim, - eps=1e-6, + eps=eps, elementwise_affine=False, norm_type="rms_norm", out_dim=None, @@ -76,16 +77,16 @@ class MochiTransformerBlock(nn.Module): out_dim=dim, out_context_dim=pooled_projection_dim, context_pre_only=context_pre_only, - processor=FluxAttnProcessor2_0(), - eps=1e-6, + processor=MochiAttnProcessor2_0(), + eps=eps, elementwise_affine=True, ) - self.norm2 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) - self.norm2_context = RMSNorm(pooled_projection_dim, eps=1e-6, elementwise_affine=False) + self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) - self.norm3 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) - self.norm3_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False) + self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) self.ff_context = None @@ -94,8 +95,8 @@ class MochiTransformerBlock(nn.Module): pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False ) - self.norm4 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) - self.norm4_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False) + self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) def forward( self, @@ -104,6 +105,7 @@ class MochiTransformerBlock(nn.Module): temb: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + breakpoint() norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) if not self.context_pre_only: @@ -140,6 +142,40 @@ class MochiTransformerBlock(nn.Module): return hidden_states, encoder_hidden_states +class MochiRoPE(nn.Module): + def __init__(self, base_height: int = 192, base_width: int = 192, theta: float = 10000.0) -> None: + super().__init__() + + self.target_area = base_height * base_width + + def _centers(self, start, stop, num, device, dtype) -> torch.Tensor: + edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype) + return (edges[:-1] + edges[1:]) / 2 + + def _get_positions(self, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + scale = (self.target_area / (height * width)) ** 0.5 + + t = torch.arange(num_frames, device=device, dtype=dtype) + h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype) + w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype) + + grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") + + positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3) + return positions + + def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: + freqs = torch.einsum("nd,dhf->nhf", pos, freqs) + freqs_cos = torch.cos(freqs) + freqs_sin = torch.sin(freqs) + return freqs_cos, freqs_sin + + def forward(self, pos_frequencies: torch.Tensor, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> Tuple[torch.Tensor, torch.Tensor]: + pos = self._get_positions(num_frames, height, width, device, dtype) + rope_cos, rope_sin = self._create_rope(pos_frequencies, pos) + return rope_cos, rope_sin + + @maybe_allow_in_graph class MochiTransformer3DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @@ -169,6 +205,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, + pos_embed_type=None, ) self.time_embed = MochiCombinedTimestepCaptionEmbedding( @@ -180,6 +217,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): ) self.pos_frequencies = nn.Parameter(torch.empty(3, num_attention_heads, attention_head_dim // 2)) + self.rope = MochiRoPE() self.transformer_blocks = nn.ModuleList( [ @@ -207,7 +245,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, ) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape @@ -224,6 +261,8 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): hidden_states = self.patch_embed(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) + image_rotary_emb = self.rope(self.pos_frequencies, num_frames, post_patch_height, post_patch_width, device=hidden_states.device, dtype=torch.float32) + for i, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, diff --git a/src/diffusers/models/transformers/transformer_mochi_original.py b/src/diffusers/models/transformers/transformer_mochi_original.py index 9c8924decb..2dad5c5c86 100644 --- a/src/diffusers/models/transformers/transformer_mochi_original.py +++ b/src/diffusers/models/transformers/transformer_mochi_original.py @@ -2,7 +2,7 @@ import collections import functools import itertools import math -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -473,6 +473,7 @@ class AsymmetricJointBlock(nn.Module): x: (B, N, dim) tensor of visual tokens after block y: (B, L, dim) tensor of text tokens after block """ + breakpoint() N = x.size(1) c = F.silu(c) @@ -559,152 +560,291 @@ class AsymmetricAttention(nn.Module): self.proj_x = nn.Linear(dim_x, dim_x, bias=out_bias, device=device) self.proj_y = nn.Linear(dim_x, dim_y, bias=out_bias, device=device) if update_y else nn.Identity() - # def run_qkv_y(self, y): - # cp_rank, cp_size = cp.get_cp_rank_size() - # local_heads = self.num_heads // cp_size + def run_qkv_y(self, y): + qkv_y = self.qkv_y(y) + qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, -1, self.head_dim) + q_y, k_y, v_y = qkv_y.unbind(2) + return q_y, k_y, v_y - # if cp.is_cp_active(): - # # Only predict local heads. - # assert not self.qkv_bias - # W_qkv_y = self.qkv_y.weight.view( - # 3, self.num_heads, self.head_dim, self.dim_y - # ) - # W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads) - # W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y) - # qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim) - # else: - # qkv_y = self.qkv_y(y) # (B, L, 3 * dim) + # cp_rank, cp_size = cp.get_cp_rank_size() + # local_heads = self.num_heads // cp_size - # qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim) - # q_y, k_y, v_y = qkv_y.unbind(2) - # return q_y, k_y, v_y + # if cp.is_cp_active(): + # # Only predict local heads. + # assert not self.qkv_bias + # W_qkv_y = self.qkv_y.weight.view( + # 3, self.num_heads, self.head_dim, self.dim_y + # ) + # W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads) + # W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y) + # qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim) + # else: + # qkv_y = self.qkv_y(y) # (B, L, 3 * dim) - # def prepare_qkv( - # self, - # x: torch.Tensor, # (B, N, dim_x) - # y: torch.Tensor, # (B, L, dim_y) - # *, - # scale_x: torch.Tensor, - # scale_y: torch.Tensor, - # rope_cos: torch.Tensor, - # rope_sin: torch.Tensor, - # valid_token_indices: torch.Tensor, - # ): - # # Pre-norm for visual features - # x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size + # qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim) + # q_y, k_y, v_y = qkv_y.unbind(2) + # return q_y, k_y, v_y - # # Process visual features - # qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) - # assert qkv_x.dtype == torch.bfloat16 - # qkv_x = cp.all_to_all_collect_tokens( - # qkv_x, self.num_heads - # ) # (3, B, N, local_h, head_dim) + def prepare_qkv( + self, + x: torch.Tensor, # (B, N, dim_x) + y: torch.Tensor, # (B, L, dim_y) + *, + scale_x: torch.Tensor, + scale_y: torch.Tensor, + rope_cos: torch.Tensor, + rope_sin: torch.Tensor, + valid_token_indices: torch.Tensor = None, + ): + breakpoint() + # Pre-norm for visual features + x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size - # # Process text features - # y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) - # q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim) - # q_y = self.q_norm_y(q_y) - # k_y = self.k_norm_y(k_y) + # Process visual features + qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) + # assert qkv_x.dtype == torch.bfloat16 + # qkv_x = cp.all_to_all_collect_tokens( + # qkv_x, self.num_heads + # ) # (3, B, N, local_h, head_dim) + B, M, _ = qkv_x.size() + qkv_x = qkv_x.view(B, M, 3, -1, 128) + qkv_x = qkv_x.permute(2, 0, 1, 3, 4) - # # Split qkv_x into q, k, v - # q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim) - # q_x = self.q_norm_x(q_x) - # q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) - # k_x = self.k_norm_x(k_x) - # k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) + # Process text features + y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) + q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim) + q_y = self.q_norm_y(q_y) + k_y = self.k_norm_y(k_y) - # # Unite streams - # qkv = unify_streams( - # q_x, - # k_x, - # v_x, - # q_y, - # k_y, - # v_y, - # valid_token_indices, - # ) + # Split qkv_x into q, k, v + q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim) + q_x = self.q_norm_x(q_x) + q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) + k_x = self.k_norm_x(k_x) + k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) - # return qkv + # Unite streams + qkv = unify_streams( + q_x, + k_x, + v_x, + q_y, + k_y, + v_y, + valid_token_indices, + ) - # @torch.compiler.disable() - # def run_attention( - # self, - # qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim) - # *, - # B: int, - # L: int, - # M: int, - # cu_seqlens: torch.Tensor, - # max_seqlen_in_batch: int, - # valid_token_indices: torch.Tensor, - # ): - # with torch.autocast("cuda", enabled=False): - # out: torch.Tensor = flash_attn_varlen_qkvpacked_func( - # qkv, - # cu_seqlens=cu_seqlens, - # max_seqlen=max_seqlen_in_batch, - # dropout_p=0.0, - # softmax_scale=self.softmax_scale, - # ) # (total, local_heads, head_dim) - # out = out.view(total, local_dim) + return qkv - # x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype) - # assert x.size() == (B, N, local_dim) - # assert y.size() == (B, L, local_dim) + @torch.compiler.disable() + def run_attention( + self, + qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim) + *, + B: int, + L: int, + M: int, + cu_seqlens: torch.Tensor = None, + max_seqlen_in_batch: int = None, + valid_token_indices: torch.Tensor = None, + ): + breakpoint() + N = M + local_heads = self.num_heads + local_dim = local_heads * self.head_dim + # with torch.autocast("cuda", enabled=False): + # out: torch.Tensor = flash_attn_varlen_qkvpacked_func( + # qkv, + # cu_seqlens=cu_seqlens, + # max_seqlen=max_seqlen_in_batch, + # dropout_p=0.0, + # softmax_scale=self.softmax_scale, + # ) # (total, local_heads, head_dim) + # out = out.view(total, local_dim) - # x = x.view(B, N, local_heads, self.head_dim) - # x = self.proj_x(x) # (B, M, dim_x) + q, k, v = qkv.unbind(1) + out = F.scaled_dot_product_attention(q, k, v) - # y = self.proj_y(y) # (B, L, dim_y) - # return x, y + # x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype) + x, y = out.split_with_sizes((N, L), dim=0) + # assert x.size() == (B, N, local_dim) + # assert y.size() == (B, L, local_dim) - # def forward( - # self, - # x: torch.Tensor, # (B, N, dim_x) - # y: torch.Tensor, # (B, L, dim_y) - # *, - # scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. - # scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. - # packed_indices: Dict[str, torch.Tensor] = None, - # **rope_rotation, - # ) -> Tuple[torch.Tensor, torch.Tensor]: - # """Forward pass of asymmetric multi-modal attention. + x = x.view(B, -1, local_heads, self.head_dim).flatten(2, 3) + x = self.proj_x(x) # (B, M, dim_x) - # Args: - # x: (B, N, dim_x) tensor for visual tokens - # y: (B, L, dim_y) tensor of text token features - # packed_indices: Dict with keys for Flash Attention - # num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens + y = y.view(B, -1, local_heads, self.head_dim).flatten(2, 3) + y = self.proj_y(y) # (B, L, dim_y) + return x, y - # Returns: - # x: (B, N, dim_x) tensor of visual tokens after multi-modal attention - # y: (B, L, dim_y) tensor of text token features after multi-modal attention - # """ - # B, L, _ = y.shape - # _, M, _ = x.shape + def forward( + self, + x: torch.Tensor, # (B, N, dim_x) + y: torch.Tensor, # (B, L, dim_y) + *, + scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. + scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. + packed_indices: Dict[str, torch.Tensor] = None, + **rope_rotation, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass of asymmetric multi-modal attention. - # # Predict a packed QKV tensor from visual and text features. - # # Don't checkpoint the all_to_all. - # qkv = self.prepare_qkv( - # x=x, - # y=y, - # scale_x=scale_x, - # scale_y=scale_y, - # rope_cos=rope_rotation.get("rope_cos"), - # rope_sin=rope_rotation.get("rope_sin"), - # valid_token_indices=packed_indices["valid_token_indices_kv"], - # ) # (total <= B * (N + L), 3, local_heads, head_dim) + Args: + x: (B, N, dim_x) tensor for visual tokens + y: (B, L, dim_y) tensor of text token features + packed_indices: Dict with keys for Flash Attention + num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens - # x, y = self.run_attention( - # qkv, - # B=B, - # L=L, - # M=M, - # cu_seqlens=packed_indices["cu_seqlens_kv"], - # max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"], - # valid_token_indices=packed_indices["valid_token_indices_kv"], - # ) - # return x, y + Returns: + x: (B, N, dim_x) tensor of visual tokens after multi-modal attention + y: (B, L, dim_y) tensor of text token features after multi-modal attention + """ + B, L, _ = y.shape + _, M, _ = x.shape + + # Predict a packed QKV tensor from visual and text features. + # Don't checkpoint the all_to_all. + qkv = self.prepare_qkv( + x=x, + y=y, + scale_x=scale_x, + scale_y=scale_y, + rope_cos=rope_rotation.get("rope_cos"), + rope_sin=rope_rotation.get("rope_sin"), + # valid_token_indices=packed_indices["valid_token_indices_kv"], + ) # (total <= B * (N + L), 3, local_heads, head_dim) + + x, y = self.run_attention( + qkv, + B=B, + L=L, + M=M, + # cu_seqlens=packed_indices["cu_seqlens_kv"], + # max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"], + # valid_token_indices=packed_indices["valid_token_indices_kv"], + ) + return x, y + +def apply_rotary_emb_qk_real( + xqk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers. + + Args: + xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D) + Can be either just query or just key, or both stacked along some batch or * dim. + freqs_cos (torch.Tensor): Precomputed cosine frequency tensor. + freqs_sin (torch.Tensor): Precomputed sine frequency tensor. + + Returns: + torch.Tensor: The input tensor with rotary embeddings applied. + """ + # assert xqk.dtype == torch.bfloat16 + # Split the last dimension into even and odd parts + xqk_even = xqk[..., 0::2] + xqk_odd = xqk[..., 1::2] + + # Apply rotation + cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk) + sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk) + + # Interleave the results back into the original shape + out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2) + # assert out.dtype == torch.bfloat16 + return out + +class PadSplitXY(torch.autograd.Function): + """ + Merge heads, pad and extract visual and text tokens, + and split along the sequence length. + """ + + @staticmethod + def forward( + ctx, + xy: torch.Tensor, + indices: torch.Tensor, + B: int, + N: int, + L: int, + dtype: torch.dtype, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim). + indices: Valid token indices out of unpacked tensor. Shape: (total,) + + Returns: + x: Visual tokens. Shape: (B, N, num_heads * head_dim). + y: Text tokens. Shape: (B, L, num_heads * head_dim). + """ + ctx.save_for_backward(indices) + ctx.B, ctx.N, ctx.L = B, N, L + D = xy.size(1) + + # Pad sequences to (B, N + L, dim). + assert indices.ndim == 1 + output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype) + indices = indices.unsqueeze(1).expand( + -1, D + ) # (total,) -> (total, num_heads * head_dim) + output.scatter_(0, indices, xy) + xy = output.view(B, N + L, D) + + # Split visual and text tokens along the sequence length. + return torch.tensor_split(xy, (N,), dim=1) + + +def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]: + return PadSplitXY.apply(xy, indices, B, N, L, dtype) + +class UnifyStreams(torch.autograd.Function): + """Unify visual and text streams.""" + + @staticmethod + def forward( + ctx, + q_x: torch.Tensor, + k_x: torch.Tensor, + v_x: torch.Tensor, + q_y: torch.Tensor, + k_y: torch.Tensor, + v_y: torch.Tensor, + indices: torch.Tensor, + ): + """ + Args: + q_x: (B, N, num_heads, head_dim) + k_x: (B, N, num_heads, head_dim) + v_x: (B, N, num_heads, head_dim) + q_y: (B, L, num_heads, head_dim) + k_y: (B, L, num_heads, head_dim) + v_y: (B, L, num_heads, head_dim) + indices: (total <= B * (N + L)) + + Returns: + qkv: (total <= B * (N + L), 3, num_heads, head_dim) + """ + ctx.save_for_backward(indices) + B, N, num_heads, head_dim = q_x.size() + ctx.B, ctx.N, ctx.L = B, N, q_y.size(1) + D = num_heads * head_dim + + q = torch.cat([q_x, q_y], dim=1) + k = torch.cat([k_x, k_y], dim=1) + v = torch.cat([v_x, v_y], dim=1) + qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D) + + # indices = indices[:, None, None].expand(-1, 3, D) + # qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim) + return qkv.unflatten(2, (num_heads, head_dim)) + + +def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor: + return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices) class FinalLayer(nn.Module): @@ -837,6 +977,7 @@ class MochiTransformer3DModel(nn.Module): t5_mask: torch.Tensor, ): """Prepare input and conditioning embeddings.""" + breakpoint() with torch.profiler.record_function("x_emb_pe"): # Visual patch embeddings with positional encoding. @@ -901,6 +1042,7 @@ class MochiTransformer3DModel(nn.Module): x, c, y_feat, rope_cos, rope_sin = self.prepare(x, sigma, y_feat[0], y_mask[0]) del y_mask + breakpoint() for i, block in enumerate(self.blocks): x, y_feat = block(