From 1468f754e0da36b4345fca558ad199aa9cba31d0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 1 Jul 2022 15:40:54 +0000 Subject: [PATCH] finish resnet --- .../models/unet_sde_score_estimation.py | 31 +++++-------------- 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 727bec30d9..c5a1579470 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -28,8 +28,7 @@ from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import GaussianFourierProjection, get_timestep_embedding from .resnet import downsample_2d, upfirdn2d, upsample_2d -from .resnet import ResnetBlockBigGANppNew as ResnetBlockBigGANpp -from .resnet import ResnetBlock as ResnetNew +from .resnet import ResnetBlock def _setup_kernel(k): @@ -323,16 +322,6 @@ class NCSNpp(ModelMixin, ConfigMixin): elif progressive_input == "residual": pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True) - ResnetBlock = functools.partial( - ResnetBlockBigGANpp, - act=act, - dropout=dropout, - fir_kernel=fir_kernel, - init_scale=init_scale, - skip_rescale=skip_rescale, - temb_dim=nf * 4, - ) - # Downsampling block channels = num_channels @@ -347,9 +336,8 @@ class NCSNpp(ModelMixin, ConfigMixin): # Residual blocks for this resolution for i_block in range(num_res_blocks): out_ch = nf * ch_mult[i_level] -# modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) modules.append( - ResnetNew( + ResnetBlock( in_channels=in_ch, out_channels=out_ch, temb_channels=4 * nf, @@ -367,9 +355,8 @@ class NCSNpp(ModelMixin, ConfigMixin): hs_c.append(in_ch) if i_level != self.num_resolutions - 1: -# modules.append(ResnetBlock(down=True, in_ch=in_ch)) modules.append( - ResnetNew( + ResnetBlock( in_channels=in_ch, temb_channels=4 * nf, output_scale_factor=np.sqrt(2.0), @@ -395,9 +382,8 @@ class NCSNpp(ModelMixin, ConfigMixin): hs_c.append(in_ch) in_ch = hs_c[-1] -# modules.append(ResnetBlock(in_ch=in_ch)) modules.append( - ResnetNew( + ResnetBlock( in_channels=in_ch, temb_channels=4 * nf, output_scale_factor=np.sqrt(2.0), @@ -408,9 +394,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ) ) modules.append(AttnBlock(channels=in_ch)) -# modules.append(ResnetBlock(in_ch=in_ch)) modules.append( - ResnetNew( + ResnetBlock( in_channels=in_ch, temb_channels=4 * nf, output_scale_factor=np.sqrt(2.0), @@ -426,9 +411,8 @@ class NCSNpp(ModelMixin, ConfigMixin): for i_level in reversed(range(self.num_resolutions)): for i_block in range(num_res_blocks + 1): out_ch = nf * ch_mult[i_level] -# modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) modules.append( - ResnetNew( + ResnetBlock( in_channels=in_ch + hs_c.pop(), out_channels=out_ch, temb_channels=4 * nf, @@ -470,9 +454,8 @@ class NCSNpp(ModelMixin, ConfigMixin): raise ValueError(f"{progressive} is not a valid name") if i_level != 0: -# modules.append(ResnetBlock(in_ch=in_ch, up=True)) modules.append( - ResnetNew( + ResnetBlock( in_channels=in_ch, temb_channels=4 * nf, output_scale_factor=np.sqrt(2.0),