diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 414e4672fd..82079d4b65 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -237,12 +237,12 @@ class ResnetBlock(nn.Module): elif non_linearity == "silu": self.nonlinearity = nn.SiLU() -# if up: -# self.h_upd = Upsample(in_channels, use_conv=False, dims=2) -# self.x_upd = Upsample(in_channels, use_conv=False, dims=2) -# elif down: -# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") -# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") + # if up: + # self.h_upd = Upsample(in_channels, use_conv=False, dims=2) + # self.x_upd = Upsample(in_channels, use_conv=False, dims=2) + # elif down: + # self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") + # self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") self.upsample = self.downsample = None if self.up and kernel == "fir": @@ -318,9 +318,9 @@ class ResnetBlock(nn.Module): num_groups = min(in_ch // 4, 32) num_groups_out = min(out_ch // 4, 32) temb_dim = temb_channels -# output_scale_factor = np.sqrt(2.0) -# non_linearity = "silu" -# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True + # output_scale_factor = np.sqrt(2.0) + # non_linearity = "silu" + # use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps) self.up = up @@ -338,7 +338,7 @@ class ResnetBlock(nn.Module): # 1x1 convolution with DDPM initialization. self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0) -# self.skip_rescale = skip_rescale + # self.skip_rescale = skip_rescale self.in_ch = in_ch self.out_ch = out_ch diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 9345d5ba5b..a4cdc8bd09 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -27,8 +27,7 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import GaussianFourierProjection, get_timestep_embedding -from .resnet import downsample_2d, upfirdn2d, upsample_2d, Downsample, Upsample -from .resnet import ResnetBlock +from .resnet import Downsample, ResnetBlock, Upsample, downsample_2d, upfirdn2d, upsample_2d def _setup_kernel(k): @@ -277,8 +276,6 @@ class NCSNpp(ModelMixin, ConfigMixin): skip_rescale=skip_rescale, continuous=continuous, ) - self.act = act = nn.SiLU() - self.nf = nf self.num_res_blocks = num_res_blocks self.attn_resolutions = attn_resolutions @@ -421,9 +418,10 @@ class NCSNpp(ModelMixin, ConfigMixin): for i_level in reversed(range(self.num_resolutions)): for i_block in range(num_res_blocks + 1): out_ch = nf * ch_mult[i_level] + in_ch = in_ch + hs_c.pop() modules.append( ResnetBlock( - in_channels=in_ch + hs_c.pop(), + in_channels=in_ch, out_channels=out_ch, temb_channels=4 * nf, output_scale_factor=np.sqrt(2.0),