1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add Downsample

This commit is contained in:
patil-suraj
2022-06-27 18:03:41 +02:00
parent ee010726ab
commit b9de7172ba

View File

@@ -103,7 +103,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -111,18 +111,29 @@ class Downsample(nn.Module):
self.dims = dims
self.padding = padding
stride = 2 if dims != 3 else (1, 2, 2)
self.name = name
if use_conv:
self.down = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.down = avg_pool_nd(dims, kernel_size=stride, stride=stride)
conv = avg_pool_nd(dims, kernel_size=stride, stride=stride)
if name == "conv":
self.conv = conv
else:
self.op = conv
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0 and self.dims == 2:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
return self.down(x)
if self.name == "conv":
return self.conv(x)
else:
return self.op(x)
# TODO (patil-suraj): needs test