1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[refactor] remove conv_cache from CogVideoX VAE (#9524)

* remove conv cache from the layer and pass as arg instead

* make style

* yiyi's cleaner implementation

Co-Authored-By: YiYi Xu <yixu310@gmail.com>

* sayak's compiled implementation

Co-Authored-By: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Aryan
2024-09-28 17:09:30 +05:30
committed by GitHub
parent 11542431a5
commit bd4df2856a

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import numpy as np
import torch
@@ -41,7 +41,9 @@ class CogVideoXSafeConv3d(nn.Conv3d):
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
memory_count = (
(input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
)
# Set to 2GB, suitable for CuDNN
if memory_count > 2:
@@ -115,34 +117,24 @@ class CogVideoXCausalConv3d(nn.Module):
dilation=dilation,
)
self.conv_cache = None
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
def fake_context_parallel_forward(
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
) -> torch.Tensor:
kernel_size = self.time_kernel_size
if kernel_size > 1:
cached_inputs = (
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
)
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
inputs = torch.cat(cached_inputs + [inputs], dim=2)
return inputs
def _clear_fake_context_parallel_cache(self):
del self.conv_cache
self.conv_cache = None
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = self.fake_context_parallel_forward(inputs)
self._clear_fake_context_parallel_cache()
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
# hundred megabytes and so let's not do it for now
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
output = self.conv(inputs)
return output
return output, conv_cache
class CogVideoXSpatialNorm3D(nn.Module):
@@ -172,7 +164,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
def forward(
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
) -> torch.Tensor:
new_conv_cache = {}
conv_cache = conv_cache or {}
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
@@ -183,9 +180,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
else:
zq = F.interpolate(zq, size=f.shape[-3:])
conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
new_f = norm_f * conv_y + conv_b
return new_f, new_conv_cache
class CogVideoXResnetBlock3D(nn.Module):
@@ -236,6 +236,7 @@ class CogVideoXResnetBlock3D(nn.Module):
self.out_channels = out_channels
self.nonlinearity = get_activation(non_linearity)
self.use_conv_shortcut = conv_shortcut
self.spatial_norm_dim = spatial_norm_dim
if spatial_norm_dim is None:
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
@@ -279,34 +280,43 @@ class CogVideoXResnetBlock3D(nn.Module):
inputs: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
new_conv_cache = {}
conv_cache = conv_cache or {}
hidden_states = inputs
if zq is not None:
hidden_states = self.norm1(hidden_states, zq)
hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
if temb is not None:
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
if zq is not None:
hidden_states = self.norm2(hidden_states, zq)
hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
else:
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
if self.in_channels != self.out_channels:
inputs = self.conv_shortcut(inputs)
if self.use_conv_shortcut:
inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
inputs, conv_cache=conv_cache.get("conv_shortcut")
)
else:
inputs = self.conv_shortcut(inputs)
hidden_states = hidden_states + inputs
return hidden_states
return hidden_states, new_conv_cache
class CogVideoXDownBlock3D(nn.Module):
@@ -392,8 +402,16 @@ class CogVideoXDownBlock3D(nn.Module):
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
for resnet in self.resnets:
r"""Forward method of the `CogVideoXDownBlock3D` class."""
new_conv_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
@@ -402,17 +420,23 @@ class CogVideoXDownBlock3D(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
zq,
conv_cache=conv_cache.get(conv_cache_key),
)
else:
hidden_states = resnet(hidden_states, temb, zq)
hidden_states, new_conv_cache[conv_cache_key] = resnet(
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
return hidden_states, new_conv_cache
class CogVideoXMidBlock3D(nn.Module):
@@ -480,8 +504,16 @@ class CogVideoXMidBlock3D(nn.Module):
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
for resnet in self.resnets:
r"""Forward method of the `CogVideoXMidBlock3D` class."""
new_conv_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
@@ -490,13 +522,15 @@ class CogVideoXMidBlock3D(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
)
else:
hidden_states = resnet(hidden_states, temb, zq)
hidden_states, new_conv_cache[conv_cache_key] = resnet(
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
)
return hidden_states
return hidden_states, new_conv_cache
class CogVideoXUpBlock3D(nn.Module):
@@ -584,9 +618,16 @@ class CogVideoXUpBlock3D(nn.Module):
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""Forward method of the `CogVideoXUpBlock3D` class."""
for resnet in self.resnets:
new_conv_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
@@ -595,17 +636,23 @@ class CogVideoXUpBlock3D(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
zq,
conv_cache=conv_cache.get(conv_cache_key),
)
else:
hidden_states = resnet(hidden_states, temb, zq)
hidden_states, new_conv_cache[conv_cache_key] = resnet(
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
return hidden_states, new_conv_cache
class CogVideoXEncoder3D(nn.Module):
@@ -705,9 +752,18 @@ class CogVideoXEncoder3D(nn.Module):
self.gradient_checkpointing = False
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self,
sample: torch.Tensor,
temb: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""The forward method of the `CogVideoXEncoder3D` class."""
hidden_states = self.conv_in(sample)
new_conv_cache = {}
conv_cache = conv_cache or {}
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
if self.training and self.gradient_checkpointing:
@@ -718,28 +774,44 @@ class CogVideoXEncoder3D(nn.Module):
return custom_forward
# 1. Down
for down_block in self.down_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), hidden_states, temb, None
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block),
hidden_states,
temb,
None,
conv_cache=conv_cache.get(conv_cache_key),
)
# 2. Mid
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb, None
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states,
temb,
None,
conv_cache=conv_cache.get("mid_block"),
)
else:
# 1. Down
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states, temb, None)
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = down_block(
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
)
# 2. Mid
hidden_states = self.mid_block(hidden_states, temb, None)
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
)
# 3. Post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
return hidden_states, new_conv_cache
class CogVideoXDecoder3D(nn.Module):
@@ -846,9 +918,18 @@ class CogVideoXDecoder3D(nn.Module):
self.gradient_checkpointing = False
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self,
sample: torch.Tensor,
temb: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""The forward method of the `CogVideoXDecoder3D` class."""
hidden_states = self.conv_in(sample)
new_conv_cache = {}
conv_cache = conv_cache or {}
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
if self.training and self.gradient_checkpointing:
@@ -859,28 +940,45 @@ class CogVideoXDecoder3D(nn.Module):
return custom_forward
# 1. Mid
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb, sample
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states,
temb,
sample,
conv_cache=conv_cache.get("mid_block"),
)
# 2. Up
for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), hidden_states, temb, sample
for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
hidden_states,
temb,
sample,
conv_cache=conv_cache.get(conv_cache_key),
)
else:
# 1. Mid
hidden_states = self.mid_block(hidden_states, temb, sample)
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
)
# 2. Up
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states, temb, sample)
for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = up_block(
hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
)
# 3. Post-process
hidden_states = self.norm_out(hidden_states, sample)
hidden_states, new_conv_cache["norm_out"] = self.norm_out(
hidden_states, sample, conv_cache=conv_cache.get("norm_out")
)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
return hidden_states, new_conv_cache
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
@@ -1019,12 +1117,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value
def _clear_fake_context_parallel_cache(self):
for name, module in self.named_modules():
if isinstance(module, CogVideoXCausalConv3d):
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
module._clear_fake_context_parallel_cache()
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
@@ -1091,20 +1183,20 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
frame_batch_size = self.num_sample_frames_batch_size
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
conv_cache = None
enc = []
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
x_intermediate = x[:, :, start_frame:end_frame]
x_intermediate = self.encoder(x_intermediate)
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
if self.quant_conv is not None:
x_intermediate = self.quant_conv(x_intermediate)
enc.append(x_intermediate)
self._clear_fake_context_parallel_cache()
enc = torch.cat(enc, dim=2)
return enc
@apply_forward_hook
@@ -1143,7 +1235,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
frame_batch_size = self.num_latent_frames_batch_size
num_batches = num_frames // frame_batch_size
conv_cache = None
dec = []
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
@@ -1151,10 +1245,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
z_intermediate = z[:, :, start_frame:end_frame]
if self.post_quant_conv is not None:
z_intermediate = self.post_quant_conv(z_intermediate)
z_intermediate = self.decoder(z_intermediate)
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
dec.append(z_intermediate)
self._clear_fake_context_parallel_cache()
dec = torch.cat(dec, dim=2)
if not return_dict:
@@ -1238,7 +1331,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for j in range(0, width, overlap_width):
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
conv_cache = None
time = []
for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
@@ -1250,11 +1345,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile)
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
if self.quant_conv is not None:
tile = self.quant_conv(tile)
time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2))
rows.append(row)
@@ -1315,7 +1410,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
row = []
for j in range(0, width, overlap_width):
num_batches = num_frames // frame_batch_size
conv_cache = None
time = []
for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
@@ -1329,9 +1426,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
]
if self.post_quant_conv is not None:
tile = self.post_quant_conv(tile)
tile = self.decoder(tile)
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2))
rows.append(row)