mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
better names and more cleanup
This commit is contained in:
@@ -40,12 +40,8 @@ def _setup_kernel(k):
|
||||
return k
|
||||
|
||||
|
||||
def _shape(x, dim):
|
||||
return x.shape[dim]
|
||||
|
||||
|
||||
def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
|
||||
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||
|
||||
Args:
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
@@ -84,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
|
||||
# Determine data dimensions.
|
||||
stride = [1, 1, factor, factor]
|
||||
output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
|
||||
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
|
||||
output_padding = (
|
||||
output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
|
||||
output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW,
|
||||
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
|
||||
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
|
||||
)
|
||||
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
||||
num_groups = _shape(x, 1) // inC
|
||||
num_groups = x.shape[1] // inC
|
||||
|
||||
# Transpose weights.
|
||||
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
|
||||
@@ -98,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
|
||||
|
||||
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
|
||||
# Original TF code.
|
||||
# x = tf.nn.conv2d_transpose(
|
||||
# x,
|
||||
# w,
|
||||
# output_shape=output_shape,
|
||||
# strides=stride,
|
||||
# padding='VALID',
|
||||
# data_format=data_format)
|
||||
# JAX equivalent
|
||||
|
||||
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
|
||||
|
||||
|
||||
def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
|
||||
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
|
||||
Args:
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
@@ -143,15 +130,7 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
||||
return F.conv2d(x, w, stride=s, padding=0)
|
||||
|
||||
|
||||
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
|
||||
"""nXn convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
||||
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
"""Ported from JAX."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
|
||||
@@ -170,13 +149,21 @@ def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, devi
|
||||
return init
|
||||
|
||||
|
||||
def Conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
|
||||
"""nXn convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
||||
conv.weight.data = _variance_scaling(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
class Combine(nn.Module):
|
||||
"""Combine information from skip connections."""
|
||||
|
||||
def __init__(self, dim1, dim2, method="cat"):
|
||||
super().__init__()
|
||||
# 1x1 convolution with DDPM initialization.
|
||||
self.Conv_0 = conv2d(dim1, dim2, kernel_size=1, padding=0)
|
||||
self.Conv_0 = Conv2d(dim1, dim2, kernel_size=1, padding=0)
|
||||
self.method = method
|
||||
|
||||
def forward(self, x, y):
|
||||
@@ -189,38 +176,38 @@ class Combine(nn.Module):
|
||||
raise ValueError(f"Method {self.method} not recognized.")
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
class FirUpsample(nn.Module):
|
||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
if with_conv:
|
||||
self.Conv2d_0 = conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
self.with_conv = with_conv
|
||||
self.fir_kernel = fir_kernel
|
||||
self.out_ch = out_ch
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
h = upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
else:
|
||||
h = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
class FirDownsample(nn.Module):
|
||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
if with_conv:
|
||||
self.Conv2d_0 = self.Conv2d_0 = conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
self.fir_kernel = fir_kernel
|
||||
self.with_conv = with_conv
|
||||
self.out_ch = out_ch
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
x = conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
else:
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
|
||||
@@ -311,21 +298,21 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
|
||||
if conditional:
|
||||
modules.append(nn.Linear(embed_dim, nf * 4))
|
||||
modules[-1].weight.data = variance_scaling()(modules[-1].weight.shape)
|
||||
modules[-1].weight.data = _variance_scaling()(modules[-1].weight.shape)
|
||||
nn.init.zeros_(modules[-1].bias)
|
||||
modules.append(nn.Linear(nf * 4, nf * 4))
|
||||
modules[-1].weight.data = variance_scaling()(modules[-1].weight.shape)
|
||||
modules[-1].weight.data = _variance_scaling()(modules[-1].weight.shape)
|
||||
nn.init.zeros_(modules[-1].bias)
|
||||
|
||||
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
|
||||
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
|
||||
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
|
||||
|
||||
if progressive == "output_skip":
|
||||
self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False)
|
||||
elif progressive == "residual":
|
||||
pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True)
|
||||
|
||||
Down_sample = functools.partial(Downsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
|
||||
Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
|
||||
|
||||
if progressive_input == "input_skip":
|
||||
self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False)
|
||||
@@ -348,7 +335,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
if progressive_input != "none":
|
||||
input_pyramid_ch = channels
|
||||
|
||||
modules.append(conv2d(channels, nf, kernel_size=3, padding=1))
|
||||
modules.append(Conv2d(channels, nf, kernel_size=3, padding=1))
|
||||
hs_c = [nf]
|
||||
|
||||
in_ch = nf
|
||||
@@ -397,11 +384,11 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
if i_level == self.num_resolutions - 1:
|
||||
if progressive == "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1))
|
||||
modules.append(Conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1))
|
||||
pyramid_ch = channels
|
||||
elif progressive == "residual":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
|
||||
modules.append(Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
|
||||
pyramid_ch = in_ch
|
||||
else:
|
||||
raise ValueError(f"{progressive} is not a valid name.")
|
||||
@@ -409,7 +396,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
if progressive == "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(
|
||||
conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)
|
||||
Conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)
|
||||
)
|
||||
pyramid_ch = channels
|
||||
elif progressive == "residual":
|
||||
@@ -425,7 +412,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
|
||||
if progressive != "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv2d(in_ch, channels, init_scale=init_scale))
|
||||
modules.append(Conv2d(in_ch, channels, init_scale=init_scale))
|
||||
|
||||
self.all_modules = nn.ModuleList(modules)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user