diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 8d87786991..34963251af 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -103,7 +103,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -111,18 +111,29 @@ class Downsample(nn.Module): self.dims = dims self.padding = padding stride = 2 if dims != 3 else (1, 2, 2) + self.name = name + if use_conv: - self.down = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) + conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels - self.down = avg_pool_nd(dims, kernel_size=stride, stride=stride) + conv = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + if name == "conv": + self.conv = conv + else: + self.op = conv def forward(self, x): assert x.shape[1] == self.channels if self.use_conv and self.padding == 0 and self.dims == 2: pad = (0, 1, 0, 1) x = F.pad(x, pad, mode="constant", value=0) - return self.down(x) + + if self.name == "conv": + return self.conv(x) + else: + return self.op(x) # TODO (patil-suraj): needs test