diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index dfd26ec19b..46824a85f1 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -404,9 +404,6 @@ class ResnetBlock(nn.Module): if groups_out is None: groups_out = groups - if use_nin_shortcut is None: - use_nin_shortcut = self.in_channels != self.out_channels - if self.pre_norm: self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps) else: @@ -439,8 +436,11 @@ class ResnetBlock(nn.Module): self.upsample = Upsample(in_channels, use_conv=False, dims=2) if self.up else None self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") if self.down else None - self.nin_shortcut = None - if use_nin_shortcut: + self.nin_shortcut = use_nin_shortcut + if self.use_nin_shortcut is None: + self.use_nin_shortcut = self.in_channels != self.out_channels + + if self.use_nin_shortcut: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) # TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED @@ -613,10 +613,6 @@ class ResnetBlock(nn.Module): x = self.downsample(x) h = self.downsample(h) -# if self.up: or self.down: -# x = self.x_upd(x) -# h = self.h_upd(h) -# h = self.conv1(h) if not self.pre_norm: diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index af19ed68a2..c900ae7810 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -29,6 +29,7 @@ from .attention import AttentionBlock from .embeddings import GaussianFourierProjection, get_timestep_embedding from .resnet import downsample_2d, upfirdn2d, upsample_2d from .resnet import ResnetBlockBigGANppNew as ResnetBlockBigGANpp +from .resnet import ResnetBlock as ResnetNew def _setup_kernel(k): @@ -346,7 +347,19 @@ class NCSNpp(ModelMixin, ConfigMixin): # Residual blocks for this resolution for i_block in range(num_res_blocks): out_ch = nf * ch_mult[i_level] - modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) +# modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) + modules.append( + ResnetNew( + in_channels=in_ch, + out_channels=out_ch, + temb_channels=4 * nf, + output_scale_factor=np.sqrt(2.0), + non_linearity="silu", + groups=min(in_ch // 4, 32), + groups_out=min(out_ch // 4, 32), + overwrite_for_score_vde=True, + ) + ) in_ch = out_ch if all_resolutions[i_level] in attn_resolutions: