1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

better names and more cleanup

This commit is contained in:
patil-suraj
2022-06-30 13:26:05 +02:00
parent 639b861129
commit 3e2cff4da2

View File

@@ -40,12 +40,8 @@ def _setup_kernel(k):
return k
def _shape(x, dim):
return x.shape[dim]
def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
@@ -84,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
# Determine data dimensions.
stride = [1, 1, factor, factor]
output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = (
output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW,
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = _shape(x, 1) // inC
num_groups = x.shape[1] // inC
# Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
@@ -98,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
# Original TF code.
# x = tf.nn.conv2d_transpose(
# x,
# w,
# output_shape=output_shape,
# strides=stride,
# padding='VALID',
# data_format=data_format)
# JAX equivalent
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
@@ -143,15 +130,7 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
return F.conv2d(x, w, stride=s, padding=0)
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
@@ -170,13 +149,21 @@ def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, devi
return init
def Conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = _variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
class Combine(nn.Module):
"""Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"):
super().__init__()
# 1x1 convolution with DDPM initialization.
self.Conv_0 = conv2d(dim1, dim2, kernel_size=1, padding=0)
self.Conv_0 = Conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method
def forward(self, x, y):
@@ -189,38 +176,38 @@ class Combine(nn.Module):
raise ValueError(f"Method {self.method} not recognized.")
class Upsample(nn.Module):
class FirUpsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if with_conv:
self.Conv2d_0 = conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
self.with_conv = with_conv
self.fir_kernel = fir_kernel
self.out_ch = out_ch
def forward(self, x):
if self.with_conv:
h = upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
else:
h = upsample_2d(x, self.fir_kernel, factor=2)
return h
class Downsample(nn.Module):
class FirDownsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if with_conv:
self.Conv2d_0 = self.Conv2d_0 = conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.with_conv = with_conv
self.out_ch = out_ch
def forward(self, x):
if self.with_conv:
x = conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
else:
x = downsample_2d(x, self.fir_kernel, factor=2)
@@ -311,21 +298,21 @@ class NCSNpp(ModelMixin, ConfigMixin):
if conditional:
modules.append(nn.Linear(embed_dim, nf * 4))
modules[-1].weight.data = variance_scaling()(modules[-1].weight.shape)
modules[-1].weight.data = _variance_scaling()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
modules.append(nn.Linear(nf * 4, nf * 4))
modules[-1].weight.data = variance_scaling()(modules[-1].weight.shape)
modules[-1].weight.data = _variance_scaling()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if progressive == "output_skip":
self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False)
elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True)
Down_sample = functools.partial(Downsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False)
@@ -348,7 +335,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if progressive_input != "none":
input_pyramid_ch = channels
modules.append(conv2d(channels, nf, kernel_size=3, padding=1))
modules.append(Conv2d(channels, nf, kernel_size=3, padding=1))
hs_c = [nf]
in_ch = nf
@@ -397,11 +384,11 @@ class NCSNpp(ModelMixin, ConfigMixin):
if i_level == self.num_resolutions - 1:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1))
modules.append(Conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1))
pyramid_ch = channels
elif progressive == "residual":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
modules.append(Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
pyramid_ch = in_ch
else:
raise ValueError(f"{progressive} is not a valid name.")
@@ -409,7 +396,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(
conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)
Conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)
)
pyramid_ch = channels
elif progressive == "residual":
@@ -425,7 +412,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if progressive != "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv2d(in_ch, channels, init_scale=init_scale))
modules.append(Conv2d(in_ch, channels, init_scale=init_scale))
self.all_modules = nn.ModuleList(modules)