mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[refactor] remove conv_cache from CogVideoX VAE (#9524)
* remove conv cache from the layer and pass as arg instead * make style * yiyi's cleaner implementation Co-Authored-By: YiYi Xu <yixu310@gmail.com> * sayak's compiled implementation Co-Authored-By: Sayak Paul <spsayakpaul@gmail.com> --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -41,7 +41,9 @@ class CogVideoXSafeConv3d(nn.Conv3d):
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
|
||||
memory_count = (
|
||||
(input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
|
||||
)
|
||||
|
||||
# Set to 2GB, suitable for CuDNN
|
||||
if memory_count > 2:
|
||||
@@ -115,34 +117,24 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
dilation=dilation,
|
||||
)
|
||||
|
||||
self.conv_cache = None
|
||||
|
||||
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
def fake_context_parallel_forward(
|
||||
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
kernel_size = self.time_kernel_size
|
||||
if kernel_size > 1:
|
||||
cached_inputs = (
|
||||
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
||||
)
|
||||
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
||||
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
||||
return inputs
|
||||
|
||||
def _clear_fake_context_parallel_cache(self):
|
||||
del self.conv_cache
|
||||
self.conv_cache = None
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = self.fake_context_parallel_forward(inputs)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
|
||||
# hundred megabytes and so let's not do it for now
|
||||
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
|
||||
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||
|
||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
||||
|
||||
output = self.conv(inputs)
|
||||
return output
|
||||
return output, conv_cache
|
||||
|
||||
|
||||
class CogVideoXSpatialNorm3D(nn.Module):
|
||||
@@ -172,7 +164,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
||||
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> torch.Tensor:
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
@@ -183,9 +180,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
||||
else:
|
||||
zq = F.interpolate(zq, size=f.shape[-3:])
|
||||
|
||||
conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
|
||||
conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
|
||||
|
||||
norm_f = self.norm_layer(f)
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
new_f = norm_f * conv_y + conv_b
|
||||
return new_f, new_conv_cache
|
||||
|
||||
|
||||
class CogVideoXResnetBlock3D(nn.Module):
|
||||
@@ -236,6 +236,7 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
self.out_channels = out_channels
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.spatial_norm_dim = spatial_norm_dim
|
||||
|
||||
if spatial_norm_dim is None:
|
||||
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
||||
@@ -279,34 +280,43 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
inputs: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
hidden_states = inputs
|
||||
|
||||
if zq is not None:
|
||||
hidden_states = self.norm1(hidden_states, zq)
|
||||
hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
|
||||
else:
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
|
||||
|
||||
if temb is not None:
|
||||
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
if zq is not None:
|
||||
hidden_states = self.norm2(hidden_states, zq)
|
||||
hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
|
||||
else:
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
inputs = self.conv_shortcut(inputs)
|
||||
if self.use_conv_shortcut:
|
||||
inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
|
||||
inputs, conv_cache=conv_cache.get("conv_shortcut")
|
||||
)
|
||||
else:
|
||||
inputs = self.conv_shortcut(inputs)
|
||||
|
||||
hidden_states = hidden_states + inputs
|
||||
return hidden_states
|
||||
return hidden_states, new_conv_cache
|
||||
|
||||
|
||||
class CogVideoXDownBlock3D(nn.Module):
|
||||
@@ -392,8 +402,16 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
for resnet in self.resnets:
|
||||
r"""Forward method of the `CogVideoXDownBlock3D` class."""
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
@@ -402,17 +420,23 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
zq,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
return hidden_states, new_conv_cache
|
||||
|
||||
|
||||
class CogVideoXMidBlock3D(nn.Module):
|
||||
@@ -480,8 +504,16 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
for resnet in self.resnets:
|
||||
r"""Forward method of the `CogVideoXMidBlock3D` class."""
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
@@ -490,13 +522,15 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
return hidden_states, new_conv_cache
|
||||
|
||||
|
||||
class CogVideoXUpBlock3D(nn.Module):
|
||||
@@ -584,9 +618,16 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `CogVideoXUpBlock3D` class."""
|
||||
for resnet in self.resnets:
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
@@ -595,17 +636,23 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
zq,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
return hidden_states, new_conv_cache
|
||||
|
||||
|
||||
class CogVideoXEncoder3D(nn.Module):
|
||||
@@ -705,9 +752,18 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""The forward method of the `CogVideoXEncoder3D` class."""
|
||||
hidden_states = self.conv_in(sample)
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -718,28 +774,44 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
return custom_forward
|
||||
|
||||
# 1. Down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block), hidden_states, temb, None
|
||||
for i, down_block in enumerate(self.down_blocks):
|
||||
conv_cache_key = f"down_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
None,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
)
|
||||
|
||||
# 2. Mid
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb, None
|
||||
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
None,
|
||||
conv_cache=conv_cache.get("mid_block"),
|
||||
)
|
||||
else:
|
||||
# 1. Down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states, temb, None)
|
||||
for i, down_block in enumerate(self.down_blocks):
|
||||
conv_cache_key = f"down_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = down_block(
|
||||
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
|
||||
# 2. Mid
|
||||
hidden_states = self.mid_block(hidden_states, temb, None)
|
||||
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
||||
hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
|
||||
)
|
||||
|
||||
# 3. Post-process
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
||||
|
||||
return hidden_states, new_conv_cache
|
||||
|
||||
|
||||
class CogVideoXDecoder3D(nn.Module):
|
||||
@@ -846,9 +918,18 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""The forward method of the `CogVideoXDecoder3D` class."""
|
||||
hidden_states = self.conv_in(sample)
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -859,28 +940,45 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
return custom_forward
|
||||
|
||||
# 1. Mid
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb, sample
|
||||
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
sample,
|
||||
conv_cache=conv_cache.get("mid_block"),
|
||||
)
|
||||
|
||||
# 2. Up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block), hidden_states, temb, sample
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
conv_cache_key = f"up_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
sample,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
)
|
||||
else:
|
||||
# 1. Mid
|
||||
hidden_states = self.mid_block(hidden_states, temb, sample)
|
||||
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
||||
hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
|
||||
)
|
||||
|
||||
# 2. Up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states, temb, sample)
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
conv_cache_key = f"up_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = up_block(
|
||||
hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
|
||||
# 3. Post-process
|
||||
hidden_states = self.norm_out(hidden_states, sample)
|
||||
hidden_states, new_conv_cache["norm_out"] = self.norm_out(
|
||||
hidden_states, sample, conv_cache=conv_cache.get("norm_out")
|
||||
)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
||||
|
||||
return hidden_states, new_conv_cache
|
||||
|
||||
|
||||
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
@@ -1019,12 +1117,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def _clear_fake_context_parallel_cache(self):
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CogVideoXCausalConv3d):
|
||||
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
||||
module._clear_fake_context_parallel_cache()
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
@@ -1091,20 +1183,20 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
frame_batch_size = self.num_sample_frames_batch_size
|
||||
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
||||
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
|
||||
conv_cache = None
|
||||
enc = []
|
||||
|
||||
for i in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
x_intermediate = x[:, :, start_frame:end_frame]
|
||||
x_intermediate = self.encoder(x_intermediate)
|
||||
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
||||
if self.quant_conv is not None:
|
||||
x_intermediate = self.quant_conv(x_intermediate)
|
||||
enc.append(x_intermediate)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
enc = torch.cat(enc, dim=2)
|
||||
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
@@ -1143,7 +1235,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
frame_batch_size = self.num_latent_frames_batch_size
|
||||
num_batches = num_frames // frame_batch_size
|
||||
conv_cache = None
|
||||
dec = []
|
||||
|
||||
for i in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
@@ -1151,10 +1245,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
z_intermediate = z[:, :, start_frame:end_frame]
|
||||
if self.post_quant_conv is not None:
|
||||
z_intermediate = self.post_quant_conv(z_intermediate)
|
||||
z_intermediate = self.decoder(z_intermediate)
|
||||
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
||||
dec.append(z_intermediate)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
dec = torch.cat(dec, dim=2)
|
||||
|
||||
if not return_dict:
|
||||
@@ -1238,7 +1331,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for j in range(0, width, overlap_width):
|
||||
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
||||
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
|
||||
conv_cache = None
|
||||
time = []
|
||||
|
||||
for k in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||
@@ -1250,11 +1345,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
i : i + self.tile_sample_min_height,
|
||||
j : j + self.tile_sample_min_width,
|
||||
]
|
||||
tile = self.encoder(tile)
|
||||
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
|
||||
if self.quant_conv is not None:
|
||||
tile = self.quant_conv(tile)
|
||||
time.append(tile)
|
||||
self._clear_fake_context_parallel_cache()
|
||||
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
|
||||
@@ -1315,7 +1410,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
num_batches = num_frames // frame_batch_size
|
||||
conv_cache = None
|
||||
time = []
|
||||
|
||||
for k in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||
@@ -1329,9 +1426,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
]
|
||||
if self.post_quant_conv is not None:
|
||||
tile = self.post_quant_conv(tile)
|
||||
tile = self.decoder(tile)
|
||||
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
||||
time.append(tile)
|
||||
self._clear_fake_context_parallel_cache()
|
||||
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user