From 0a6189eb952b800ddb1c24da245d75490ac95478 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 25 Oct 2024 01:02:46 +0200 Subject: [PATCH] add --- .../autoencoders/autoencoder_kl_mochi.py | 350 ++++++++++++++++-- 1 file changed, 318 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index 9afe9ac069..4f427f3ce1 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -40,7 +40,29 @@ import torch.nn as nn import torch.nn.functional as F -class MochiCausalConv3d(nn.Module): +# YiYi to-do: replace this with nn.Conv3d +class Conv1x1(nn.Linear): + """*1x1 Conv implemented with a linear layer.""" + + def __init__(self, in_features: int, out_features: int, *args, **kwargs): + super().__init__(in_features, out_features, *args, **kwargs) + + def forward(self, x: torch.Tensor): + """Forward pass. + + Args: + x: Input tensor. Shape: [B, C, *] or [B, *, C]. + + Returns: + x: Output tensor. Shape: [B, C', *] or [B, *, C']. + """ + x = x.movedim(1, -1) + x = super().forward(x) + x = x.movedim(-1, 1) + return x + + +class MochiChunkedCausalConv3d(nn.Module): r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. Args: @@ -81,50 +103,42 @@ class MochiCausalConv3d(nn.Module): padding=(0, height_pad, width_pad), padding_mode=padding_mode, ) - self.time_kernel_size = time_kernel_size - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - context_size = self.time_kernel_size - 1 + def forward(self, hidden_states : torch.Tensor) -> torch.Tensor: + time_kernel_size = self.conv.kernel_size[0] + context_size = time_kernel_size - 1 time_casual_padding = (0, 0, 0, 0, context_size, 0) - - inputs = F.pad(inputs, time_casual_padding, mode=self.padding_mode) + hidden_states = F.pad(hidden_states, time_casual_padding, mode=self.padding_mode) # Memory-efficient chunked operation - memory_count = torch.prod(torch.tensor(inputs.shape)).item() * 2 / 1024**3 + memory_count = torch.prod(torch.tensor(hidden_states.shape)).item() * 2 / 1024**3 + # YiYI Notes: testing only!! please remove + memory_count = 3 if memory_count > 2: part_num = int(memory_count / 2) + 1 - k = self.time_kernel_size - input_idx = torch.arange(context_size, inputs.size(2)) - input_chunks_idx = torch.split(input_idx, input_idx.size(0) // part_num) + num_frames = hidden_states.shape[2] + frames_idx = torch.arange(context_size, num_frames) + frames_chunks_idx = torch.chunk(frames_idx, part_num, dim=0) - # Compute output size - B, _, T_in, H_in, W_in = inputs.shape - output_size = ( - B, - self.conv.out_channels, - T_in - k + 1, - H_in // self.conv.stride[1], - W_in // self.conv.stride[2], - ) - output = torch.empty(output_size, dtype=inputs.dtype, device=inputs.device) - for input_chunk_idx in input_chunks_idx: - input_s = input_chunk_idx[0] - k + 1 - input_e = input_chunk_idx[-1] + 1 - input_chunk = inputs[:, :, input_s:input_e, :, :] - output_chunk = self.conv(input_chunk) + output_chunks = [] + for frames_chunk_idx in frames_chunks_idx: + frames_s = frames_chunk_idx[0] - context_size + frames_e = frames_chunk_idx[-1] + 1 + frames_chunk = hidden_states[:, :, frames_s:frames_e, :, :] + output_chunk = self.conv(frames_chunk) + output_chunks.append(output_chunk) # Append each output chunk to the list - output_s = input_s - output_e = output_s + output_chunk.size(2) - output[:, :, output_s:output_e, :, :] = output_chunk + # Concatenate all output chunks along the temporal dimension + hidden_states = torch.cat(output_chunks, dim=2) - return output + return hidden_states else: - return self.conv(inputs) + return self.conv(hidden_states) -class MochiGroupNorm3D(nn.Module): +class MochiChunkedGroupNorm3D(nn.Module): r""" Group normalization applied per-frame. @@ -134,10 +148,13 @@ class MochiGroupNorm3D(nn.Module): def __init__( self, + num_channels: int, + num_groups: int = 32, + affine: bool = True, chunk_size: int = 8, ): super().__init__() - self.norm_layer = nn.GroupNorm() + self.norm_layer = nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, affine=affine) self.chunk_size = chunk_size def forward( @@ -158,3 +175,272 @@ class MochiGroupNorm3D(nn.Module): return output +class MochiResnetBlock3D(nn.Module): + r""" + A 3D ResNet block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + spatial_norm_dim (`int`, *optional*): + The dimension to use for spatial norm if it is to be used instead of group norm. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + non_linearity: str = "swish", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.in_channels = in_channels + self.out_channels = out_channels + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels) + self.conv1 = MochiChunkedCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1 + ) + self.norm2 = MochiChunkedGroupNorm3D(num_channels=out_channels) + self.conv2 = MochiChunkedCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1 + ) + + + def forward( + self, + inputs: torch.Tensor, + ) -> torch.Tensor: + + hidden_states = inputs + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv2(hidden_states) + + hidden_states = hidden_states + inputs + return hidden_states + + +class MochiUpBlock3D(nn.Module): + r""" + An upsampling block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + temporal_expansion: int = 2, + spatial_expansion: int = 2, + ): + super().__init__() + self.temporal_expansion = temporal_expansion + self.spatial_expansion = spatial_expansion + + resnets = [] + for i in range(num_layers): + resnets.append( + MochiResnetBlock3D( + in_channels=in_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.proj = Conv1x1( + in_channels, + out_channels * temporal_expansion * (spatial_expansion**2), + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + r"""Forward method of the `MochiUpBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + ) + else: + hidden_states = resnet(hidden_states) + + hidden_states = self.proj(hidden_states) + + # Calculate new shape + B, C, T, H, W = hidden_states.shape + st = self.temporal_expansion + sh = self.spatial_expansion + sw = self.spatial_expansion + new_C = C // (st * sh * sw) + + # Reshape and permute + hidden_states = hidden_states.view(B, new_C, st, sh, sw, T, H, W) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4) + hidden_states = hidden_states.contiguous().view(B, new_C, T * st, H * sh, W * sw) + + if self.temporal_expansion > 1: + print(f"x: {hidden_states.shape}") + # Drop the first self.temporal_expansion - 1 frames. + hidden_states = hidden_states[:, :, self.temporal_expansion - 1 :] + print(f"x: {hidden_states.shape}") + + return hidden_states + + +class MochiMidBlock3D(nn.Module): + r""" + A middle block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, # 768 + num_layers: int = 3, + ): + super().__init__() + + resnets = [] + for _ in range(num_layers): + resnets.append( + MochiResnetBlock3D(in_channels=in_channels) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + r"""Forward method of the `MochiMidBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states + ) + else: + hidden_states = resnet(hidden_states) + + return hidden_states + + +class MochiDecoder3D(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, # 12 + out_channels: int, # 3 + block_out_channels: Tuple[int, ...] = (128, 256, 512, 768), + layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), + temporal_expansions: Tuple[int, ...] = (1, 2, 3), + spatial_expansions: Tuple[int, ...] = (2, 2, 2), + non_linearity: str = "swish", + ): + super().__init__() + + self.nonlinearity = get_activation(non_linearity) + + self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1)) + self.block_in = MochiMidBlock3D( + in_channels=block_out_channels[-1], + num_layers=layers_per_block[-1], + ) + self.up_blocks = nn.ModuleList([]) + for i in range(len(block_out_channels) - 1): + up_block = MochiUpBlock3D( + in_channels=block_out_channels[-i - 1], + out_channels=block_out_channels[-i - 2], + num_layers=layers_per_block[-i - 2], + temporal_expansion=temporal_expansions[-i - 1], + spatial_expansion=spatial_expansions[-i - 1], + ) + self.up_blocks.append(up_block) + self.block_out = MochiMidBlock3D( + in_channels=block_out_channels[0], + num_layers=layers_per_block[0], + ) + self.conv_out = Conv1x1(block_out_channels[0], out_channels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r"""Forward method of the `MochiDecoder3D` class.""" + + print(f"hidden_states: {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") + hidden_states = self.conv_in(hidden_states) + print(f"hidden_states (after conv_in): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") + + + # 1. Mid + hidden_states = self.block_in(hidden_states) + print(f"hidden_states (after block_in): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") + # 2. Up + for i, up_block in enumerate(self.up_blocks): + hidden_states = up_block(hidden_states) + print(f"hidden_states (after up_block {i}): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") + # 3. Post-process + hidden_states = self.block_out(hidden_states) + print(f"hidden_states (after block_out): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv_out(hidden_states) + print(f"hidden_states (after conv_out): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") + + return hidden_states + \ No newline at end of file