From 85a98254492ef8fe01659e102cea8cf3eb37d0bf Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 24 Oct 2024 18:12:44 +0200 Subject: [PATCH 1/5] init --- .../autoencoders/autoencoder_kl_mochi.py | 160 ++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_mochi.py diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py new file mode 100644 index 0000000000..9afe9ac069 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -0,0 +1,160 @@ +# Copyright 2024 The Mochi team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..downsampling import CogVideoXDownsample3D +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..upsampling import CogVideoXUpsample3D +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MochiCausalConv3d(nn.Module): + r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. + + Args: + in_channels (`int`): Number of channels in the input tensor. + out_channels (`int`): Number of output channels produced by the convolution. + kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel. + stride (`int` or `Tuple[int, int, int]`, defaults to `1`): Stride of the convolution. + pad_mode (`str`, defaults to `"constant"`): Padding mode. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]], + padding_mode: str = "replicate", + ): + super().__init__() + + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + self.padding_mode = padding_mode + height_pad = (height_kernel_size - 1) // 2 + width_pad = (width_kernel_size - 1) // 2 + + self.conv = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=(1, 1, 1), + padding=(0, height_pad, width_pad), + padding_mode=padding_mode, + ) + self.time_kernel_size = time_kernel_size + + + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + context_size = self.time_kernel_size - 1 + time_casual_padding = (0, 0, 0, 0, context_size, 0) + + inputs = F.pad(inputs, time_casual_padding, mode=self.padding_mode) + + # Memory-efficient chunked operation + memory_count = torch.prod(torch.tensor(inputs.shape)).item() * 2 / 1024**3 + if memory_count > 2: + part_num = int(memory_count / 2) + 1 + k = self.time_kernel_size + input_idx = torch.arange(context_size, inputs.size(2)) + input_chunks_idx = torch.split(input_idx, input_idx.size(0) // part_num) + + # Compute output size + B, _, T_in, H_in, W_in = inputs.shape + output_size = ( + B, + self.conv.out_channels, + T_in - k + 1, + H_in // self.conv.stride[1], + W_in // self.conv.stride[2], + ) + output = torch.empty(output_size, dtype=inputs.dtype, device=inputs.device) + for input_chunk_idx in input_chunks_idx: + input_s = input_chunk_idx[0] - k + 1 + input_e = input_chunk_idx[-1] + 1 + input_chunk = inputs[:, :, input_s:input_e, :, :] + output_chunk = self.conv(input_chunk) + + output_s = input_s + output_e = output_s + output_chunk.size(2) + output[:, :, output_s:output_e, :, :] = output_chunk + + return output + else: + return self.conv(inputs) + + +class MochiGroupNorm3D(nn.Module): + r""" + Group normalization applied per-frame. + + Args: + + """ + + def __init__( + self, + chunk_size: int = 8, + ): + super().__init__() + self.norm_layer = nn.GroupNorm() + self.chunk_size = chunk_size + + def forward( + self, x: torch.Tensor = None + ) -> torch.Tensor: + + batch_size, channels, num_frames, height, width = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + num_chunks = (batch_size * num_frames + self.chunk_size - 1) // self.chunk_size + + output = torch.cat( + [self.norm_layer(chunk) for chunk in x.split(self.chunk_size, dim=0)], + dim=0 + ) + output = output.view(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + + return output + + From 0a6189eb952b800ddb1c24da245d75490ac95478 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 25 Oct 2024 01:02:46 +0200 Subject: [PATCH 2/5] add --- .../autoencoders/autoencoder_kl_mochi.py | 350 ++++++++++++++++-- 1 file changed, 318 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index 9afe9ac069..4f427f3ce1 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -40,7 +40,29 @@ import torch.nn as nn import torch.nn.functional as F -class MochiCausalConv3d(nn.Module): +# YiYi to-do: replace this with nn.Conv3d +class Conv1x1(nn.Linear): + """*1x1 Conv implemented with a linear layer.""" + + def __init__(self, in_features: int, out_features: int, *args, **kwargs): + super().__init__(in_features, out_features, *args, **kwargs) + + def forward(self, x: torch.Tensor): + """Forward pass. + + Args: + x: Input tensor. Shape: [B, C, *] or [B, *, C]. + + Returns: + x: Output tensor. Shape: [B, C', *] or [B, *, C']. + """ + x = x.movedim(1, -1) + x = super().forward(x) + x = x.movedim(-1, 1) + return x + + +class MochiChunkedCausalConv3d(nn.Module): r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. Args: @@ -81,50 +103,42 @@ class MochiCausalConv3d(nn.Module): padding=(0, height_pad, width_pad), padding_mode=padding_mode, ) - self.time_kernel_size = time_kernel_size - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - context_size = self.time_kernel_size - 1 + def forward(self, hidden_states : torch.Tensor) -> torch.Tensor: + time_kernel_size = self.conv.kernel_size[0] + context_size = time_kernel_size - 1 time_casual_padding = (0, 0, 0, 0, context_size, 0) - - inputs = F.pad(inputs, time_casual_padding, mode=self.padding_mode) + hidden_states = F.pad(hidden_states, time_casual_padding, mode=self.padding_mode) # Memory-efficient chunked operation - memory_count = torch.prod(torch.tensor(inputs.shape)).item() * 2 / 1024**3 + memory_count = torch.prod(torch.tensor(hidden_states.shape)).item() * 2 / 1024**3 + # YiYI Notes: testing only!! please remove + memory_count = 3 if memory_count > 2: part_num = int(memory_count / 2) + 1 - k = self.time_kernel_size - input_idx = torch.arange(context_size, inputs.size(2)) - input_chunks_idx = torch.split(input_idx, input_idx.size(0) // part_num) + num_frames = hidden_states.shape[2] + frames_idx = torch.arange(context_size, num_frames) + frames_chunks_idx = torch.chunk(frames_idx, part_num, dim=0) - # Compute output size - B, _, T_in, H_in, W_in = inputs.shape - output_size = ( - B, - self.conv.out_channels, - T_in - k + 1, - H_in // self.conv.stride[1], - W_in // self.conv.stride[2], - ) - output = torch.empty(output_size, dtype=inputs.dtype, device=inputs.device) - for input_chunk_idx in input_chunks_idx: - input_s = input_chunk_idx[0] - k + 1 - input_e = input_chunk_idx[-1] + 1 - input_chunk = inputs[:, :, input_s:input_e, :, :] - output_chunk = self.conv(input_chunk) + output_chunks = [] + for frames_chunk_idx in frames_chunks_idx: + frames_s = frames_chunk_idx[0] - context_size + frames_e = frames_chunk_idx[-1] + 1 + frames_chunk = hidden_states[:, :, frames_s:frames_e, :, :] + output_chunk = self.conv(frames_chunk) + output_chunks.append(output_chunk) # Append each output chunk to the list - output_s = input_s - output_e = output_s + output_chunk.size(2) - output[:, :, output_s:output_e, :, :] = output_chunk + # Concatenate all output chunks along the temporal dimension + hidden_states = torch.cat(output_chunks, dim=2) - return output + return hidden_states else: - return self.conv(inputs) + return self.conv(hidden_states) -class MochiGroupNorm3D(nn.Module): +class MochiChunkedGroupNorm3D(nn.Module): r""" Group normalization applied per-frame. @@ -134,10 +148,13 @@ class MochiGroupNorm3D(nn.Module): def __init__( self, + num_channels: int, + num_groups: int = 32, + affine: bool = True, chunk_size: int = 8, ): super().__init__() - self.norm_layer = nn.GroupNorm() + self.norm_layer = nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, affine=affine) self.chunk_size = chunk_size def forward( @@ -158,3 +175,272 @@ class MochiGroupNorm3D(nn.Module): return output +class MochiResnetBlock3D(nn.Module): + r""" + A 3D ResNet block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + spatial_norm_dim (`int`, *optional*): + The dimension to use for spatial norm if it is to be used instead of group norm. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + non_linearity: str = "swish", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.in_channels = in_channels + self.out_channels = out_channels + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels) + self.conv1 = MochiChunkedCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1 + ) + self.norm2 = MochiChunkedGroupNorm3D(num_channels=out_channels) + self.conv2 = MochiChunkedCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1 + ) + + + def forward( + self, + inputs: torch.Tensor, + ) -> torch.Tensor: + + hidden_states = inputs + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv2(hidden_states) + + hidden_states = hidden_states + inputs + return hidden_states + + +class MochiUpBlock3D(nn.Module): + r""" + An upsampling block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + temporal_expansion: int = 2, + spatial_expansion: int = 2, + ): + super().__init__() + self.temporal_expansion = temporal_expansion + self.spatial_expansion = spatial_expansion + + resnets = [] + for i in range(num_layers): + resnets.append( + MochiResnetBlock3D( + in_channels=in_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.proj = Conv1x1( + in_channels, + out_channels * temporal_expansion * (spatial_expansion**2), + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + r"""Forward method of the `MochiUpBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + ) + else: + hidden_states = resnet(hidden_states) + + hidden_states = self.proj(hidden_states) + + # Calculate new shape + B, C, T, H, W = hidden_states.shape + st = self.temporal_expansion + sh = self.spatial_expansion + sw = self.spatial_expansion + new_C = C // (st * sh * sw) + + # Reshape and permute + hidden_states = hidden_states.view(B, new_C, st, sh, sw, T, H, W) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4) + hidden_states = hidden_states.contiguous().view(B, new_C, T * st, H * sh, W * sw) + + if self.temporal_expansion > 1: + print(f"x: {hidden_states.shape}") + # Drop the first self.temporal_expansion - 1 frames. + hidden_states = hidden_states[:, :, self.temporal_expansion - 1 :] + print(f"x: {hidden_states.shape}") + + return hidden_states + + +class MochiMidBlock3D(nn.Module): + r""" + A middle block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, # 768 + num_layers: int = 3, + ): + super().__init__() + + resnets = [] + for _ in range(num_layers): + resnets.append( + MochiResnetBlock3D(in_channels=in_channels) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + r"""Forward method of the `MochiMidBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states + ) + else: + hidden_states = resnet(hidden_states) + + return hidden_states + + +class MochiDecoder3D(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, # 12 + out_channels: int, # 3 + block_out_channels: Tuple[int, ...] = (128, 256, 512, 768), + layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), + temporal_expansions: Tuple[int, ...] = (1, 2, 3), + spatial_expansions: Tuple[int, ...] = (2, 2, 2), + non_linearity: str = "swish", + ): + super().__init__() + + self.nonlinearity = get_activation(non_linearity) + + self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1)) + self.block_in = MochiMidBlock3D( + in_channels=block_out_channels[-1], + num_layers=layers_per_block[-1], + ) + self.up_blocks = nn.ModuleList([]) + for i in range(len(block_out_channels) - 1): + up_block = MochiUpBlock3D( + in_channels=block_out_channels[-i - 1], + out_channels=block_out_channels[-i - 2], + num_layers=layers_per_block[-i - 2], + temporal_expansion=temporal_expansions[-i - 1], + spatial_expansion=spatial_expansions[-i - 1], + ) + self.up_blocks.append(up_block) + self.block_out = MochiMidBlock3D( + in_channels=block_out_channels[0], + num_layers=layers_per_block[0], + ) + self.conv_out = Conv1x1(block_out_channels[0], out_channels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r"""Forward method of the `MochiDecoder3D` class.""" + + print(f"hidden_states: {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") + hidden_states = self.conv_in(hidden_states) + print(f"hidden_states (after conv_in): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") + + + # 1. Mid + hidden_states = self.block_in(hidden_states) + print(f"hidden_states (after block_in): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") + # 2. Up + for i, up_block in enumerate(self.up_blocks): + hidden_states = up_block(hidden_states) + print(f"hidden_states (after up_block {i}): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") + # 3. Post-process + hidden_states = self.block_out(hidden_states) + print(f"hidden_states (after block_out): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv_out(hidden_states) + print(f"hidden_states (after conv_out): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") + + return hidden_states + \ No newline at end of file From e9e92d043d622b314353e97d52dda3a893e45fa6 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 25 Oct 2024 02:02:46 +0200 Subject: [PATCH 3/5] up --- .../autoencoders/autoencoder_kl_mochi.py | 166 +++++++++--------- 1 file changed, 79 insertions(+), 87 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index 4f427f3ce1..f332d4fb20 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -13,33 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import logging -from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation -from ..downsampling import CogVideoXDownsample3D -from ..modeling_outputs import AutoencoderKLOutput -from ..modeling_utils import ModelMixin -from ..upsampling import CogVideoXUpsample3D -from .vae import DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name -import torch -import torch.nn as nn -import torch.nn.functional as F - - # YiYi to-do: replace this with nn.Conv3d class Conv1x1(nn.Linear): """*1x1 Conv implemented with a linear layer.""" @@ -60,17 +46,18 @@ class Conv1x1(nn.Linear): x = super().forward(x) x = x.movedim(-1, 1) return x - + class MochiChunkedCausalConv3d(nn.Module): - r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. + r"""A 3D causal convolution layer that pads the input tensor to ensure causality in Mochi Model. + It also supports memory-efficient chunked 3D convolutions. Args: in_channels (`int`): Number of channels in the input tensor. out_channels (`int`): Number of output channels produced by the convolution. kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel. stride (`int` or `Tuple[int, int, int]`, defaults to `1`): Stride of the convolution. - pad_mode (`str`, defaults to `"constant"`): Padding mode. + padding_mode (`str`, defaults to `"replicate"`): Padding mode. """ def __init__( @@ -88,7 +75,7 @@ class MochiChunkedCausalConv3d(nn.Module): if isinstance(stride, int): stride = (stride,) * 3 - time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + _, height_kernel_size, width_kernel_size = kernel_size self.padding_mode = padding_mode height_pad = (height_kernel_size - 1) // 2 @@ -104,18 +91,17 @@ class MochiChunkedCausalConv3d(nn.Module): padding_mode=padding_mode, ) - - - def forward(self, hidden_states : torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: time_kernel_size = self.conv.kernel_size[0] context_size = time_kernel_size - 1 time_casual_padding = (0, 0, 0, 0, context_size, 0) hidden_states = F.pad(hidden_states, time_casual_padding, mode=self.padding_mode) - + # Memory-efficient chunked operation memory_count = torch.prod(torch.tensor(hidden_states.shape)).item() * 2 / 1024**3 # YiYI Notes: testing only!! please remove memory_count = 3 + # YiYI Notes: this number 2 should be a config: max_memory_chunk_size (2 is 2GB) if memory_count > 2: part_num = int(memory_count / 2) + 1 num_frames = hidden_states.shape[2] @@ -131,7 +117,7 @@ class MochiChunkedCausalConv3d(nn.Module): output_chunks.append(output_chunk) # Append each output chunk to the list # Concatenate all output chunks along the temporal dimension - hidden_states = torch.cat(output_chunks, dim=2) + hidden_states = torch.cat(output_chunks, dim=2) return hidden_states else: @@ -140,9 +126,14 @@ class MochiChunkedCausalConv3d(nn.Module): class MochiChunkedGroupNorm3D(nn.Module): r""" - Group normalization applied per-frame. + Applies per-frame group normalization for 5D video inputs. It also supports memory-efficient chunked group + normalization. Args: + num_channels (int): Number of channels expected in input + num_groups (int, optional): Number of groups to separate the channels into. Default: 32 + affine (bool, optional): If True, this module has learnable affine parameters. Default: True + chunk_size (int, optional): Size of each chunk for processing. Default: 8 """ @@ -157,49 +148,27 @@ class MochiChunkedGroupNorm3D(nn.Module): self.norm_layer = nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, affine=affine) self.chunk_size = chunk_size - def forward( - self, x: torch.Tensor = None - ) -> torch.Tensor: - + def forward(self, x: torch.Tensor = None) -> torch.Tensor: batch_size, channels, num_frames, height, width = x.shape x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - - num_chunks = (batch_size * num_frames + self.chunk_size - 1) // self.chunk_size - - output = torch.cat( - [self.norm_layer(chunk) for chunk in x.split(self.chunk_size, dim=0)], - dim=0 - ) + + output = torch.cat([self.norm_layer(chunk) for chunk in x.split(self.chunk_size, dim=0)], dim=0) output = output.view(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) - + return output class MochiResnetBlock3D(nn.Module): r""" - A 3D ResNet block used in the CogVideoX model. + A 3D ResNet block used in the Mochi model. Args: in_channels (`int`): Number of input channels. out_channels (`int`, *optional*): Number of output channels. If None, defaults to `in_channels`. - dropout (`float`, defaults to `0.0`): - Dropout rate. - temb_channels (`int`, defaults to `512`): - Number of time embedding channels. - groups (`int`, defaults to `32`): - Number of groups to separate the channels into for group normalization. - eps (`float`, defaults to `1e-6`): - Epsilon value for normalization layers. non_linearity (`str`, defaults to `"swish"`): Activation function to use. - conv_shortcut (bool, defaults to `False`): - Whether or not to use a convolution shortcut. - spatial_norm_dim (`int`, *optional*): - The dimension to use for spatial norm if it is to be used instead of group norm. - pad_mode (str, defaults to `"first"`): - Padding mode. """ def __init__( @@ -225,14 +194,12 @@ class MochiResnetBlock3D(nn.Module): in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1 ) - def forward( self, inputs: torch.Tensor, ) -> torch.Tensor: - hidden_states = inputs - + hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) @@ -254,6 +221,12 @@ class MochiUpBlock3D(nn.Module): Number of input channels. out_channels (`int`, *optional*): Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet blocks in the block. + temporal_expansion (`int`, defaults to `2`): + Temporal expansion factor. + spatial_expansion (`int`, defaults to `2`): + Spatial expansion factor. """ def __init__( @@ -290,8 +263,7 @@ class MochiUpBlock3D(nn.Module): ) -> torch.Tensor: r"""Forward method of the `MochiUpBlock3D` class.""" - for i, resnet in enumerate(self.resnets): - + for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -322,10 +294,8 @@ class MochiUpBlock3D(nn.Module): hidden_states = hidden_states.contiguous().view(B, new_C, T * st, H * sh, W * sw) if self.temporal_expansion > 1: - print(f"x: {hidden_states.shape}") # Drop the first self.temporal_expansion - 1 frames. hidden_states = hidden_states[:, :, self.temporal_expansion - 1 :] - print(f"x: {hidden_states.shape}") return hidden_states @@ -337,22 +307,20 @@ class MochiMidBlock3D(nn.Module): Args: in_channels (`int`): Number of input channels. + num_layers (`int`, defaults to `3`): + Number of resnet blocks in the block. """ - _supports_gradient_checkpointing = True - def __init__( self, - in_channels: int, # 768 + in_channels: int, # 768 num_layers: int = 3, ): super().__init__() resnets = [] for _ in range(num_layers): - resnets.append( - MochiResnetBlock3D(in_channels=in_channels) - ) + resnets.append(MochiResnetBlock3D(in_channels=in_channels)) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False @@ -363,7 +331,7 @@ class MochiMidBlock3D(nn.Module): ) -> torch.Tensor: r"""Forward method of the `MochiMidBlock3D` class.""" - for i, resnet in enumerate(self.resnets): + for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -372,9 +340,7 @@ class MochiMidBlock3D(nn.Module): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states - ) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) else: hidden_states = resnet(hidden_states) @@ -382,12 +348,31 @@ class MochiMidBlock3D(nn.Module): class MochiDecoder3D(nn.Module): - _supports_gradient_checkpointing = True + r""" + The `MochiDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, *optional*): + The number of input channels. + out_channels (`int`, *optional*): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`): + The number of output channels for each block. + layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`): + The number of resnet blocks for each block. + temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`): + The temporal expansion factor for each of the up blocks. + spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`): + The spatial expansion factor for each of the up blocks. + non_linearity (`str`, *optional*, defaults to `"swish"`): + The non-linearity to use in the decoder. + """ def __init__( self, - in_channels: int, # 12 - out_channels: int, # 3 + in_channels: int, # 12 + out_channels: int, # 3 block_out_channels: Tuple[int, ...] = (128, 256, 512, 768), layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), temporal_expansions: Tuple[int, ...] = (1, 2, 3), @@ -418,29 +403,36 @@ class MochiDecoder3D(nn.Module): num_layers=layers_per_block[0], ) self.conv_out = Conv1x1(block_out_channels[0], out_channels) - + + self.gradient_checkpointing = False + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: r"""Forward method of the `MochiDecoder3D` class.""" - print(f"hidden_states: {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") hidden_states = self.conv_in(hidden_states) - print(f"hidden_states (after conv_in): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") - # 1. Mid - hidden_states = self.block_in(hidden_states) - print(f"hidden_states (after block_in): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") - # 2. Up - for i, up_block in enumerate(self.up_blocks): - hidden_states = up_block(hidden_states) - print(f"hidden_states (after up_block {i}): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") - # 3. Post-process + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.block_in), hidden_states) + + for up_block in self.up_blocks: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states) + else: + hidden_states = self.block_in(hidden_states) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states) + hidden_states = self.block_out(hidden_states) - print(f"hidden_states (after block_out): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") - + hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv_out(hidden_states) - print(f"hidden_states (after conv_out): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}") return hidden_states - \ No newline at end of file From 0b76fea5dd93d8c4ce81e6dbec0e6c5ed00e6d24 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 25 Oct 2024 04:14:23 +0200 Subject: [PATCH 4/5] up --- .../autoencoders/autoencoder_kl_mochi.py | 270 +++++++++++++++++- 1 file changed, 266 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index f332d4fb20..7f5699508c 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -19,8 +19,12 @@ import torch import torch.nn as nn import torch.nn.functional as F +from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -175,7 +179,7 @@ class MochiResnetBlock3D(nn.Module): self, in_channels: int, out_channels: Optional[int] = None, - non_linearity: str = "swish", + act_fn: str = "swish", ): super().__init__() @@ -183,7 +187,7 @@ class MochiResnetBlock3D(nn.Module): self.in_channels = in_channels self.out_channels = out_channels - self.nonlinearity = get_activation(non_linearity) + self.nonlinearity = get_activation(act_fn) self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels) self.conv1 = MochiChunkedCausalConv3d( @@ -377,11 +381,11 @@ class MochiDecoder3D(nn.Module): layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), temporal_expansions: Tuple[int, ...] = (1, 2, 3), spatial_expansions: Tuple[int, ...] = (2, 2, 2), - non_linearity: str = "swish", + act_fn: str = "swish", ): super().__init__() - self.nonlinearity = get_activation(non_linearity) + self.nonlinearity = get_activation(act_fn) self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1)) self.block_in = MochiMidBlock3D( @@ -436,3 +440,261 @@ class MochiDecoder3D(nn.Module): hidden_states = self.conv_out(hidden_states) return hidden_states + + +class AutoencoderKLMochi(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [Mochi 1 preview](https://github.com/genmoai/models). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + scaling_factor (`float`, *optional*, defaults to `1.15258426`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["MochiResnetBlock3D"] + + @register_to_config + def __init__( + self, + out_channels: int = 3, + block_out_channels: Tuple[int] = (128, 256, 256, 512), + latent_channels: int = 12, + layers_per_block: int = 3, + act_fn: str = "silu", + temporal_expansions: Tuple[int, ...] = (1, 2, 3), + spatial_expansions: Tuple[int, ...] = (2, 2, 2), + latents_mean: Tuple[float, ...] = ( + -0.06730895953510081, + -0.038011381506090416, + -0.07477820912866141, + -0.05565264470995561, + 0.012767231469026969, + -0.04703542746246419, + 0.043896967884726704, + -0.09346305707025976, + -0.09918314763016893, + -0.008729793427399178, + -0.011931556316503654, + -0.0321993391887285, + ), + latents_std: Tuple[float, ...] = ( + 0.9263795028493863, + 0.9248894543193766, + 0.9393059390890617, + 0.959253732819592, + 0.8244560132752793, + 0.917259975397747, + 0.9294154431013696, + 1.3720942357788521, + 0.881393668867029, + 0.9168315692124348, + 0.9185249279345552, + 0.9274757570805041, + ), + scaling_factor: float = 1.0, + ): + super().__init__() + + self.decoder = MochiDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + temporal_expansions=temporal_expansions, + spatial_expansions=spatial_expansions, + act_fn=act_fn, + ) + + self.use_slicing = False + self.use_tiling = False + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, MochiDecoder3D): + module.gradient_checkpointing = value + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_overlap_factor_height: Optional[float] = None, + tile_overlap_factor_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_overlap_factor_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + tile_overlap_factor_width (`int`, *optional*): + The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there + are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_latent_min_height = int( + self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) + ) + self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height + self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self.decoder(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self.decoder(z) + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + frame_batch_size = self.num_latent_frames_batch_size + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + num_batches = max(num_frames // frame_batch_size, 1) + conv_cache = None + time = [] + + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = z[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) + time.append(tile) + + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) From 7c55ef50009a081e435a801b00f8547bf258b0ed Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 25 Oct 2024 04:24:51 +0200 Subject: [PATCH 5/5] up --- src/diffusers/models/autoencoders/autoencoder_kl_mochi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index 7f5699508c..ac32381ee0 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -472,9 +472,9 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin): def __init__( self, out_channels: int = 3, - block_out_channels: Tuple[int] = (128, 256, 256, 512), + block_out_channels: Tuple[int] = (128, 256, 512, 768), latent_channels: int = 12, - layers_per_block: int = 3, + layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), act_fn: str = "silu", temporal_expansions: Tuple[int, ...] = (1, 2, 3), spatial_expansions: Tuple[int, ...] = (2, 2, 2),