From 3e2cff4da25642e964c48fa44d7c00d3314b1ce8 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 30 Jun 2022 13:26:05 +0200 Subject: [PATCH] better names and more cleanup --- .../models/unet_sde_score_estimation.py | 79 ++++++++----------- 1 file changed, 33 insertions(+), 46 deletions(-) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 5a67b1ffb1..48e25bea7d 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -40,12 +40,8 @@ def _setup_kernel(k): return k -def _shape(x, dim): - return x.shape[dim] - - -def upsample_conv_2d(x, w, k=None, factor=2, gain=1): - """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. +def _upsample_conv_2d(x, w, k=None, factor=2, gain=1): + """Fused `upsample_2d()` followed by `Conv2d()`. Args: Padding is performed only once at the beginning, not between the operations. The fused op is considerably more @@ -84,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1): # Determine data dimensions. stride = [1, 1, factor, factor] - output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) + output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW) output_padding = ( - output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, - output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW, + output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (x.shape[3] - 1) * stride[1] - convW, ) assert output_padding[0] >= 0 and output_padding[1] >= 0 - num_groups = _shape(x, 1) // inC + num_groups = x.shape[1] // inC # Transpose weights. w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) @@ -98,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1): w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) - # Original TF code. - # x = tf.nn.conv2d_transpose( - # x, - # w, - # output_shape=output_shape, - # strides=stride, - # padding='VALID', - # data_format=data_format) - # JAX equivalent return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) -def conv_downsample_2d(x, w, k=None, factor=2, gain=1): - """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. +def _conv_downsample_2d(x, w, k=None, factor=2, gain=1): + """Fused `Conv2d()` followed by `downsample_2d()`. Args: Padding is performed only once at the beginning, not between the operations. The fused op is considerably more @@ -143,15 +130,7 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1): return F.conv2d(x, w, stride=s, padding=0) -def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1): - """nXn convolution with DDPM initialization.""" - conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) - conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape) - nn.init.zeros_(conv.bias) - return conv - - -def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): +def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): """Ported from JAX.""" scale = 1e-10 if scale == 0 else scale @@ -170,13 +149,21 @@ def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, devi return init +def Conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1): + """nXn convolution with DDPM initialization.""" + conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) + conv.weight.data = _variance_scaling(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + class Combine(nn.Module): """Combine information from skip connections.""" def __init__(self, dim1, dim2, method="cat"): super().__init__() # 1x1 convolution with DDPM initialization. - self.Conv_0 = conv2d(dim1, dim2, kernel_size=1, padding=0) + self.Conv_0 = Conv2d(dim1, dim2, kernel_size=1, padding=0) self.method = method def forward(self, x, y): @@ -189,38 +176,38 @@ class Combine(nn.Module): raise ValueError(f"Method {self.method} not recognized.") -class Upsample(nn.Module): +class FirUpsample(nn.Module): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_ch = out_ch if out_ch else in_ch if with_conv: - self.Conv2d_0 = conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) + self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) self.with_conv = with_conv self.fir_kernel = fir_kernel self.out_ch = out_ch def forward(self, x): if self.with_conv: - h = upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) + h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) else: h = upsample_2d(x, self.fir_kernel, factor=2) return h -class Downsample(nn.Module): +class FirDownsample(nn.Module): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_ch = out_ch if out_ch else in_ch if with_conv: - self.Conv2d_0 = self.Conv2d_0 = conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) + self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) self.fir_kernel = fir_kernel self.with_conv = with_conv self.out_ch = out_ch def forward(self, x): if self.with_conv: - x = conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) + x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) else: x = downsample_2d(x, self.fir_kernel, factor=2) @@ -311,21 +298,21 @@ class NCSNpp(ModelMixin, ConfigMixin): if conditional: modules.append(nn.Linear(embed_dim, nf * 4)) - modules[-1].weight.data = variance_scaling()(modules[-1].weight.shape) + modules[-1].weight.data = _variance_scaling()(modules[-1].weight.shape) nn.init.zeros_(modules[-1].bias) modules.append(nn.Linear(nf * 4, nf * 4)) - modules[-1].weight.data = variance_scaling()(modules[-1].weight.shape) + modules[-1].weight.data = _variance_scaling()(modules[-1].weight.shape) nn.init.zeros_(modules[-1].bias) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) - Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel) + Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel) if progressive == "output_skip": self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False) elif progressive == "residual": pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True) - Down_sample = functools.partial(Downsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel) + Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel) if progressive_input == "input_skip": self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False) @@ -348,7 +335,7 @@ class NCSNpp(ModelMixin, ConfigMixin): if progressive_input != "none": input_pyramid_ch = channels - modules.append(conv2d(channels, nf, kernel_size=3, padding=1)) + modules.append(Conv2d(channels, nf, kernel_size=3, padding=1)) hs_c = [nf] in_ch = nf @@ -397,11 +384,11 @@ class NCSNpp(ModelMixin, ConfigMixin): if i_level == self.num_resolutions - 1: if progressive == "output_skip": modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1)) + modules.append(Conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1)) pyramid_ch = channels elif progressive == "residual": modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1)) + modules.append(Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1)) pyramid_ch = in_ch else: raise ValueError(f"{progressive} is not a valid name.") @@ -409,7 +396,7 @@ class NCSNpp(ModelMixin, ConfigMixin): if progressive == "output_skip": modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append( - conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1) + Conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1) ) pyramid_ch = channels elif progressive == "residual": @@ -425,7 +412,7 @@ class NCSNpp(ModelMixin, ConfigMixin): if progressive != "output_skip": modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(conv2d(in_ch, channels, init_scale=init_scale)) + modules.append(Conv2d(in_ch, channels, init_scale=init_scale)) self.all_modules = nn.ModuleList(modules)