1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

make style

This commit is contained in:
Aryan
2024-10-24 13:49:12 +02:00
parent 2fd2ec4025
commit 46f95d5cdb
3 changed files with 54 additions and 29 deletions

View File

@@ -3096,9 +3096,6 @@ class MochiAttnProcessor2_0:
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
breakpoint()
batch_size = hidden_states.size(0)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
@@ -3124,8 +3121,9 @@ class MochiAttnProcessor2_0:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
if image_rotary_emb is not None:
def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()
@@ -3137,9 +3135,13 @@ class MochiAttnProcessor2_0:
query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = encoder_query.transpose(1, 2), encoder_key.transpose(1, 2), encoder_value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)
sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)
@@ -3152,7 +3154,9 @@ class MochiAttnProcessor2_0:
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length), dim=1)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)

View File

@@ -145,16 +145,23 @@ class MochiTransformerBlock(nn.Module):
class MochiRoPE(nn.Module):
def __init__(self, base_height: int = 192, base_width: int = 192, theta: float = 10000.0) -> None:
super().__init__()
self.target_area = base_height * base_width
def _centers(self, start, stop, num, device, dtype) -> torch.Tensor:
edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype)
return (edges[:-1] + edges[1:]) / 2
def _get_positions(self, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
def _get_positions(
self,
num_frames: int,
height: int,
width: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
scale = (self.target_area / (height * width)) ** 0.5
t = torch.arange(num_frames, device=device, dtype=dtype)
h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype)
w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype)
@@ -170,7 +177,15 @@ class MochiRoPE(nn.Module):
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin
def forward(self, pos_frequencies: torch.Tensor, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(
self,
pos_frequencies: torch.Tensor,
num_frames: int,
height: int,
width: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
pos = self._get_positions(num_frames, height, width, device, dtype)
rope_cos, rope_sin = self._create_rope(pos_frequencies, pos)
return rope_cos, rope_sin
@@ -261,7 +276,14 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
image_rotary_emb = self.rope(self.pos_frequencies, num_frames, post_patch_height, post_patch_width, device=hidden_states.device, dtype=torch.float32)
image_rotary_emb = self.rope(
self.pos_frequencies,
num_frames,
post_patch_height,
post_patch_width,
device=hidden_states.device,
dtype=torch.float32,
)
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(

View File

@@ -96,8 +96,7 @@ def compute_mixed_rotation(
num_heads: int
Returns:
freqs_cos: [N, num_heads, num_freqs] - cosine components
freqs_sin: [N, num_heads, num_freqs] - sine components
freqs_cos: [N, num_heads, num_freqs] - cosine components freqs_sin: [N, num_heads, num_freqs] - sine components
"""
with torch.autocast("cuda", enabled=False):
assert freqs.ndim == 3
@@ -470,8 +469,7 @@ class AsymmetricJointBlock(nn.Module):
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
Returns:
x: (B, N, dim) tensor of visual tokens after block
y: (B, L, dim) tensor of text tokens after block
x: (B, N, dim) tensor of visual tokens after block y: (B, L, dim) tensor of text tokens after block
"""
breakpoint()
N = x.size(1)
@@ -651,7 +649,7 @@ class AsymmetricAttention(nn.Module):
breakpoint()
N = M
local_heads = self.num_heads
local_dim = local_heads * self.head_dim
# local_dim = local_heads * self.head_dim
# with torch.autocast("cuda", enabled=False):
# out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
# qkv,
@@ -696,8 +694,8 @@ class AsymmetricAttention(nn.Module):
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
Returns:
x: (B, N, dim_x) tensor of visual tokens after multi-modal attention
y: (B, L, dim_y) tensor of text token features after multi-modal attention
x: (B, N, dim_x) tensor of visual tokens after multi-modal attention y: (B, L, dim_y) tensor of text token
features after multi-modal attention
"""
B, L, _ = y.shape
_, M, _ = x.shape
@@ -725,6 +723,7 @@ class AsymmetricAttention(nn.Module):
)
return x, y
def apply_rotary_emb_qk_real(
xqk: torch.Tensor,
freqs_cos: torch.Tensor,
@@ -756,10 +755,10 @@ def apply_rotary_emb_qk_real(
# assert out.dtype == torch.bfloat16
return out
class PadSplitXY(torch.autograd.Function):
"""
Merge heads, pad and extract visual and text tokens,
and split along the sequence length.
Merge heads, pad and extract visual and text tokens, and split along the sequence length.
"""
@staticmethod
@@ -778,8 +777,7 @@ class PadSplitXY(torch.autograd.Function):
indices: Valid token indices out of unpacked tensor. Shape: (total,)
Returns:
x: Visual tokens. Shape: (B, N, num_heads * head_dim).
y: Text tokens. Shape: (B, L, num_heads * head_dim).
x: Visual tokens. Shape: (B, N, num_heads * head_dim). y: Text tokens. Shape: (B, L, num_heads * head_dim).
"""
ctx.save_for_backward(indices)
ctx.B, ctx.N, ctx.L = B, N, L
@@ -788,9 +786,7 @@ class PadSplitXY(torch.autograd.Function):
# Pad sequences to (B, N + L, dim).
assert indices.ndim == 1
output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype)
indices = indices.unsqueeze(1).expand(
-1, D
) # (total,) -> (total, num_heads * head_dim)
indices = indices.unsqueeze(1).expand(-1, D) # (total,) -> (total, num_heads * head_dim)
output.scatter_(0, indices, xy)
xy = output.view(B, N + L, D)
@@ -801,6 +797,7 @@ class PadSplitXY(torch.autograd.Function):
def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
return PadSplitXY.apply(xy, indices, B, N, L, dtype)
class UnifyStreams(torch.autograd.Function):
"""Unify visual and text streams."""
@@ -1034,7 +1031,9 @@ class MochiTransformer3DModel(nn.Module):
Args:
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
sigma: (B,) tensor of noise standard deviations
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
y_feat:
List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77,
y_feat_dim=2048)
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
"""