mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'main' of https://github.com/huggingface/diffusers into conversion-scripts
This commit is contained in:
2
setup.py
2
setup.py
@@ -88,7 +88,7 @@ _deps = [
|
||||
"requests",
|
||||
"torch>=1.4",
|
||||
"tensorboard",
|
||||
"modelcards=0.1.4"
|
||||
"modelcards==0.1.4"
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
|
||||
@@ -14,4 +14,5 @@ deps = {
|
||||
"requests": "requests",
|
||||
"torch": "torch>=1.4",
|
||||
"tensorboard": "tensorboard",
|
||||
"modelcards": "modelcards==0.1.4",
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import string
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -79,18 +79,25 @@ class Upsample(nn.Module):
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
|
||||
conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
@@ -103,7 +110,10 @@ class Upsample(nn.Module):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
if self.name == "conv":
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = self.Conv2d_0(x)
|
||||
|
||||
return x
|
||||
|
||||
@@ -135,6 +145,8 @@ class Downsample(nn.Module):
|
||||
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
elif name == "Conv2d_0":
|
||||
self.Conv2d_0 = conv
|
||||
else:
|
||||
self.op = conv
|
||||
|
||||
@@ -146,6 +158,8 @@ class Downsample(nn.Module):
|
||||
|
||||
if self.name == "conv":
|
||||
return self.conv(x)
|
||||
elif self.name == "Conv2d_0":
|
||||
return self.Conv2d_0(x)
|
||||
else:
|
||||
return self.op(x)
|
||||
|
||||
@@ -162,110 +176,7 @@ class Downsample(nn.Module):
|
||||
|
||||
# RESNETS
|
||||
|
||||
# unet_glide.py & unet_ldm.py
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
|
||||
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
|
||||
use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
|
||||
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
|
||||
downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
use_scale_shift_norm=False,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.emb_channels = emb_channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels, swish=1.0),
|
||||
nn.Identity(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
|
||||
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
|
||||
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
|
||||
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
# unet.py and unet_grad_tts.py
|
||||
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -279,7 +190,12 @@ class ResnetBlock(nn.Module):
|
||||
pre_norm=True,
|
||||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
time_embedding_norm="default",
|
||||
up=False,
|
||||
down=False,
|
||||
overwrite_for_grad_tts=False,
|
||||
overwrite_for_ldm=False,
|
||||
overwrite_for_glide=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = pre_norm
|
||||
@@ -287,6 +203,9 @@ class ResnetBlock(nn.Module):
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.up = up
|
||||
self.down = down
|
||||
|
||||
if self.pre_norm:
|
||||
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
|
||||
@@ -294,23 +213,38 @@ class ResnetBlock(nn.Module):
|
||||
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
|
||||
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
|
||||
if time_embedding_norm == "default":
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
elif time_embedding_norm == "scale_shift":
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
|
||||
|
||||
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = nonlinearity
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
|
||||
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
|
||||
elif down:
|
||||
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
|
||||
self.is_overwritten = False
|
||||
self.overwrite_for_glide = overwrite_for_glide
|
||||
self.overwrite_for_grad_tts = overwrite_for_grad_tts
|
||||
self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
|
||||
if self.overwrite_for_grad_tts:
|
||||
dim = in_channels
|
||||
dim_out = out_channels
|
||||
@@ -324,6 +258,37 @@ class ResnetBlock(nn.Module):
|
||||
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
|
||||
else:
|
||||
self.res_conv = torch.nn.Identity()
|
||||
elif self.overwrite_for_ldm:
|
||||
dims = 2
|
||||
# eps = 1e-5
|
||||
# non_linearity = "silu"
|
||||
# overwrite_for_ldm
|
||||
channels = in_channels
|
||||
emb_channels = temb_channels
|
||||
use_scale_shift_norm = False
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels, swish=1.0),
|
||||
nn.Identity(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
)
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if self.time_embedding_norm == "scale_shift" else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
|
||||
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
|
||||
)
|
||||
if self.out_channels == in_channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
def set_weights_grad_tts(self):
|
||||
self.conv1.weight.data = self.block1.block[0].weight.data
|
||||
@@ -343,30 +308,67 @@ class ResnetBlock(nn.Module):
|
||||
self.nin_shortcut.weight.data = self.res_conv.weight.data
|
||||
self.nin_shortcut.bias.data = self.res_conv.bias.data
|
||||
|
||||
def forward(self, x, temb, mask=None):
|
||||
def set_weights_ldm(self):
|
||||
self.norm1.weight.data = self.in_layers[0].weight.data
|
||||
self.norm1.bias.data = self.in_layers[0].bias.data
|
||||
|
||||
self.conv1.weight.data = self.in_layers[-1].weight.data
|
||||
self.conv1.bias.data = self.in_layers[-1].bias.data
|
||||
|
||||
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
|
||||
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
|
||||
|
||||
self.norm2.weight.data = self.out_layers[0].weight.data
|
||||
self.norm2.bias.data = self.out_layers[0].bias.data
|
||||
|
||||
self.conv2.weight.data = self.out_layers[-1].weight.data
|
||||
self.conv2.bias.data = self.out_layers[-1].bias.data
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut.weight.data = self.skip_connection.weight.data
|
||||
self.nin_shortcut.bias.data = self.skip_connection.bias.data
|
||||
|
||||
def forward(self, x, temb, mask=1.0):
|
||||
# TODO(Patrick) eventually this class should be split into multiple classes
|
||||
# too many if else statements
|
||||
if self.overwrite_for_grad_tts and not self.is_overwritten:
|
||||
self.set_weights_grad_tts()
|
||||
self.is_overwritten = True
|
||||
elif self.overwrite_for_ldm and not self.is_overwritten:
|
||||
self.set_weights_ldm()
|
||||
self.is_overwritten = True
|
||||
|
||||
h = x
|
||||
h = h * mask if mask is not None else h
|
||||
h = h * mask
|
||||
if self.pre_norm:
|
||||
h = self.norm1(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
if self.up or self.down:
|
||||
x = self.x_upd(x)
|
||||
h = self.h_upd(h)
|
||||
|
||||
h = self.conv1(h)
|
||||
|
||||
if not self.pre_norm:
|
||||
h = self.norm1(h)
|
||||
h = self.nonlinearity(h)
|
||||
h = h * mask if mask is not None else h
|
||||
h = h * mask
|
||||
|
||||
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
if self.time_embedding_norm == "scale_shift":
|
||||
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||
|
||||
h = h * mask if mask is not None else h
|
||||
if self.pre_norm:
|
||||
h = self.norm2(h)
|
||||
h = h + h * scale + shift
|
||||
h = self.nonlinearity(h)
|
||||
elif self.time_embedding_norm == "default":
|
||||
h = h + temb
|
||||
h = h * mask
|
||||
if self.pre_norm:
|
||||
h = self.norm2(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
@@ -374,14 +376,11 @@ class ResnetBlock(nn.Module):
|
||||
if not self.pre_norm:
|
||||
h = self.norm2(h)
|
||||
h = self.nonlinearity(h)
|
||||
h = h * mask if mask is not None else h
|
||||
h = h * mask
|
||||
|
||||
x = x * mask if mask is not None else x
|
||||
x = x * mask
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
@@ -394,10 +393,6 @@ class Block(torch.nn.Module):
|
||||
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
|
||||
)
|
||||
|
||||
def forward(self, x, mask):
|
||||
output = self.block(x * mask)
|
||||
return output * mask
|
||||
|
||||
|
||||
# unet_score_estimation.py
|
||||
class ResnetBlockBigGANpp(nn.Module):
|
||||
@@ -424,17 +419,29 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||
self.fir = fir
|
||||
self.fir_kernel = fir_kernel
|
||||
|
||||
self.Conv_0 = conv3x3(in_ch, out_ch)
|
||||
if self.up:
|
||||
if self.fir:
|
||||
self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2)
|
||||
else:
|
||||
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
||||
elif self.down:
|
||||
if self.fir:
|
||||
self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2)
|
||||
else:
|
||||
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
||||
|
||||
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
|
||||
if temb_dim is not None:
|
||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
||||
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
|
||||
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
|
||||
nn.init.zeros_(self.Dense_0.bias)
|
||||
|
||||
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
||||
self.Dropout_0 = nn.Dropout(dropout)
|
||||
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
||||
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
|
||||
if in_ch != out_ch or up or down:
|
||||
self.Conv_2 = conv1x1(in_ch, out_ch)
|
||||
# 1x1 convolution with DDPM initialization.
|
||||
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
|
||||
|
||||
self.skip_rescale = skip_rescale
|
||||
self.act = act
|
||||
@@ -445,19 +452,11 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||
h = self.act(self.GroupNorm_0(x))
|
||||
|
||||
if self.up:
|
||||
if self.fir:
|
||||
h = upsample_2d(h, self.fir_kernel, factor=2)
|
||||
x = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
else:
|
||||
h = naive_upsample_2d(h, factor=2)
|
||||
x = naive_upsample_2d(x, factor=2)
|
||||
h = self.upsample(h)
|
||||
x = self.upsample(x)
|
||||
elif self.down:
|
||||
if self.fir:
|
||||
h = downsample_2d(h, self.fir_kernel, factor=2)
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
else:
|
||||
h = naive_downsample_2d(h, factor=2)
|
||||
x = naive_downsample_2d(x, factor=2)
|
||||
h = self.downsample(h)
|
||||
x = self.downsample(x)
|
||||
|
||||
h = self.Conv_0(h)
|
||||
# Add bias to each feature map conditioned on the time embedding
|
||||
@@ -476,62 +475,6 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||
return (x + h) / np.sqrt(2.0)
|
||||
|
||||
|
||||
# unet_score_estimation.py
|
||||
class ResnetBlockDDPMpp(nn.Module):
|
||||
"""ResBlock adapted from DDPM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act,
|
||||
in_ch,
|
||||
out_ch=None,
|
||||
temb_dim=None,
|
||||
conv_shortcut=False,
|
||||
dropout=0.1,
|
||||
skip_rescale=False,
|
||||
init_scale=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
||||
self.Conv_0 = conv3x3(in_ch, out_ch)
|
||||
if temb_dim is not None:
|
||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
||||
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
|
||||
nn.init.zeros_(self.Dense_0.bias)
|
||||
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
||||
self.Dropout_0 = nn.Dropout(dropout)
|
||||
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
||||
if in_ch != out_ch:
|
||||
if conv_shortcut:
|
||||
self.Conv_2 = conv3x3(in_ch, out_ch)
|
||||
else:
|
||||
self.NIN_0 = NIN(in_ch, out_ch)
|
||||
|
||||
self.skip_rescale = skip_rescale
|
||||
self.act = act
|
||||
self.out_ch = out_ch
|
||||
self.conv_shortcut = conv_shortcut
|
||||
|
||||
def forward(self, x, temb=None):
|
||||
h = self.act(self.GroupNorm_0(x))
|
||||
h = self.Conv_0(h)
|
||||
if temb is not None:
|
||||
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
||||
h = self.act(self.GroupNorm_1(h))
|
||||
h = self.Dropout_0(h)
|
||||
h = self.Conv_1(h)
|
||||
if x.shape[1] != self.out_ch:
|
||||
if self.conv_shortcut:
|
||||
x = self.Conv_2(x)
|
||||
else:
|
||||
x = self.NIN_0(x)
|
||||
if not self.skip_rescale:
|
||||
return x + h
|
||||
else:
|
||||
return (x + h) / np.sqrt(2.0)
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
class ResidualTemporalBlock(nn.Module):
|
||||
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
|
||||
@@ -649,32 +592,17 @@ class RearrangeDim(nn.Module):
|
||||
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
|
||||
"""1x1 convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
|
||||
"""nXn convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
||||
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
|
||||
"""3x3 convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(
|
||||
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
|
||||
)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def default_init(scale=1.0):
|
||||
"""The same initialization used in DDPM."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
return variance_scaling(scale, "fan_avg", "uniform")
|
||||
|
||||
|
||||
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
"""Ported from JAX."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
|
||||
def _compute_fans(shape, in_axis=1, out_axis=0):
|
||||
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
||||
@@ -684,21 +612,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
|
||||
|
||||
def init(shape, dtype=dtype, device=device):
|
||||
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
||||
if mode == "fan_in":
|
||||
denominator = fan_in
|
||||
elif mode == "fan_out":
|
||||
denominator = fan_out
|
||||
elif mode == "fan_avg":
|
||||
denominator = (fan_in + fan_out) / 2
|
||||
else:
|
||||
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
|
||||
denominator = (fan_in + fan_out) / 2
|
||||
variance = scale / denominator
|
||||
if distribution == "normal":
|
||||
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
|
||||
elif distribution == "uniform":
|
||||
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
|
||||
else:
|
||||
raise ValueError("invalid distribution for variance scaling initializer")
|
||||
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
|
||||
|
||||
return init
|
||||
|
||||
@@ -796,31 +712,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
|
||||
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
|
||||
|
||||
|
||||
def naive_upsample_2d(x, factor=2):
|
||||
_N, C, H, W = x.shape
|
||||
x = torch.reshape(x, (-1, C, H, 1, W, 1))
|
||||
x = x.repeat(1, 1, 1, factor, 1, factor)
|
||||
return torch.reshape(x, (-1, C, H * factor, W * factor))
|
||||
|
||||
|
||||
def naive_downsample_2d(x, factor=2):
|
||||
_N, C, H, W = x.shape
|
||||
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
|
||||
return torch.mean(x, dim=(3, 5))
|
||||
|
||||
|
||||
class NIN(nn.Module):
|
||||
def __init__(self, in_dim, num_units, init_scale=0.1):
|
||||
super().__init__()
|
||||
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
|
||||
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
y = contract_inner(x, self.W) + self.b
|
||||
return y.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
def _setup_kernel(k):
|
||||
k = np.asarray(k, dtype=np.float32)
|
||||
if k.ndim == 1:
|
||||
@@ -829,17 +720,3 @@ def _setup_kernel(k):
|
||||
assert k.ndim == 2
|
||||
assert k.shape[0] == k.shape[1]
|
||||
return k
|
||||
|
||||
|
||||
def contract_inner(x, y):
|
||||
"""tensordot(x, y, 1)."""
|
||||
x_chars = list(string.ascii_lowercase[: len(x.shape)])
|
||||
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
|
||||
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
|
||||
out_chars = x_chars[:-1] + y_chars[1:]
|
||||
return _einsum(x_chars, y_chars, out_chars, x, y)
|
||||
|
||||
|
||||
def _einsum(a, b, c, x, y):
|
||||
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
|
||||
return torch.einsum(einsum_str, x, y)
|
||||
|
||||
@@ -34,48 +34,6 @@ def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
# class ResnetBlock(nn.Module):
|
||||
# def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
||||
# super().__init__()
|
||||
# self.in_channels = in_channels
|
||||
# out_channels = in_channels if out_channels is None else out_channels
|
||||
# self.out_channels = out_channels
|
||||
# self.use_conv_shortcut = conv_shortcut
|
||||
#
|
||||
# self.norm1 = Normalize(in_channels)
|
||||
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
# self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
# self.norm2 = Normalize(out_channels)
|
||||
# self.dropout = torch.nn.Dropout(dropout)
|
||||
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
# if self.in_channels != self.out_channels:
|
||||
# if self.use_conv_shortcut:
|
||||
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
# else:
|
||||
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
#
|
||||
# def forward(self, x, temb):
|
||||
# h = x
|
||||
# h = self.norm1(h)
|
||||
# h = nonlinearity(h)
|
||||
# h = self.conv1(h)
|
||||
#
|
||||
# h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
#
|
||||
# h = self.norm2(h)
|
||||
# h = nonlinearity(h)
|
||||
# h = self.dropout(h)
|
||||
# h = self.conv2(h)
|
||||
#
|
||||
# if self.in_channels != self.out_channels:
|
||||
# if self.use_conv_shortcut:
|
||||
# x = self.conv_shortcut(x)
|
||||
# else:
|
||||
# x = self.nin_shortcut(x)
|
||||
#
|
||||
# return x + h
|
||||
|
||||
|
||||
class UNetModel(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, ResBlock, TimestepBlock, Upsample
|
||||
from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
@@ -29,19 +29,6 @@ def convert_module_to_f32(l):
|
||||
l.bias.data = l.bias.data.float()
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
@@ -101,7 +88,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
|
||||
def forward(self, x, emb, encoder_out=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, AttentionBlock):
|
||||
x = layer(x, encoder_out)
|
||||
@@ -190,14 +177,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=int(mult * model_channels),
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=mult * model_channels,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
|
||||
overwrite_for_glide=True,
|
||||
)
|
||||
]
|
||||
ch = int(mult * model_channels)
|
||||
@@ -218,14 +206,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
|
||||
overwrite_for_glide=True,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
@@ -240,13 +229,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
|
||||
overwrite_for_glide=True,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
@@ -255,13 +245,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
num_head_channels=num_head_channels,
|
||||
encoder_channels=transformer_dim,
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
|
||||
overwrite_for_glide=True,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
@@ -271,15 +262,16 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=int(model_channels * mult),
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
ResnetBlock(
|
||||
in_channels=ch + ich,
|
||||
out_channels=model_channels * mult,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
|
||||
overwrite_for_glide=True,
|
||||
),
|
||||
]
|
||||
ch = int(model_channels * mult)
|
||||
if ds in attention_resolutions:
|
||||
@@ -295,14 +287,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
|
||||
overwrite_for_glide=True,
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
|
||||
@@ -10,7 +10,10 @@ from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, ResBlock, TimestepBlock, Upsample
|
||||
from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample
|
||||
|
||||
|
||||
# from .resnet import ResBlock
|
||||
|
||||
|
||||
def exists(val):
|
||||
@@ -75,182 +78,6 @@ def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
# class LinearAttention(nn.Module):
|
||||
# def __init__(self, dim, heads=4, dim_head=32):
|
||||
# super().__init__()
|
||||
# self.heads = heads
|
||||
# hidden_dim = dim_head * heads
|
||||
# self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
# self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
#
|
||||
# def forward(self, x):
|
||||
# b, c, h, w = x.shape
|
||||
# qkv = self.to_qkv(x)
|
||||
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# k = k.softmax(dim=-1)
|
||||
# context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
# out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
||||
# return self.to_out(out)
|
||||
#
|
||||
|
||||
# class SpatialSelfAttention(nn.Module):
|
||||
# def __init__(self, in_channels):
|
||||
# super().__init__()
|
||||
# self.in_channels = in_channels
|
||||
#
|
||||
# self.norm = Normalize(in_channels)
|
||||
# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
#
|
||||
# def forward(self, x):
|
||||
# h_ = x
|
||||
# h_ = self.norm(h_)
|
||||
# q = self.q(h_)
|
||||
# k = self.k(h_)
|
||||
# v = self.v(h_)
|
||||
#
|
||||
# compute attention
|
||||
# b, c, h, w = q.shape
|
||||
# q = rearrange(q, "b c h w -> b (h w) c")
|
||||
# k = rearrange(k, "b c h w -> b c (h w)")
|
||||
# w_ = torch.einsum("bij,bjk->bik", q, k)
|
||||
#
|
||||
# w_ = w_ * (int(c) ** (-0.5))
|
||||
# w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
#
|
||||
# attend to values
|
||||
# v = rearrange(v, "b c h w -> b c (h w)")
|
||||
# w_ = rearrange(w_, "b i j -> b j i")
|
||||
# h_ = torch.einsum("bij,bjk->bik", v, w_)
|
||||
# h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
||||
# h_ = self.proj_out(h_)
|
||||
#
|
||||
# return x + h_
|
||||
#
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
return tensor
|
||||
|
||||
def reshape_batch_dim_to_heads(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
batch_size, sequence_length, dim = x.shape
|
||||
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q = self.reshape_heads_to_batch_dim(q)
|
||||
k = self.reshape_heads_to_batch_dim(k)
|
||||
v = self.reshape_heads_to_batch_dim(v)
|
||||
|
||||
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = mask.reshape(batch_size, -1)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = mask[:, None, :].repeat(h, 1, 1)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
||||
out = self.reshape_batch_dim_to_heads(out)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
||||
standard transformer action. Finally, reshape to image
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
"""
|
||||
Convert primitive modules to float16.
|
||||
@@ -271,19 +98,6 @@ def convert_module_to_f32(l):
|
||||
l.bias.data = l.bias.data.float()
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
@@ -327,36 +141,6 @@ def normalization(channels, swish=0.0):
|
||||
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads_channels: int,
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *_spatial = x.shape
|
||||
x = x.reshape(b, c, -1) # NC(HW)
|
||||
x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
||||
x = self.qkv_proj(x)
|
||||
x = self.attention(x)
|
||||
x = self.c_proj(x)
|
||||
return x[:, :, 0]
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
||||
@@ -364,7 +148,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
|
||||
def forward(self, x, emb, context=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context)
|
||||
@@ -373,39 +157,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
return x
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
|
||||
T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = torch.einsum(
|
||||
"bct,bcs->bts",
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
def count_flops_attn(model, _x, y):
|
||||
"""
|
||||
A counter for the `thop` package to count the operations in an attention operation. Meant to be used like:
|
||||
@@ -559,14 +310,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
@@ -599,20 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
|
||||
)
|
||||
Downsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op")
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
@@ -629,13 +367,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=None,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
@@ -646,13 +385,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=None,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
@@ -662,15 +402,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
ResnetBlock(
|
||||
in_channels=ch + ich,
|
||||
out_channels=model_channels * mult,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
@@ -697,20 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
)
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
|
||||
)
|
||||
layers.append(Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
@@ -777,212 +504,119 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class EncoderUNetModel(nn.Module):
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
The half UNet model with attention and timestep embedding. For usage, see UNet.
|
||||
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
||||
standard transformer action. Finally, reshape to image
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
pool="adaptive",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
|
||||
super().__init__()
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
self.pool = pool
|
||||
if pool == "adaptive":
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
||||
nn.Flatten(),
|
||||
)
|
||||
elif pool == "attention":
|
||||
assert num_head_channels != -1
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
|
||||
)
|
||||
elif pool == "spatial":
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
elif pool == "spatial_v2":
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
normalization(2048),
|
||||
nn.SiLU(),
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected {pool} pooling")
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f16)
|
||||
self.middle_block.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self):
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f32)
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, timesteps):
|
||||
"""
|
||||
Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch
|
||||
of timesteps. :return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
emb = self.time_embed(
|
||||
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
)
|
||||
def forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
results = []
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
if self.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = self.middle_block(h, emb)
|
||||
if self.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = torch.cat(results, axis=-1)
|
||||
return self.out(h)
|
||||
else:
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
return tensor
|
||||
|
||||
def reshape_batch_dim_to_heads(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
batch_size, sequence_length, dim = x.shape
|
||||
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q = self.reshape_heads_to_batch_dim(q)
|
||||
k = self.reshape_heads_to_batch_dim(k)
|
||||
v = self.reshape_heads_to_batch_dim(v)
|
||||
|
||||
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = mask.reshape(batch_size, -1)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = mask[:, None, :].repeat(h, 1, 1)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
||||
out = self.reshape_batch_dim_to_heads(out)
|
||||
return self.to_out(out)
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import ResidualTemporalBlock
|
||||
from .resnet import Downsample, ResidualTemporalBlock, Upsample
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
@@ -18,24 +18,6 @@ class SinusoidalPosEmb(nn.Module):
|
||||
return get_timestep_embedding(x, self.dim)
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class RearrangeDim(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -114,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
[
|
||||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
||||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
||||
Downsample(dim_out, use_conv=True, dims=1) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -134,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
[
|
||||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
||||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
||||
Upsample(dim_in, use_conv_transpose=True, dims=1) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
|
||||
import functools
|
||||
import math
|
||||
import string
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -28,116 +27,21 @@ from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
from .resnet import ResnetBlockBigGANpp, ResnetBlockDDPMpp
|
||||
from .resnet import Downsample, ResnetBlockBigGANpp, Upsample, downsample_2d, upfirdn2d, upsample_2d
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
||||
def _setup_kernel(k):
|
||||
k = np.asarray(k, dtype=np.float32)
|
||||
if k.ndim == 1:
|
||||
k = np.outer(k, k)
|
||||
k /= np.sum(k)
|
||||
assert k.ndim == 2
|
||||
assert k.shape[0] == k.shape[1]
|
||||
return k
|
||||
|
||||
|
||||
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
||||
_, channel, in_h, in_w = input.shape
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
_, in_h, in_w, minor = input.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out[
|
||||
:,
|
||||
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
||||
:,
|
||||
]
|
||||
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
-1,
|
||||
minor,
|
||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
||||
)
|
||||
out = out.permute(0, 2, 3, 1)
|
||||
out = out[:, ::down_y, ::down_x, :]
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
||||
|
||||
return out.view(-1, channel, out_h, out_w)
|
||||
|
||||
|
||||
# Function ported from StyleGAN2
|
||||
def get_weight(module, shape, weight_var="weight", kernel_init=None):
|
||||
"""Get/create weight tensor for a convolution or fully-connected layer."""
|
||||
|
||||
return module.param(weight_var, kernel_init, shape)
|
||||
|
||||
|
||||
class Conv2d(nn.Module):
|
||||
"""Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel,
|
||||
up=False,
|
||||
down=False,
|
||||
resample_kernel=(1, 3, 3, 1),
|
||||
use_bias=True,
|
||||
kernel_init=None,
|
||||
):
|
||||
super().__init__()
|
||||
assert not (up and down)
|
||||
assert kernel >= 1 and kernel % 2 == 1
|
||||
self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
|
||||
if kernel_init is not None:
|
||||
self.weight.data = kernel_init(self.weight.data.shape)
|
||||
if use_bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_ch))
|
||||
|
||||
self.up = up
|
||||
self.down = down
|
||||
self.resample_kernel = resample_kernel
|
||||
self.kernel = kernel
|
||||
self.use_bias = use_bias
|
||||
|
||||
def forward(self, x):
|
||||
if self.up:
|
||||
x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
|
||||
elif self.down:
|
||||
x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
|
||||
else:
|
||||
x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
|
||||
|
||||
if self.use_bias:
|
||||
x = x + self.bias.reshape(1, -1, 1, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def naive_upsample_2d(x, factor=2):
|
||||
_N, C, H, W = x.shape
|
||||
x = torch.reshape(x, (-1, C, H, 1, W, 1))
|
||||
x = x.repeat(1, 1, 1, factor, 1, factor)
|
||||
return torch.reshape(x, (-1, C, H * factor, W * factor))
|
||||
|
||||
|
||||
def naive_downsample_2d(x, factor=2):
|
||||
_N, C, H, W = x.shape
|
||||
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
|
||||
return torch.mean(x, dim=(3, 5))
|
||||
|
||||
|
||||
def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
|
||||
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||
|
||||
Args:
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
@@ -176,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
|
||||
# Determine data dimensions.
|
||||
stride = [1, 1, factor, factor]
|
||||
output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
|
||||
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
|
||||
output_padding = (
|
||||
output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
|
||||
output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW,
|
||||
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
|
||||
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
|
||||
)
|
||||
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
||||
num_groups = _shape(x, 1) // inC
|
||||
num_groups = x.shape[1] // inC
|
||||
|
||||
# Transpose weights.
|
||||
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
|
||||
@@ -190,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
|
||||
|
||||
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
|
||||
# Original TF code.
|
||||
# x = tf.nn.conv2d_transpose(
|
||||
# x,
|
||||
# w,
|
||||
# output_shape=output_shape,
|
||||
# strides=stride,
|
||||
# padding='VALID',
|
||||
# data_format=data_format)
|
||||
# JAX equivalent
|
||||
|
||||
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
|
||||
|
||||
|
||||
def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
|
||||
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
|
||||
Args:
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
@@ -235,138 +130,9 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
||||
return F.conv2d(x, w, stride=s, padding=0)
|
||||
|
||||
|
||||
def _setup_kernel(k):
|
||||
k = np.asarray(k, dtype=np.float32)
|
||||
if k.ndim == 1:
|
||||
k = np.outer(k, k)
|
||||
k /= np.sum(k)
|
||||
assert k.ndim == 2
|
||||
assert k.shape[0] == k.shape[1]
|
||||
return k
|
||||
|
||||
|
||||
def _shape(x, dim):
|
||||
return x.shape[dim]
|
||||
|
||||
|
||||
def upsample_2d(x, k=None, factor=2, gain=1):
|
||||
r"""Upsample a batch of 2D images with the given filter.
|
||||
|
||||
Args:
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
||||
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
|
||||
multiple of the upsampling factor.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]`
|
||||
"""
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if k is None:
|
||||
k = [1] * factor
|
||||
k = _setup_kernel(k) * (gain * (factor**2))
|
||||
p = k.shape[0] - factor
|
||||
return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
|
||||
|
||||
|
||||
def downsample_2d(x, k=None, factor=2, gain=1):
|
||||
r"""Downsample a batch of 2D images with the given filter.
|
||||
|
||||
Args:
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
||||
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
||||
shape is a multiple of the downsampling factor.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
||||
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]`
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if k is None:
|
||||
k = [1] * factor
|
||||
k = _setup_kernel(k) * gain
|
||||
p = k.shape[0] - factor
|
||||
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
|
||||
"""1x1 convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
|
||||
"""3x3 convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(
|
||||
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
|
||||
)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def _einsum(a, b, c, x, y):
|
||||
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
|
||||
return torch.einsum(einsum_str, x, y)
|
||||
|
||||
|
||||
def contract_inner(x, y):
|
||||
"""tensordot(x, y, 1)."""
|
||||
x_chars = list(string.ascii_lowercase[: len(x.shape)])
|
||||
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
|
||||
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
|
||||
out_chars = x_chars[:-1] + y_chars[1:]
|
||||
return _einsum(x_chars, y_chars, out_chars, x, y)
|
||||
|
||||
|
||||
class NIN(nn.Module):
|
||||
def __init__(self, in_dim, num_units, init_scale=0.1):
|
||||
super().__init__()
|
||||
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
|
||||
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
y = contract_inner(x, self.W) + self.b
|
||||
return y.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
def get_act(nonlinearity):
|
||||
"""Get activation functions from the config file."""
|
||||
|
||||
if nonlinearity.lower() == "elu":
|
||||
return nn.ELU()
|
||||
elif nonlinearity.lower() == "relu":
|
||||
return nn.ReLU()
|
||||
elif nonlinearity.lower() == "lrelu":
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif nonlinearity.lower() == "swish":
|
||||
return nn.SiLU()
|
||||
else:
|
||||
raise NotImplementedError("activation function does not exist!")
|
||||
|
||||
|
||||
def default_init(scale=1.0):
|
||||
"""The same initialization used in DDPM."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
return variance_scaling(scale, "fan_avg", "uniform")
|
||||
|
||||
|
||||
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
"""Ported from JAX."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
|
||||
def _compute_fans(shape, in_axis=1, out_axis=0):
|
||||
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
||||
@@ -376,31 +142,35 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
|
||||
|
||||
def init(shape, dtype=dtype, device=device):
|
||||
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
||||
if mode == "fan_in":
|
||||
denominator = fan_in
|
||||
elif mode == "fan_out":
|
||||
denominator = fan_out
|
||||
elif mode == "fan_avg":
|
||||
denominator = (fan_in + fan_out) / 2
|
||||
else:
|
||||
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
|
||||
denominator = (fan_in + fan_out) / 2
|
||||
variance = scale / denominator
|
||||
if distribution == "normal":
|
||||
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
|
||||
elif distribution == "uniform":
|
||||
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
|
||||
else:
|
||||
raise ValueError("invalid distribution for variance scaling initializer")
|
||||
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
|
||||
|
||||
return init
|
||||
|
||||
|
||||
def Conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
|
||||
"""nXn convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
||||
conv.weight.data = _variance_scaling(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def Linear(dim_in, dim_out):
|
||||
linear = nn.Linear(dim_in, dim_out)
|
||||
linear.weight.data = _variance_scaling()(linear.weight.shape)
|
||||
nn.init.zeros_(linear.bias)
|
||||
return linear
|
||||
|
||||
|
||||
class Combine(nn.Module):
|
||||
"""Combine information from skip connections."""
|
||||
|
||||
def __init__(self, dim1, dim2, method="cat"):
|
||||
super().__init__()
|
||||
self.Conv_0 = conv1x1(dim1, dim2)
|
||||
# 1x1 convolution with DDPM initialization.
|
||||
self.Conv_0 = Conv2d(dim1, dim2, kernel_size=1, padding=0)
|
||||
self.method = method
|
||||
|
||||
def forward(self, x, y):
|
||||
@@ -413,80 +183,42 @@ class Combine(nn.Module):
|
||||
raise ValueError(f"Method {self.method} not recognized.")
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
|
||||
class FirUpsample(nn.Module):
|
||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
if not fir:
|
||||
if with_conv:
|
||||
self.Conv_0 = conv3x3(in_ch, out_ch)
|
||||
else:
|
||||
if with_conv:
|
||||
self.Conv2d_0 = Conv2d(
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel=3,
|
||||
up=True,
|
||||
resample_kernel=fir_kernel,
|
||||
use_bias=True,
|
||||
kernel_init=default_init(),
|
||||
)
|
||||
self.fir = fir
|
||||
self.with_conv = with_conv
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.use_conv = use_conv
|
||||
self.fir_kernel = fir_kernel
|
||||
self.out_ch = out_ch
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
if not self.fir:
|
||||
h = F.interpolate(x, (H * 2, W * 2), "nearest")
|
||||
if self.with_conv:
|
||||
h = self.Conv_0(h)
|
||||
if self.use_conv:
|
||||
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
if not self.with_conv:
|
||||
h = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
else:
|
||||
h = self.Conv2d_0(x)
|
||||
h = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
|
||||
class FirDownsample(nn.Module):
|
||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
if not fir:
|
||||
if with_conv:
|
||||
self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
|
||||
else:
|
||||
if with_conv:
|
||||
self.Conv2d_0 = Conv2d(
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel=3,
|
||||
down=True,
|
||||
resample_kernel=fir_kernel,
|
||||
use_bias=True,
|
||||
kernel_init=default_init(),
|
||||
)
|
||||
self.fir = fir
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.fir_kernel = fir_kernel
|
||||
self.with_conv = with_conv
|
||||
self.out_ch = out_ch
|
||||
self.use_conv = use_conv
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
if not self.fir:
|
||||
if self.with_conv:
|
||||
x = F.pad(x, (0, 1, 0, 1))
|
||||
x = self.Conv_0(x)
|
||||
else:
|
||||
x = F.avg_pool2d(x, 2, stride=2)
|
||||
if self.use_conv:
|
||||
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
if not self.with_conv:
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
else:
|
||||
x = self.Conv2d_0(x)
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
|
||||
return x
|
||||
|
||||
@@ -496,10 +228,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
centered=False,
|
||||
image_size=1024,
|
||||
num_channels=3,
|
||||
attention_type="ddpm",
|
||||
centered=False,
|
||||
attn_resolutions=(16,),
|
||||
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
|
||||
conditional=True,
|
||||
@@ -511,24 +242,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
fourier_scale=16,
|
||||
init_scale=0.0,
|
||||
nf=16,
|
||||
nonlinearity="swish",
|
||||
normalization="GroupNorm",
|
||||
num_res_blocks=1,
|
||||
progressive="output_skip",
|
||||
progressive_combine="sum",
|
||||
progressive_input="input_skip",
|
||||
resamp_with_conv=True,
|
||||
resblock_type="biggan",
|
||||
scale_by_sigma=True,
|
||||
skip_rescale=True,
|
||||
continuous=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
centered=centered,
|
||||
image_size=image_size,
|
||||
num_channels=num_channels,
|
||||
attention_type=attention_type,
|
||||
centered=centered,
|
||||
attn_resolutions=attn_resolutions,
|
||||
ch_mult=ch_mult,
|
||||
conditional=conditional,
|
||||
@@ -540,19 +267,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
fourier_scale=fourier_scale,
|
||||
init_scale=init_scale,
|
||||
nf=nf,
|
||||
nonlinearity=nonlinearity,
|
||||
normalization=normalization,
|
||||
num_res_blocks=num_res_blocks,
|
||||
progressive=progressive,
|
||||
progressive_combine=progressive_combine,
|
||||
progressive_input=progressive_input,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
resblock_type=resblock_type,
|
||||
scale_by_sigma=scale_by_sigma,
|
||||
skip_rescale=skip_rescale,
|
||||
continuous=continuous,
|
||||
)
|
||||
self.act = act = get_act(nonlinearity)
|
||||
self.act = act = nn.SiLU()
|
||||
|
||||
self.nf = nf
|
||||
self.num_res_blocks = num_res_blocks
|
||||
@@ -562,7 +286,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
|
||||
self.conditional = conditional
|
||||
self.skip_rescale = skip_rescale
|
||||
self.resblock_type = resblock_type
|
||||
self.progressive = progressive
|
||||
self.progressive_input = progressive_input
|
||||
self.embedding_type = embedding_type
|
||||
@@ -585,53 +308,41 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
raise ValueError(f"embedding type {embedding_type} unknown.")
|
||||
|
||||
if conditional:
|
||||
modules.append(nn.Linear(embed_dim, nf * 4))
|
||||
modules[-1].weight.data = default_init()(modules[-1].weight.shape)
|
||||
nn.init.zeros_(modules[-1].bias)
|
||||
modules.append(nn.Linear(nf * 4, nf * 4))
|
||||
modules[-1].weight.data = default_init()(modules[-1].weight.shape)
|
||||
nn.init.zeros_(modules[-1].bias)
|
||||
modules.append(Linear(embed_dim, nf * 4))
|
||||
modules.append(Linear(nf * 4, nf * 4))
|
||||
|
||||
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
|
||||
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
||||
|
||||
if self.fir:
|
||||
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
|
||||
else:
|
||||
Up_sample = functools.partial(Upsample, name="Conv2d_0")
|
||||
|
||||
if progressive == "output_skip":
|
||||
self.pyramid_upsample = Up_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
||||
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
|
||||
elif progressive == "residual":
|
||||
pyramid_upsample = functools.partial(Up_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True)
|
||||
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
|
||||
|
||||
Down_sample = functools.partial(Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
||||
if self.fir:
|
||||
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
|
||||
else:
|
||||
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
|
||||
|
||||
if progressive_input == "input_skip":
|
||||
self.pyramid_downsample = Down_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
||||
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
|
||||
elif progressive_input == "residual":
|
||||
pyramid_downsample = functools.partial(Down_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True)
|
||||
pyramid_downsample = functools.partial(Down_sample, use_conv=True)
|
||||
|
||||
if resblock_type == "ddpm":
|
||||
ResnetBlock = functools.partial(
|
||||
ResnetBlockDDPMpp,
|
||||
act=act,
|
||||
dropout=dropout,
|
||||
init_scale=init_scale,
|
||||
skip_rescale=skip_rescale,
|
||||
temb_dim=nf * 4,
|
||||
)
|
||||
|
||||
elif resblock_type == "biggan":
|
||||
ResnetBlock = functools.partial(
|
||||
ResnetBlockBigGANpp,
|
||||
act=act,
|
||||
dropout=dropout,
|
||||
fir=fir,
|
||||
fir_kernel=fir_kernel,
|
||||
init_scale=init_scale,
|
||||
skip_rescale=skip_rescale,
|
||||
temb_dim=nf * 4,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"resblock type {resblock_type} unrecognized.")
|
||||
ResnetBlock = functools.partial(
|
||||
ResnetBlockBigGANpp,
|
||||
act=act,
|
||||
dropout=dropout,
|
||||
fir=fir,
|
||||
fir_kernel=fir_kernel,
|
||||
init_scale=init_scale,
|
||||
skip_rescale=skip_rescale,
|
||||
temb_dim=nf * 4,
|
||||
)
|
||||
|
||||
# Downsampling block
|
||||
|
||||
@@ -639,7 +350,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
if progressive_input != "none":
|
||||
input_pyramid_ch = channels
|
||||
|
||||
modules.append(conv3x3(channels, nf))
|
||||
modules.append(Conv2d(channels, nf, kernel_size=3, padding=1))
|
||||
hs_c = [nf]
|
||||
|
||||
in_ch = nf
|
||||
@@ -655,10 +366,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
hs_c.append(in_ch)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
if resblock_type == "ddpm":
|
||||
modules.append(Downsample(in_ch=in_ch))
|
||||
else:
|
||||
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
||||
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
||||
|
||||
if progressive_input == "input_skip":
|
||||
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
||||
@@ -666,7 +374,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
in_ch *= 2
|
||||
|
||||
elif progressive_input == "residual":
|
||||
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
|
||||
modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
|
||||
input_pyramid_ch = in_ch
|
||||
|
||||
hs_c.append(in_ch)
|
||||
@@ -691,36 +399,35 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
if i_level == self.num_resolutions - 1:
|
||||
if progressive == "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
||||
modules.append(Conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1))
|
||||
pyramid_ch = channels
|
||||
elif progressive == "residual":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv3x3(in_ch, in_ch, bias=True))
|
||||
modules.append(Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
|
||||
pyramid_ch = in_ch
|
||||
else:
|
||||
raise ValueError(f"{progressive} is not a valid name.")
|
||||
else:
|
||||
if progressive == "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
|
||||
modules.append(
|
||||
Conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)
|
||||
)
|
||||
pyramid_ch = channels
|
||||
elif progressive == "residual":
|
||||
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
|
||||
modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
|
||||
pyramid_ch = in_ch
|
||||
else:
|
||||
raise ValueError(f"{progressive} is not a valid name")
|
||||
|
||||
if i_level != 0:
|
||||
if resblock_type == "ddpm":
|
||||
modules.append(Upsample(in_ch=in_ch))
|
||||
else:
|
||||
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
||||
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
||||
|
||||
assert not hs_c
|
||||
|
||||
if progressive != "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
||||
modules.append(Conv2d(in_ch, channels, init_scale=init_scale))
|
||||
|
||||
self.all_modules = nn.ModuleList(modules)
|
||||
|
||||
@@ -751,8 +458,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# If input data is in [0, 1]
|
||||
if not self.config.centered:
|
||||
# If input data is in [0, 1]
|
||||
x = 2 * x - 1.0
|
||||
|
||||
# Downsampling block
|
||||
@@ -774,12 +481,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
hs.append(h)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
if self.resblock_type == "ddpm":
|
||||
h = modules[m_idx](hs[-1])
|
||||
m_idx += 1
|
||||
else:
|
||||
h = modules[m_idx](hs[-1], temb)
|
||||
m_idx += 1
|
||||
h = modules[m_idx](hs[-1], temb)
|
||||
m_idx += 1
|
||||
|
||||
if self.progressive_input == "input_skip":
|
||||
input_pyramid = self.pyramid_downsample(input_pyramid)
|
||||
@@ -851,12 +554,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
raise ValueError(f"{self.progressive} is not a valid name")
|
||||
|
||||
if i_level != 0:
|
||||
if self.resblock_type == "ddpm":
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
else:
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
|
||||
assert not hs
|
||||
|
||||
|
||||
@@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
|
||||
class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
|
||||
@@ -607,7 +607,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
|
||||
class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
@@ -678,7 +678,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
@@ -742,18 +742,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(torch_device)
|
||||
noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([3.1909e-07, -8.5393e-08, 4.8460e-07, -4.5550e-07, -1.3205e-06, -6.3475e-07, 9.7837e-07, 2.9974e-07, 1.2345e-06])
|
||||
expected_output_slice = torch.tensor([0.1315, 0.0741, 0.0393, 0.0455, 0.0556, 0.0180, -0.0832, -0.0644, -0.0856])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
def test_output_pretrained_ve_large(self):
|
||||
model = NCSNpp.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy")
|
||||
@@ -768,21 +768,21 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(torch_device)
|
||||
noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-8.3299e-07, -9.0431e-07, 4.0585e-08, 9.7563e-07, 1.0280e-06, 1.0133e-06, 1.4979e-06, -2.9716e-07, -6.1817e-07])
|
||||
expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
def test_output_pretrained_vp(self):
|
||||
model = NCSNpp.from_pretrained("fusing/ddpm-cifar10-vp-dummy")
|
||||
model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp")
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
|
||||
@@ -794,18 +794,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(torch_device)
|
||||
noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [9.0]).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-3.9086e-07, -1.1001e-05, 1.8881e-06, 1.1106e-05, 1.6629e-06, 2.9820e-06, 8.4978e-06, 8.0253e-07, 1.5435e-06])
|
||||
expected_output_slice = torch.tensor([0.3303, -0.2275, -2.8872, -0.1309, -1.2861, 3.4567, -1.0083, 2.5325, -1.3866])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
|
||||
class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
@@ -878,10 +878,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462,
|
||||
-0.4218])
|
||||
expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
|
||||
class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
@@ -950,10 +949,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662,
|
||||
0.1750])
|
||||
expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user