From 46f95d5cdbb17e15f454d7a959c8244b30ddcb7e Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 13:49:12 +0200 Subject: [PATCH] make style --- src/diffusers/models/attention_processor.py | 18 ++++++---- .../models/transformers/transformer_mochi.py | 36 +++++++++++++++---- .../transformer_mochi_original.py | 29 ++++++++------- 3 files changed, 54 insertions(+), 29 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index dd61a6ab2c..c17556463c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3096,9 +3096,6 @@ class MochiAttnProcessor2_0: attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - breakpoint() - batch_size = hidden_states.size(0) - query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) @@ -3124,8 +3121,9 @@ class MochiAttnProcessor2_0: encoder_query = attn.norm_added_q(encoder_query) if attn.norm_added_k is not None: encoder_key = attn.norm_added_k(encoder_key) - + if image_rotary_emb is not None: + def apply_rotary_emb(x, freqs_cos, freqs_sin): x_even = x[..., 0::2].float() x_odd = x[..., 1::2].float() @@ -3137,9 +3135,13 @@ class MochiAttnProcessor2_0: query = apply_rotary_emb(query, *image_rotary_emb) key = apply_rotary_emb(key, *image_rotary_emb) - + query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) - encoder_query, encoder_key, encoder_value = encoder_query.transpose(1, 2), encoder_key.transpose(1, 2), encoder_value.transpose(1, 2) + encoder_query, encoder_key, encoder_value = ( + encoder_query.transpose(1, 2), + encoder_key.transpose(1, 2), + encoder_value.transpose(1, 2), + ) sequence_length = query.size(2) encoder_sequence_length = encoder_query.size(2) @@ -3152,7 +3154,9 @@ class MochiAttnProcessor2_0: hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.to(query.dtype) - hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length), dim=1) + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( + (sequence_length, encoder_sequence_length), dim=1 + ) # linear proj hidden_states = attn.to_out[0](hidden_states) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 7938d5e39f..3b6c0decbe 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -145,16 +145,23 @@ class MochiTransformerBlock(nn.Module): class MochiRoPE(nn.Module): def __init__(self, base_height: int = 192, base_width: int = 192, theta: float = 10000.0) -> None: super().__init__() - + self.target_area = base_height * base_width - + def _centers(self, start, stop, num, device, dtype) -> torch.Tensor: edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype) return (edges[:-1] + edges[1:]) / 2 - - def _get_positions(self, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + + def _get_positions( + self, + num_frames: int, + height: int, + width: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: scale = (self.target_area / (height * width)) ** 0.5 - + t = torch.arange(num_frames, device=device, dtype=dtype) h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype) w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype) @@ -170,7 +177,15 @@ class MochiRoPE(nn.Module): freqs_sin = torch.sin(freqs) return freqs_cos, freqs_sin - def forward(self, pos_frequencies: torch.Tensor, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, + pos_frequencies: torch.Tensor, + num_frames: int, + height: int, + width: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: pos = self._get_positions(num_frames, height, width, device, dtype) rope_cos, rope_sin = self._create_rope(pos_frequencies, pos) return rope_cos, rope_sin @@ -261,7 +276,14 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): hidden_states = self.patch_embed(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) - image_rotary_emb = self.rope(self.pos_frequencies, num_frames, post_patch_height, post_patch_width, device=hidden_states.device, dtype=torch.float32) + image_rotary_emb = self.rope( + self.pos_frequencies, + num_frames, + post_patch_height, + post_patch_width, + device=hidden_states.device, + dtype=torch.float32, + ) for i, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = block( diff --git a/src/diffusers/models/transformers/transformer_mochi_original.py b/src/diffusers/models/transformers/transformer_mochi_original.py index 2dad5c5c86..a428e57a3b 100644 --- a/src/diffusers/models/transformers/transformer_mochi_original.py +++ b/src/diffusers/models/transformers/transformer_mochi_original.py @@ -96,8 +96,7 @@ def compute_mixed_rotation( num_heads: int Returns: - freqs_cos: [N, num_heads, num_freqs] - cosine components - freqs_sin: [N, num_heads, num_freqs] - sine components + freqs_cos: [N, num_heads, num_freqs] - cosine components freqs_sin: [N, num_heads, num_freqs] - sine components """ with torch.autocast("cuda", enabled=False): assert freqs.ndim == 3 @@ -470,8 +469,7 @@ class AsymmetricJointBlock(nn.Module): num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens Returns: - x: (B, N, dim) tensor of visual tokens after block - y: (B, L, dim) tensor of text tokens after block + x: (B, N, dim) tensor of visual tokens after block y: (B, L, dim) tensor of text tokens after block """ breakpoint() N = x.size(1) @@ -651,7 +649,7 @@ class AsymmetricAttention(nn.Module): breakpoint() N = M local_heads = self.num_heads - local_dim = local_heads * self.head_dim + # local_dim = local_heads * self.head_dim # with torch.autocast("cuda", enabled=False): # out: torch.Tensor = flash_attn_varlen_qkvpacked_func( # qkv, @@ -696,8 +694,8 @@ class AsymmetricAttention(nn.Module): num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens Returns: - x: (B, N, dim_x) tensor of visual tokens after multi-modal attention - y: (B, L, dim_y) tensor of text token features after multi-modal attention + x: (B, N, dim_x) tensor of visual tokens after multi-modal attention y: (B, L, dim_y) tensor of text token + features after multi-modal attention """ B, L, _ = y.shape _, M, _ = x.shape @@ -725,6 +723,7 @@ class AsymmetricAttention(nn.Module): ) return x, y + def apply_rotary_emb_qk_real( xqk: torch.Tensor, freqs_cos: torch.Tensor, @@ -756,10 +755,10 @@ def apply_rotary_emb_qk_real( # assert out.dtype == torch.bfloat16 return out + class PadSplitXY(torch.autograd.Function): """ - Merge heads, pad and extract visual and text tokens, - and split along the sequence length. + Merge heads, pad and extract visual and text tokens, and split along the sequence length. """ @staticmethod @@ -778,8 +777,7 @@ class PadSplitXY(torch.autograd.Function): indices: Valid token indices out of unpacked tensor. Shape: (total,) Returns: - x: Visual tokens. Shape: (B, N, num_heads * head_dim). - y: Text tokens. Shape: (B, L, num_heads * head_dim). + x: Visual tokens. Shape: (B, N, num_heads * head_dim). y: Text tokens. Shape: (B, L, num_heads * head_dim). """ ctx.save_for_backward(indices) ctx.B, ctx.N, ctx.L = B, N, L @@ -788,9 +786,7 @@ class PadSplitXY(torch.autograd.Function): # Pad sequences to (B, N + L, dim). assert indices.ndim == 1 output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype) - indices = indices.unsqueeze(1).expand( - -1, D - ) # (total,) -> (total, num_heads * head_dim) + indices = indices.unsqueeze(1).expand(-1, D) # (total,) -> (total, num_heads * head_dim) output.scatter_(0, indices, xy) xy = output.view(B, N + L, D) @@ -801,6 +797,7 @@ class PadSplitXY(torch.autograd.Function): def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]: return PadSplitXY.apply(xy, indices, B, N, L, dtype) + class UnifyStreams(torch.autograd.Function): """Unify visual and text streams.""" @@ -1034,7 +1031,9 @@ class MochiTransformer3DModel(nn.Module): Args: x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images) sigma: (B,) tensor of noise standard deviations - y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048) + y_feat: + List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, + y_feat_dim=2048) y_mask: List((B, L) boolean tensor indicating which tokens are not padding) packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices. """