mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
make style
This commit is contained in:
@@ -3096,9 +3096,6 @@ class MochiAttnProcessor2_0:
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
breakpoint()
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
@@ -3124,8 +3121,9 @@ class MochiAttnProcessor2_0:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(x, freqs_cos, freqs_sin):
|
||||
x_even = x[..., 0::2].float()
|
||||
x_odd = x[..., 1::2].float()
|
||||
@@ -3137,9 +3135,13 @@ class MochiAttnProcessor2_0:
|
||||
|
||||
query = apply_rotary_emb(query, *image_rotary_emb)
|
||||
key = apply_rotary_emb(key, *image_rotary_emb)
|
||||
|
||||
|
||||
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
||||
encoder_query, encoder_key, encoder_value = encoder_query.transpose(1, 2), encoder_key.transpose(1, 2), encoder_value.transpose(1, 2)
|
||||
encoder_query, encoder_key, encoder_value = (
|
||||
encoder_query.transpose(1, 2),
|
||||
encoder_key.transpose(1, 2),
|
||||
encoder_value.transpose(1, 2),
|
||||
)
|
||||
|
||||
sequence_length = query.size(2)
|
||||
encoder_sequence_length = encoder_query.size(2)
|
||||
@@ -3152,7 +3154,9 @@ class MochiAttnProcessor2_0:
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length), dim=1)
|
||||
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
|
||||
(sequence_length, encoder_sequence_length), dim=1
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
|
||||
@@ -145,16 +145,23 @@ class MochiTransformerBlock(nn.Module):
|
||||
class MochiRoPE(nn.Module):
|
||||
def __init__(self, base_height: int = 192, base_width: int = 192, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.target_area = base_height * base_width
|
||||
|
||||
|
||||
def _centers(self, start, stop, num, device, dtype) -> torch.Tensor:
|
||||
edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype)
|
||||
return (edges[:-1] + edges[1:]) / 2
|
||||
|
||||
def _get_positions(self, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
|
||||
def _get_positions(
|
||||
self,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
scale = (self.target_area / (height * width)) ** 0.5
|
||||
|
||||
|
||||
t = torch.arange(num_frames, device=device, dtype=dtype)
|
||||
h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype)
|
||||
w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype)
|
||||
@@ -170,7 +177,15 @@ class MochiRoPE(nn.Module):
|
||||
freqs_sin = torch.sin(freqs)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
def forward(self, pos_frequencies: torch.Tensor, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(
|
||||
self,
|
||||
pos_frequencies: torch.Tensor,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
pos = self._get_positions(num_frames, height, width, device, dtype)
|
||||
rope_cos, rope_sin = self._create_rope(pos_frequencies, pos)
|
||||
return rope_cos, rope_sin
|
||||
@@ -261,7 +276,14 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
|
||||
|
||||
image_rotary_emb = self.rope(self.pos_frequencies, num_frames, post_patch_height, post_patch_width, device=hidden_states.device, dtype=torch.float32)
|
||||
image_rotary_emb = self.rope(
|
||||
self.pos_frequencies,
|
||||
num_frames,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
|
||||
@@ -96,8 +96,7 @@ def compute_mixed_rotation(
|
||||
num_heads: int
|
||||
|
||||
Returns:
|
||||
freqs_cos: [N, num_heads, num_freqs] - cosine components
|
||||
freqs_sin: [N, num_heads, num_freqs] - sine components
|
||||
freqs_cos: [N, num_heads, num_freqs] - cosine components freqs_sin: [N, num_heads, num_freqs] - sine components
|
||||
"""
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
assert freqs.ndim == 3
|
||||
@@ -470,8 +469,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
||||
|
||||
Returns:
|
||||
x: (B, N, dim) tensor of visual tokens after block
|
||||
y: (B, L, dim) tensor of text tokens after block
|
||||
x: (B, N, dim) tensor of visual tokens after block y: (B, L, dim) tensor of text tokens after block
|
||||
"""
|
||||
breakpoint()
|
||||
N = x.size(1)
|
||||
@@ -651,7 +649,7 @@ class AsymmetricAttention(nn.Module):
|
||||
breakpoint()
|
||||
N = M
|
||||
local_heads = self.num_heads
|
||||
local_dim = local_heads * self.head_dim
|
||||
# local_dim = local_heads * self.head_dim
|
||||
# with torch.autocast("cuda", enabled=False):
|
||||
# out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
|
||||
# qkv,
|
||||
@@ -696,8 +694,8 @@ class AsymmetricAttention(nn.Module):
|
||||
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
||||
|
||||
Returns:
|
||||
x: (B, N, dim_x) tensor of visual tokens after multi-modal attention
|
||||
y: (B, L, dim_y) tensor of text token features after multi-modal attention
|
||||
x: (B, N, dim_x) tensor of visual tokens after multi-modal attention y: (B, L, dim_y) tensor of text token
|
||||
features after multi-modal attention
|
||||
"""
|
||||
B, L, _ = y.shape
|
||||
_, M, _ = x.shape
|
||||
@@ -725,6 +723,7 @@ class AsymmetricAttention(nn.Module):
|
||||
)
|
||||
return x, y
|
||||
|
||||
|
||||
def apply_rotary_emb_qk_real(
|
||||
xqk: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
@@ -756,10 +755,10 @@ def apply_rotary_emb_qk_real(
|
||||
# assert out.dtype == torch.bfloat16
|
||||
return out
|
||||
|
||||
|
||||
class PadSplitXY(torch.autograd.Function):
|
||||
"""
|
||||
Merge heads, pad and extract visual and text tokens,
|
||||
and split along the sequence length.
|
||||
Merge heads, pad and extract visual and text tokens, and split along the sequence length.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -778,8 +777,7 @@ class PadSplitXY(torch.autograd.Function):
|
||||
indices: Valid token indices out of unpacked tensor. Shape: (total,)
|
||||
|
||||
Returns:
|
||||
x: Visual tokens. Shape: (B, N, num_heads * head_dim).
|
||||
y: Text tokens. Shape: (B, L, num_heads * head_dim).
|
||||
x: Visual tokens. Shape: (B, N, num_heads * head_dim). y: Text tokens. Shape: (B, L, num_heads * head_dim).
|
||||
"""
|
||||
ctx.save_for_backward(indices)
|
||||
ctx.B, ctx.N, ctx.L = B, N, L
|
||||
@@ -788,9 +786,7 @@ class PadSplitXY(torch.autograd.Function):
|
||||
# Pad sequences to (B, N + L, dim).
|
||||
assert indices.ndim == 1
|
||||
output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype)
|
||||
indices = indices.unsqueeze(1).expand(
|
||||
-1, D
|
||||
) # (total,) -> (total, num_heads * head_dim)
|
||||
indices = indices.unsqueeze(1).expand(-1, D) # (total,) -> (total, num_heads * head_dim)
|
||||
output.scatter_(0, indices, xy)
|
||||
xy = output.view(B, N + L, D)
|
||||
|
||||
@@ -801,6 +797,7 @@ class PadSplitXY(torch.autograd.Function):
|
||||
def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return PadSplitXY.apply(xy, indices, B, N, L, dtype)
|
||||
|
||||
|
||||
class UnifyStreams(torch.autograd.Function):
|
||||
"""Unify visual and text streams."""
|
||||
|
||||
@@ -1034,7 +1031,9 @@ class MochiTransformer3DModel(nn.Module):
|
||||
Args:
|
||||
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
sigma: (B,) tensor of noise standard deviations
|
||||
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
|
||||
y_feat:
|
||||
List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77,
|
||||
y_feat_dim=2048)
|
||||
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
|
||||
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user