1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
yiyixuxu
2024-10-25 01:02:46 +02:00
parent 85a9825449
commit 0a6189eb95

View File

@@ -40,7 +40,29 @@ import torch.nn as nn
import torch.nn.functional as F
class MochiCausalConv3d(nn.Module):
# YiYi to-do: replace this with nn.Conv3d
class Conv1x1(nn.Linear):
"""*1x1 Conv implemented with a linear layer."""
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
super().__init__(in_features, out_features, *args, **kwargs)
def forward(self, x: torch.Tensor):
"""Forward pass.
Args:
x: Input tensor. Shape: [B, C, *] or [B, *, C].
Returns:
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
"""
x = x.movedim(1, -1)
x = super().forward(x)
x = x.movedim(-1, 1)
return x
class MochiChunkedCausalConv3d(nn.Module):
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
Args:
@@ -81,50 +103,42 @@ class MochiCausalConv3d(nn.Module):
padding=(0, height_pad, width_pad),
padding_mode=padding_mode,
)
self.time_kernel_size = time_kernel_size
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
context_size = self.time_kernel_size - 1
def forward(self, hidden_states : torch.Tensor) -> torch.Tensor:
time_kernel_size = self.conv.kernel_size[0]
context_size = time_kernel_size - 1
time_casual_padding = (0, 0, 0, 0, context_size, 0)
inputs = F.pad(inputs, time_casual_padding, mode=self.padding_mode)
hidden_states = F.pad(hidden_states, time_casual_padding, mode=self.padding_mode)
# Memory-efficient chunked operation
memory_count = torch.prod(torch.tensor(inputs.shape)).item() * 2 / 1024**3
memory_count = torch.prod(torch.tensor(hidden_states.shape)).item() * 2 / 1024**3
# YiYI Notes: testing only!! please remove
memory_count = 3
if memory_count > 2:
part_num = int(memory_count / 2) + 1
k = self.time_kernel_size
input_idx = torch.arange(context_size, inputs.size(2))
input_chunks_idx = torch.split(input_idx, input_idx.size(0) // part_num)
num_frames = hidden_states.shape[2]
frames_idx = torch.arange(context_size, num_frames)
frames_chunks_idx = torch.chunk(frames_idx, part_num, dim=0)
# Compute output size
B, _, T_in, H_in, W_in = inputs.shape
output_size = (
B,
self.conv.out_channels,
T_in - k + 1,
H_in // self.conv.stride[1],
W_in // self.conv.stride[2],
)
output = torch.empty(output_size, dtype=inputs.dtype, device=inputs.device)
for input_chunk_idx in input_chunks_idx:
input_s = input_chunk_idx[0] - k + 1
input_e = input_chunk_idx[-1] + 1
input_chunk = inputs[:, :, input_s:input_e, :, :]
output_chunk = self.conv(input_chunk)
output_chunks = []
for frames_chunk_idx in frames_chunks_idx:
frames_s = frames_chunk_idx[0] - context_size
frames_e = frames_chunk_idx[-1] + 1
frames_chunk = hidden_states[:, :, frames_s:frames_e, :, :]
output_chunk = self.conv(frames_chunk)
output_chunks.append(output_chunk) # Append each output chunk to the list
output_s = input_s
output_e = output_s + output_chunk.size(2)
output[:, :, output_s:output_e, :, :] = output_chunk
# Concatenate all output chunks along the temporal dimension
hidden_states = torch.cat(output_chunks, dim=2)
return output
return hidden_states
else:
return self.conv(inputs)
return self.conv(hidden_states)
class MochiGroupNorm3D(nn.Module):
class MochiChunkedGroupNorm3D(nn.Module):
r"""
Group normalization applied per-frame.
@@ -134,10 +148,13 @@ class MochiGroupNorm3D(nn.Module):
def __init__(
self,
num_channels: int,
num_groups: int = 32,
affine: bool = True,
chunk_size: int = 8,
):
super().__init__()
self.norm_layer = nn.GroupNorm()
self.norm_layer = nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, affine=affine)
self.chunk_size = chunk_size
def forward(
@@ -158,3 +175,272 @@ class MochiGroupNorm3D(nn.Module):
return output
class MochiResnetBlock3D(nn.Module):
r"""
A 3D ResNet block used in the CogVideoX model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
dropout (`float`, defaults to `0.0`):
Dropout rate.
temb_channels (`int`, defaults to `512`):
Number of time embedding channels.
groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
non_linearity (`str`, defaults to `"swish"`):
Activation function to use.
conv_shortcut (bool, defaults to `False`):
Whether or not to use a convolution shortcut.
spatial_norm_dim (`int`, *optional*):
The dimension to use for spatial norm if it is to be used instead of group norm.
pad_mode (str, defaults to `"first"`):
Padding mode.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
non_linearity: str = "swish",
):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.nonlinearity = get_activation(non_linearity)
self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels)
self.conv1 = MochiChunkedCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1
)
self.norm2 = MochiChunkedGroupNorm3D(num_channels=out_channels)
self.conv2 = MochiChunkedCausalConv3d(
in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1
)
def forward(
self,
inputs: torch.Tensor,
) -> torch.Tensor:
hidden_states = inputs
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = hidden_states + inputs
return hidden_states
class MochiUpBlock3D(nn.Module):
r"""
An upsampling block used in the Mochi model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
temporal_expansion: int = 2,
spatial_expansion: int = 2,
):
super().__init__()
self.temporal_expansion = temporal_expansion
self.spatial_expansion = spatial_expansion
resnets = []
for i in range(num_layers):
resnets.append(
MochiResnetBlock3D(
in_channels=in_channels,
)
)
self.resnets = nn.ModuleList(resnets)
self.proj = Conv1x1(
in_channels,
out_channels * temporal_expansion * (spatial_expansion**2),
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
r"""Forward method of the `MochiUpBlock3D` class."""
for i, resnet in enumerate(self.resnets):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
)
else:
hidden_states = resnet(hidden_states)
hidden_states = self.proj(hidden_states)
# Calculate new shape
B, C, T, H, W = hidden_states.shape
st = self.temporal_expansion
sh = self.spatial_expansion
sw = self.spatial_expansion
new_C = C // (st * sh * sw)
# Reshape and permute
hidden_states = hidden_states.view(B, new_C, st, sh, sw, T, H, W)
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4)
hidden_states = hidden_states.contiguous().view(B, new_C, T * st, H * sh, W * sw)
if self.temporal_expansion > 1:
print(f"x: {hidden_states.shape}")
# Drop the first self.temporal_expansion - 1 frames.
hidden_states = hidden_states[:, :, self.temporal_expansion - 1 :]
print(f"x: {hidden_states.shape}")
return hidden_states
class MochiMidBlock3D(nn.Module):
r"""
A middle block used in the Mochi model.
Args:
in_channels (`int`):
Number of input channels.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int, # 768
num_layers: int = 3,
):
super().__init__()
resnets = []
for _ in range(num_layers):
resnets.append(
MochiResnetBlock3D(in_channels=in_channels)
)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
r"""Forward method of the `MochiMidBlock3D` class."""
for i, resnet in enumerate(self.resnets):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states
)
else:
hidden_states = resnet(hidden_states)
return hidden_states
class MochiDecoder3D(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int, # 12
out_channels: int, # 3
block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
non_linearity: str = "swish",
):
super().__init__()
self.nonlinearity = get_activation(non_linearity)
self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1))
self.block_in = MochiMidBlock3D(
in_channels=block_out_channels[-1],
num_layers=layers_per_block[-1],
)
self.up_blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
up_block = MochiUpBlock3D(
in_channels=block_out_channels[-i - 1],
out_channels=block_out_channels[-i - 2],
num_layers=layers_per_block[-i - 2],
temporal_expansion=temporal_expansions[-i - 1],
spatial_expansion=spatial_expansions[-i - 1],
)
self.up_blocks.append(up_block)
self.block_out = MochiMidBlock3D(
in_channels=block_out_channels[0],
num_layers=layers_per_block[0],
)
self.conv_out = Conv1x1(block_out_channels[0], out_channels)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""Forward method of the `MochiDecoder3D` class."""
print(f"hidden_states: {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
hidden_states = self.conv_in(hidden_states)
print(f"hidden_states (after conv_in): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
# 1. Mid
hidden_states = self.block_in(hidden_states)
print(f"hidden_states (after block_in): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
# 2. Up
for i, up_block in enumerate(self.up_blocks):
hidden_states = up_block(hidden_states)
print(f"hidden_states (after up_block {i}): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
# 3. Post-process
hidden_states = self.block_out(hidden_states)
print(f"hidden_states (after block_out): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv_out(hidden_states)
print(f"hidden_states (after conv_out): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
return hidden_states