mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge pull request #52 from huggingface/clean-unet-sde
Clean UNetNCSNpp
This commit is contained in:
@@ -579,7 +579,6 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||
up=False,
|
||||
down=False,
|
||||
dropout=0.1,
|
||||
fir=False,
|
||||
fir_kernel=(1, 3, 3, 1),
|
||||
skip_rescale=True,
|
||||
init_scale=0.0,
|
||||
@@ -590,20 +589,20 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
||||
self.up = up
|
||||
self.down = down
|
||||
self.fir = fir
|
||||
self.fir_kernel = fir_kernel
|
||||
|
||||
self.Conv_0 = conv3x3(in_ch, out_ch)
|
||||
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
|
||||
if temb_dim is not None:
|
||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
||||
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
|
||||
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
|
||||
nn.init.zeros_(self.Dense_0.bias)
|
||||
|
||||
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
||||
self.Dropout_0 = nn.Dropout(dropout)
|
||||
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
||||
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
|
||||
if in_ch != out_ch or up or down:
|
||||
self.Conv_2 = conv1x1(in_ch, out_ch)
|
||||
# 1x1 convolution with DDPM initialization.
|
||||
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
|
||||
|
||||
self.skip_rescale = skip_rescale
|
||||
self.act = act
|
||||
@@ -614,19 +613,11 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||
h = self.act(self.GroupNorm_0(x))
|
||||
|
||||
if self.up:
|
||||
if self.fir:
|
||||
h = upsample_2d(h, self.fir_kernel, factor=2)
|
||||
x = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
else:
|
||||
h = naive_upsample_2d(h, factor=2)
|
||||
x = naive_upsample_2d(x, factor=2)
|
||||
h = upsample_2d(h, self.fir_kernel, factor=2)
|
||||
x = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
elif self.down:
|
||||
if self.fir:
|
||||
h = downsample_2d(h, self.fir_kernel, factor=2)
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
else:
|
||||
h = naive_downsample_2d(h, factor=2)
|
||||
x = naive_downsample_2d(x, factor=2)
|
||||
h = downsample_2d(h, self.fir_kernel, factor=2)
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
|
||||
h = self.Conv_0(h)
|
||||
# Add bias to each feature map conditioned on the time embedding
|
||||
@@ -645,62 +636,6 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||
return (x + h) / np.sqrt(2.0)
|
||||
|
||||
|
||||
# unet_score_estimation.py
|
||||
class ResnetBlockDDPMpp(nn.Module):
|
||||
"""ResBlock adapted from DDPM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act,
|
||||
in_ch,
|
||||
out_ch=None,
|
||||
temb_dim=None,
|
||||
conv_shortcut=False,
|
||||
dropout=0.1,
|
||||
skip_rescale=False,
|
||||
init_scale=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
||||
self.Conv_0 = conv3x3(in_ch, out_ch)
|
||||
if temb_dim is not None:
|
||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
||||
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
|
||||
nn.init.zeros_(self.Dense_0.bias)
|
||||
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
||||
self.Dropout_0 = nn.Dropout(dropout)
|
||||
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
||||
if in_ch != out_ch:
|
||||
if conv_shortcut:
|
||||
self.Conv_2 = conv3x3(in_ch, out_ch)
|
||||
else:
|
||||
self.NIN_0 = NIN(in_ch, out_ch)
|
||||
|
||||
self.skip_rescale = skip_rescale
|
||||
self.act = act
|
||||
self.out_ch = out_ch
|
||||
self.conv_shortcut = conv_shortcut
|
||||
|
||||
def forward(self, x, temb=None):
|
||||
h = self.act(self.GroupNorm_0(x))
|
||||
h = self.Conv_0(h)
|
||||
if temb is not None:
|
||||
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
||||
h = self.act(self.GroupNorm_1(h))
|
||||
h = self.Dropout_0(h)
|
||||
h = self.Conv_1(h)
|
||||
if x.shape[1] != self.out_ch:
|
||||
if self.conv_shortcut:
|
||||
x = self.Conv_2(x)
|
||||
else:
|
||||
x = self.NIN_0(x)
|
||||
if not self.skip_rescale:
|
||||
return x + h
|
||||
else:
|
||||
return (x + h) / np.sqrt(2.0)
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
class ResidualTemporalBlock(nn.Module):
|
||||
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
|
||||
@@ -818,32 +753,17 @@ class RearrangeDim(nn.Module):
|
||||
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
|
||||
"""1x1 convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
|
||||
"""nXn convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
||||
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
|
||||
"""3x3 convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(
|
||||
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
|
||||
)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def default_init(scale=1.0):
|
||||
"""The same initialization used in DDPM."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
return variance_scaling(scale, "fan_avg", "uniform")
|
||||
|
||||
|
||||
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
"""Ported from JAX."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
|
||||
def _compute_fans(shape, in_axis=1, out_axis=0):
|
||||
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
||||
@@ -853,21 +773,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
|
||||
|
||||
def init(shape, dtype=dtype, device=device):
|
||||
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
||||
if mode == "fan_in":
|
||||
denominator = fan_in
|
||||
elif mode == "fan_out":
|
||||
denominator = fan_out
|
||||
elif mode == "fan_avg":
|
||||
denominator = (fan_in + fan_out) / 2
|
||||
else:
|
||||
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
|
||||
denominator = (fan_in + fan_out) / 2
|
||||
variance = scale / denominator
|
||||
if distribution == "normal":
|
||||
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
|
||||
elif distribution == "uniform":
|
||||
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
|
||||
else:
|
||||
raise ValueError("invalid distribution for variance scaling initializer")
|
||||
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
|
||||
|
||||
return init
|
||||
|
||||
@@ -965,31 +873,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
|
||||
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
|
||||
|
||||
|
||||
def naive_upsample_2d(x, factor=2):
|
||||
_N, C, H, W = x.shape
|
||||
x = torch.reshape(x, (-1, C, H, 1, W, 1))
|
||||
x = x.repeat(1, 1, 1, factor, 1, factor)
|
||||
return torch.reshape(x, (-1, C, H * factor, W * factor))
|
||||
|
||||
|
||||
def naive_downsample_2d(x, factor=2):
|
||||
_N, C, H, W = x.shape
|
||||
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
|
||||
return torch.mean(x, dim=(3, 5))
|
||||
|
||||
|
||||
class NIN(nn.Module):
|
||||
def __init__(self, in_dim, num_units, init_scale=0.1):
|
||||
super().__init__()
|
||||
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
|
||||
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
y = contract_inner(x, self.W) + self.b
|
||||
return y.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
def _setup_kernel(k):
|
||||
k = np.asarray(k, dtype=np.float32)
|
||||
if k.ndim == 1:
|
||||
@@ -998,17 +881,3 @@ def _setup_kernel(k):
|
||||
assert k.ndim == 2
|
||||
assert k.shape[0] == k.shape[1]
|
||||
return k
|
||||
|
||||
|
||||
def contract_inner(x, y):
|
||||
"""tensordot(x, y, 1)."""
|
||||
x_chars = list(string.ascii_lowercase[: len(x.shape)])
|
||||
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
|
||||
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
|
||||
out_chars = x_chars[:-1] + y_chars[1:]
|
||||
return _einsum(x_chars, y_chars, out_chars, x, y)
|
||||
|
||||
|
||||
def _einsum(a, b, c, x, y):
|
||||
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
|
||||
return torch.einsum(einsum_str, x, y)
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
|
||||
import functools
|
||||
import math
|
||||
import string
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -28,116 +27,21 @@ from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
from .resnet import ResnetBlockBigGANpp, ResnetBlockDDPMpp
|
||||
from .resnet import ResnetBlockBigGANpp, downsample_2d, upfirdn2d, upsample_2d
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
||||
def _setup_kernel(k):
|
||||
k = np.asarray(k, dtype=np.float32)
|
||||
if k.ndim == 1:
|
||||
k = np.outer(k, k)
|
||||
k /= np.sum(k)
|
||||
assert k.ndim == 2
|
||||
assert k.shape[0] == k.shape[1]
|
||||
return k
|
||||
|
||||
|
||||
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
||||
_, channel, in_h, in_w = input.shape
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
_, in_h, in_w, minor = input.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out[
|
||||
:,
|
||||
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
||||
:,
|
||||
]
|
||||
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
-1,
|
||||
minor,
|
||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
||||
)
|
||||
out = out.permute(0, 2, 3, 1)
|
||||
out = out[:, ::down_y, ::down_x, :]
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
||||
|
||||
return out.view(-1, channel, out_h, out_w)
|
||||
|
||||
|
||||
# Function ported from StyleGAN2
|
||||
def get_weight(module, shape, weight_var="weight", kernel_init=None):
|
||||
"""Get/create weight tensor for a convolution or fully-connected layer."""
|
||||
|
||||
return module.param(weight_var, kernel_init, shape)
|
||||
|
||||
|
||||
class Conv2d(nn.Module):
|
||||
"""Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel,
|
||||
up=False,
|
||||
down=False,
|
||||
resample_kernel=(1, 3, 3, 1),
|
||||
use_bias=True,
|
||||
kernel_init=None,
|
||||
):
|
||||
super().__init__()
|
||||
assert not (up and down)
|
||||
assert kernel >= 1 and kernel % 2 == 1
|
||||
self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
|
||||
if kernel_init is not None:
|
||||
self.weight.data = kernel_init(self.weight.data.shape)
|
||||
if use_bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_ch))
|
||||
|
||||
self.up = up
|
||||
self.down = down
|
||||
self.resample_kernel = resample_kernel
|
||||
self.kernel = kernel
|
||||
self.use_bias = use_bias
|
||||
|
||||
def forward(self, x):
|
||||
if self.up:
|
||||
x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
|
||||
elif self.down:
|
||||
x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
|
||||
else:
|
||||
x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
|
||||
|
||||
if self.use_bias:
|
||||
x = x + self.bias.reshape(1, -1, 1, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def naive_upsample_2d(x, factor=2):
|
||||
_N, C, H, W = x.shape
|
||||
x = torch.reshape(x, (-1, C, H, 1, W, 1))
|
||||
x = x.repeat(1, 1, 1, factor, 1, factor)
|
||||
return torch.reshape(x, (-1, C, H * factor, W * factor))
|
||||
|
||||
|
||||
def naive_downsample_2d(x, factor=2):
|
||||
_N, C, H, W = x.shape
|
||||
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
|
||||
return torch.mean(x, dim=(3, 5))
|
||||
|
||||
|
||||
def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
|
||||
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||
|
||||
Args:
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
@@ -176,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
|
||||
# Determine data dimensions.
|
||||
stride = [1, 1, factor, factor]
|
||||
output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
|
||||
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
|
||||
output_padding = (
|
||||
output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
|
||||
output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW,
|
||||
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
|
||||
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
|
||||
)
|
||||
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
||||
num_groups = _shape(x, 1) // inC
|
||||
num_groups = x.shape[1] // inC
|
||||
|
||||
# Transpose weights.
|
||||
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
|
||||
@@ -190,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
|
||||
|
||||
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
|
||||
# Original TF code.
|
||||
# x = tf.nn.conv2d_transpose(
|
||||
# x,
|
||||
# w,
|
||||
# output_shape=output_shape,
|
||||
# strides=stride,
|
||||
# padding='VALID',
|
||||
# data_format=data_format)
|
||||
# JAX equivalent
|
||||
|
||||
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
|
||||
|
||||
|
||||
def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
|
||||
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
|
||||
Args:
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
@@ -235,138 +130,9 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
||||
return F.conv2d(x, w, stride=s, padding=0)
|
||||
|
||||
|
||||
def _setup_kernel(k):
|
||||
k = np.asarray(k, dtype=np.float32)
|
||||
if k.ndim == 1:
|
||||
k = np.outer(k, k)
|
||||
k /= np.sum(k)
|
||||
assert k.ndim == 2
|
||||
assert k.shape[0] == k.shape[1]
|
||||
return k
|
||||
|
||||
|
||||
def _shape(x, dim):
|
||||
return x.shape[dim]
|
||||
|
||||
|
||||
def upsample_2d(x, k=None, factor=2, gain=1):
|
||||
r"""Upsample a batch of 2D images with the given filter.
|
||||
|
||||
Args:
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
||||
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
|
||||
multiple of the upsampling factor.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]`
|
||||
"""
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if k is None:
|
||||
k = [1] * factor
|
||||
k = _setup_kernel(k) * (gain * (factor**2))
|
||||
p = k.shape[0] - factor
|
||||
return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
|
||||
|
||||
|
||||
def downsample_2d(x, k=None, factor=2, gain=1):
|
||||
r"""Downsample a batch of 2D images with the given filter.
|
||||
|
||||
Args:
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
||||
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
||||
shape is a multiple of the downsampling factor.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
||||
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]`
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if k is None:
|
||||
k = [1] * factor
|
||||
k = _setup_kernel(k) * gain
|
||||
p = k.shape[0] - factor
|
||||
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
|
||||
"""1x1 convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
|
||||
"""3x3 convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(
|
||||
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
|
||||
)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def _einsum(a, b, c, x, y):
|
||||
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
|
||||
return torch.einsum(einsum_str, x, y)
|
||||
|
||||
|
||||
def contract_inner(x, y):
|
||||
"""tensordot(x, y, 1)."""
|
||||
x_chars = list(string.ascii_lowercase[: len(x.shape)])
|
||||
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
|
||||
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
|
||||
out_chars = x_chars[:-1] + y_chars[1:]
|
||||
return _einsum(x_chars, y_chars, out_chars, x, y)
|
||||
|
||||
|
||||
class NIN(nn.Module):
|
||||
def __init__(self, in_dim, num_units, init_scale=0.1):
|
||||
super().__init__()
|
||||
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
|
||||
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
y = contract_inner(x, self.W) + self.b
|
||||
return y.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
def get_act(nonlinearity):
|
||||
"""Get activation functions from the config file."""
|
||||
|
||||
if nonlinearity.lower() == "elu":
|
||||
return nn.ELU()
|
||||
elif nonlinearity.lower() == "relu":
|
||||
return nn.ReLU()
|
||||
elif nonlinearity.lower() == "lrelu":
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif nonlinearity.lower() == "swish":
|
||||
return nn.SiLU()
|
||||
else:
|
||||
raise NotImplementedError("activation function does not exist!")
|
||||
|
||||
|
||||
def default_init(scale=1.0):
|
||||
"""The same initialization used in DDPM."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
return variance_scaling(scale, "fan_avg", "uniform")
|
||||
|
||||
|
||||
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
"""Ported from JAX."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
|
||||
def _compute_fans(shape, in_axis=1, out_axis=0):
|
||||
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
||||
@@ -376,31 +142,35 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
|
||||
|
||||
def init(shape, dtype=dtype, device=device):
|
||||
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
||||
if mode == "fan_in":
|
||||
denominator = fan_in
|
||||
elif mode == "fan_out":
|
||||
denominator = fan_out
|
||||
elif mode == "fan_avg":
|
||||
denominator = (fan_in + fan_out) / 2
|
||||
else:
|
||||
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
|
||||
denominator = (fan_in + fan_out) / 2
|
||||
variance = scale / denominator
|
||||
if distribution == "normal":
|
||||
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
|
||||
elif distribution == "uniform":
|
||||
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
|
||||
else:
|
||||
raise ValueError("invalid distribution for variance scaling initializer")
|
||||
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
|
||||
|
||||
return init
|
||||
|
||||
|
||||
def Conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
|
||||
"""nXn convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
||||
conv.weight.data = _variance_scaling(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def Linear(dim_in, dim_out):
|
||||
linear = nn.Linear(dim_in, dim_out)
|
||||
linear.weight.data = _variance_scaling()(linear.weight.shape)
|
||||
nn.init.zeros_(linear.bias)
|
||||
return linear
|
||||
|
||||
|
||||
class Combine(nn.Module):
|
||||
"""Combine information from skip connections."""
|
||||
|
||||
def __init__(self, dim1, dim2, method="cat"):
|
||||
super().__init__()
|
||||
self.Conv_0 = conv1x1(dim1, dim2)
|
||||
# 1x1 convolution with DDPM initialization.
|
||||
self.Conv_0 = Conv2d(dim1, dim2, kernel_size=1, padding=0)
|
||||
self.method = method
|
||||
|
||||
def forward(self, x, y):
|
||||
@@ -413,80 +183,40 @@ class Combine(nn.Module):
|
||||
raise ValueError(f"Method {self.method} not recognized.")
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
|
||||
class FirUpsample(nn.Module):
|
||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
if not fir:
|
||||
if with_conv:
|
||||
self.Conv_0 = conv3x3(in_ch, out_ch)
|
||||
else:
|
||||
if with_conv:
|
||||
self.Conv2d_0 = Conv2d(
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel=3,
|
||||
up=True,
|
||||
resample_kernel=fir_kernel,
|
||||
use_bias=True,
|
||||
kernel_init=default_init(),
|
||||
)
|
||||
self.fir = fir
|
||||
if with_conv:
|
||||
self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
self.with_conv = with_conv
|
||||
self.fir_kernel = fir_kernel
|
||||
self.out_ch = out_ch
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
if not self.fir:
|
||||
h = F.interpolate(x, (H * 2, W * 2), "nearest")
|
||||
if self.with_conv:
|
||||
h = self.Conv_0(h)
|
||||
if self.with_conv:
|
||||
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
else:
|
||||
if not self.with_conv:
|
||||
h = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
else:
|
||||
h = self.Conv2d_0(x)
|
||||
h = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
|
||||
class FirDownsample(nn.Module):
|
||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
if not fir:
|
||||
if with_conv:
|
||||
self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
|
||||
else:
|
||||
if with_conv:
|
||||
self.Conv2d_0 = Conv2d(
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel=3,
|
||||
down=True,
|
||||
resample_kernel=fir_kernel,
|
||||
use_bias=True,
|
||||
kernel_init=default_init(),
|
||||
)
|
||||
self.fir = fir
|
||||
if with_conv:
|
||||
self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
self.fir_kernel = fir_kernel
|
||||
self.with_conv = with_conv
|
||||
self.out_ch = out_ch
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
if not self.fir:
|
||||
if self.with_conv:
|
||||
x = F.pad(x, (0, 1, 0, 1))
|
||||
x = self.Conv_0(x)
|
||||
else:
|
||||
x = F.avg_pool2d(x, 2, stride=2)
|
||||
if self.with_conv:
|
||||
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
else:
|
||||
if not self.with_conv:
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
else:
|
||||
x = self.Conv2d_0(x)
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
|
||||
return x
|
||||
|
||||
@@ -496,63 +226,52 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
centered=False,
|
||||
image_size=1024,
|
||||
num_channels=3,
|
||||
attention_type="ddpm",
|
||||
attn_resolutions=(16,),
|
||||
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
|
||||
conditional=True,
|
||||
conv_size=3,
|
||||
dropout=0.0,
|
||||
embedding_type="fourier",
|
||||
fir=True,
|
||||
fir=True, # TODO (patil-suraj) remove this option from here and pre-trained model configs
|
||||
fir_kernel=(1, 3, 3, 1),
|
||||
fourier_scale=16,
|
||||
init_scale=0.0,
|
||||
nf=16,
|
||||
nonlinearity="swish",
|
||||
normalization="GroupNorm",
|
||||
num_res_blocks=1,
|
||||
progressive="output_skip",
|
||||
progressive_combine="sum",
|
||||
progressive_input="input_skip",
|
||||
resamp_with_conv=True,
|
||||
resblock_type="biggan",
|
||||
scale_by_sigma=True,
|
||||
skip_rescale=True,
|
||||
continuous=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
centered=centered,
|
||||
image_size=image_size,
|
||||
num_channels=num_channels,
|
||||
attention_type=attention_type,
|
||||
attn_resolutions=attn_resolutions,
|
||||
ch_mult=ch_mult,
|
||||
conditional=conditional,
|
||||
conv_size=conv_size,
|
||||
dropout=dropout,
|
||||
embedding_type=embedding_type,
|
||||
fir=fir,
|
||||
fir_kernel=fir_kernel,
|
||||
fourier_scale=fourier_scale,
|
||||
init_scale=init_scale,
|
||||
nf=nf,
|
||||
nonlinearity=nonlinearity,
|
||||
normalization=normalization,
|
||||
num_res_blocks=num_res_blocks,
|
||||
progressive=progressive,
|
||||
progressive_combine=progressive_combine,
|
||||
progressive_input=progressive_input,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
resblock_type=resblock_type,
|
||||
scale_by_sigma=scale_by_sigma,
|
||||
skip_rescale=skip_rescale,
|
||||
continuous=continuous,
|
||||
)
|
||||
self.act = act = get_act(nonlinearity)
|
||||
self.act = act = nn.SiLU()
|
||||
|
||||
self.nf = nf
|
||||
self.num_res_blocks = num_res_blocks
|
||||
@@ -562,7 +281,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
|
||||
self.conditional = conditional
|
||||
self.skip_rescale = skip_rescale
|
||||
self.resblock_type = resblock_type
|
||||
self.progressive = progressive
|
||||
self.progressive_input = progressive_input
|
||||
self.embedding_type = embedding_type
|
||||
@@ -585,53 +303,33 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
raise ValueError(f"embedding type {embedding_type} unknown.")
|
||||
|
||||
if conditional:
|
||||
modules.append(nn.Linear(embed_dim, nf * 4))
|
||||
modules[-1].weight.data = default_init()(modules[-1].weight.shape)
|
||||
nn.init.zeros_(modules[-1].bias)
|
||||
modules.append(nn.Linear(nf * 4, nf * 4))
|
||||
modules[-1].weight.data = default_init()(modules[-1].weight.shape)
|
||||
nn.init.zeros_(modules[-1].bias)
|
||||
modules.append(Linear(embed_dim, nf * 4))
|
||||
modules.append(Linear(nf * 4, nf * 4))
|
||||
|
||||
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
|
||||
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
||||
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
|
||||
|
||||
if progressive == "output_skip":
|
||||
self.pyramid_upsample = Up_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
||||
self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False)
|
||||
elif progressive == "residual":
|
||||
pyramid_upsample = functools.partial(Up_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True)
|
||||
pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True)
|
||||
|
||||
Down_sample = functools.partial(Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
||||
Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
|
||||
|
||||
if progressive_input == "input_skip":
|
||||
self.pyramid_downsample = Down_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
||||
self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False)
|
||||
elif progressive_input == "residual":
|
||||
pyramid_downsample = functools.partial(Down_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True)
|
||||
pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True)
|
||||
|
||||
if resblock_type == "ddpm":
|
||||
ResnetBlock = functools.partial(
|
||||
ResnetBlockDDPMpp,
|
||||
act=act,
|
||||
dropout=dropout,
|
||||
init_scale=init_scale,
|
||||
skip_rescale=skip_rescale,
|
||||
temb_dim=nf * 4,
|
||||
)
|
||||
|
||||
elif resblock_type == "biggan":
|
||||
ResnetBlock = functools.partial(
|
||||
ResnetBlockBigGANpp,
|
||||
act=act,
|
||||
dropout=dropout,
|
||||
fir=fir,
|
||||
fir_kernel=fir_kernel,
|
||||
init_scale=init_scale,
|
||||
skip_rescale=skip_rescale,
|
||||
temb_dim=nf * 4,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"resblock type {resblock_type} unrecognized.")
|
||||
ResnetBlock = functools.partial(
|
||||
ResnetBlockBigGANpp,
|
||||
act=act,
|
||||
dropout=dropout,
|
||||
fir_kernel=fir_kernel,
|
||||
init_scale=init_scale,
|
||||
skip_rescale=skip_rescale,
|
||||
temb_dim=nf * 4,
|
||||
)
|
||||
|
||||
# Downsampling block
|
||||
|
||||
@@ -639,7 +337,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
if progressive_input != "none":
|
||||
input_pyramid_ch = channels
|
||||
|
||||
modules.append(conv3x3(channels, nf))
|
||||
modules.append(Conv2d(channels, nf, kernel_size=3, padding=1))
|
||||
hs_c = [nf]
|
||||
|
||||
in_ch = nf
|
||||
@@ -655,10 +353,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
hs_c.append(in_ch)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
if resblock_type == "ddpm":
|
||||
modules.append(Downsample(in_ch=in_ch))
|
||||
else:
|
||||
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
||||
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
||||
|
||||
if progressive_input == "input_skip":
|
||||
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
||||
@@ -691,18 +386,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
if i_level == self.num_resolutions - 1:
|
||||
if progressive == "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
||||
modules.append(Conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1))
|
||||
pyramid_ch = channels
|
||||
elif progressive == "residual":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv3x3(in_ch, in_ch, bias=True))
|
||||
modules.append(Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
|
||||
pyramid_ch = in_ch
|
||||
else:
|
||||
raise ValueError(f"{progressive} is not a valid name.")
|
||||
else:
|
||||
if progressive == "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
|
||||
modules.append(
|
||||
Conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)
|
||||
)
|
||||
pyramid_ch = channels
|
||||
elif progressive == "residual":
|
||||
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
|
||||
@@ -711,16 +408,13 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
raise ValueError(f"{progressive} is not a valid name")
|
||||
|
||||
if i_level != 0:
|
||||
if resblock_type == "ddpm":
|
||||
modules.append(Upsample(in_ch=in_ch))
|
||||
else:
|
||||
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
||||
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
||||
|
||||
assert not hs_c
|
||||
|
||||
if progressive != "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
||||
modules.append(Conv2d(in_ch, channels, init_scale=init_scale))
|
||||
|
||||
self.all_modules = nn.ModuleList(modules)
|
||||
|
||||
@@ -751,9 +445,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
temb = None
|
||||
|
||||
if not self.config.centered:
|
||||
# If input data is in [0, 1]
|
||||
x = 2 * x - 1.0
|
||||
# If input data is in [0, 1]
|
||||
x = 2 * x - 1.0
|
||||
|
||||
# Downsampling block
|
||||
input_pyramid = None
|
||||
@@ -774,12 +467,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
hs.append(h)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
if self.resblock_type == "ddpm":
|
||||
h = modules[m_idx](hs[-1])
|
||||
m_idx += 1
|
||||
else:
|
||||
h = modules[m_idx](hs[-1], temb)
|
||||
m_idx += 1
|
||||
h = modules[m_idx](hs[-1], temb)
|
||||
m_idx += 1
|
||||
|
||||
if self.progressive_input == "input_skip":
|
||||
input_pyramid = self.pyramid_downsample(input_pyramid)
|
||||
@@ -851,12 +540,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
raise ValueError(f"{self.progressive} is not a valid name")
|
||||
|
||||
if i_level != 0:
|
||||
if self.resblock_type == "ddpm":
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
else:
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
|
||||
assert not hs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user