mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add Downsample
This commit is contained in:
@@ -103,7 +103,7 @@ class Downsample(nn.Module):
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@@ -111,18 +111,29 @@ class Downsample(nn.Module):
|
||||
self.dims = dims
|
||||
self.padding = padding
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
self.down = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.down = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
conv = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.op = conv
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.use_conv and self.padding == 0 and self.dims == 2:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
return self.down(x)
|
||||
|
||||
if self.name == "conv":
|
||||
return self.conv(x)
|
||||
else:
|
||||
return self.op(x)
|
||||
|
||||
|
||||
# TODO (patil-suraj): needs test
|
||||
|
||||
Reference in New Issue
Block a user