From 13ac40ed8ed33d75910b14575788c0eab0cbbe75 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 30 Jun 2022 12:21:04 +0200 Subject: [PATCH] style --- src/diffusers/models/resnet.py | 6 ++---- src/diffusers/models/unet_sde_score_estimation.py | 10 +++++----- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index c206859b70..5cc5530625 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -603,7 +603,7 @@ class ResnetBlockBigGANpp(nn.Module): self.Dropout_0 = nn.Dropout(dropout) self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1) if in_ch != out_ch or up or down: - #1x1 convolution with DDPM initialization. + # 1x1 convolution with DDPM initialization. self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0) self.skip_rescale = skip_rescale @@ -757,9 +757,7 @@ class RearrangeDim(nn.Module): def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1): """nXn convolution with DDPM initialization.""" - conv = nn.Conv2d( - in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias - ) + conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape) nn.init.zeros_(conv.bias) return conv diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 30db349395..6f909dcf3b 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -289,9 +289,7 @@ def downsample_2d(x, k=None, factor=2, gain=1): def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1): """nXn convolution with DDPM initialization.""" - conv = nn.Conv2d( - in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias - ) + conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape) nn.init.zeros_(conv.bias) return conv @@ -336,7 +334,7 @@ class Combine(nn.Module): def __init__(self, dim1, dim2, method="cat"): super().__init__() - #1x1 convolution with DDPM initialization. + # 1x1 convolution with DDPM initialization. self.Conv_0 = conv2d(dim1, dim2, kernel_size=1, padding=0) self.method = method @@ -602,7 +600,9 @@ class NCSNpp(ModelMixin, ConfigMixin): else: if progressive == "output_skip": modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)) + modules.append( + conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1) + ) pyramid_ch = channels elif progressive == "residual": modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))