mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
remove unnecessary call to F.pad (#10620)
* rewrite memory count without implicitly using dimensions by @ic-synth * replace F.pad by built-in padding in Conv3D * in-place sums to reduce memory allocations * fixed trailing whitespace * file reformatted * in-place sums * simpler in-place expressions * removed in-place sum, may affect backward propagation logic * removed in-place sum, may affect backward propagation logic * removed in-place sum, may affect backward propagation logic * reverted change
This commit is contained in:
@@ -105,6 +105,7 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
self.width_pad = width_pad
|
||||
self.time_pad = time_pad
|
||||
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
|
||||
self.const_padding_conv3d = (0, self.width_pad, self.height_pad)
|
||||
|
||||
self.temporal_dim = 2
|
||||
self.time_kernel_size = time_kernel_size
|
||||
@@ -117,6 +118,8 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=0 if self.pad_mode == "replicate" else self.const_padding_conv3d,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
|
||||
def fake_context_parallel_forward(
|
||||
@@ -137,9 +140,7 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
if self.pad_mode == "replicate":
|
||||
conv_cache = None
|
||||
else:
|
||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
||||
|
||||
output = self.conv(inputs)
|
||||
return output, conv_cache
|
||||
|
||||
Reference in New Issue
Block a user