1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
patil-suraj
2022-06-30 12:21:04 +02:00
parent ebe683432f
commit 13ac40ed8e
2 changed files with 7 additions and 9 deletions

View File

@@ -603,7 +603,7 @@ class ResnetBlockBigGANpp(nn.Module):
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down:
#1x1 convolution with DDPM initialization.
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale
@@ -757,9 +757,7 @@ class RearrangeDim(nn.Module):
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(
in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv

View File

@@ -289,9 +289,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(
in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
@@ -336,7 +334,7 @@ class Combine(nn.Module):
def __init__(self, dim1, dim2, method="cat"):
super().__init__()
#1x1 convolution with DDPM initialization.
# 1x1 convolution with DDPM initialization.
self.Conv_0 = conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method
@@ -602,7 +600,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
else:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1))
modules.append(
conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)
)
pyramid_ch = channels
elif progressive == "residual":
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))