mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
finish resnet
This commit is contained in:
@@ -28,8 +28,7 @@ from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
from .resnet import downsample_2d, upfirdn2d, upsample_2d
|
||||
from .resnet import ResnetBlockBigGANppNew as ResnetBlockBigGANpp
|
||||
from .resnet import ResnetBlock as ResnetNew
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
|
||||
def _setup_kernel(k):
|
||||
@@ -323,16 +322,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
elif progressive_input == "residual":
|
||||
pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True)
|
||||
|
||||
ResnetBlock = functools.partial(
|
||||
ResnetBlockBigGANpp,
|
||||
act=act,
|
||||
dropout=dropout,
|
||||
fir_kernel=fir_kernel,
|
||||
init_scale=init_scale,
|
||||
skip_rescale=skip_rescale,
|
||||
temb_dim=nf * 4,
|
||||
)
|
||||
|
||||
# Downsampling block
|
||||
|
||||
channels = num_channels
|
||||
@@ -347,9 +336,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
# Residual blocks for this resolution
|
||||
for i_block in range(num_res_blocks):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
# modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
|
||||
modules.append(
|
||||
ResnetNew(
|
||||
ResnetBlock(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch,
|
||||
temb_channels=4 * nf,
|
||||
@@ -367,9 +355,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
hs_c.append(in_ch)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
# modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
||||
modules.append(
|
||||
ResnetNew(
|
||||
ResnetBlock(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
@@ -395,9 +382,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
hs_c.append(in_ch)
|
||||
|
||||
in_ch = hs_c[-1]
|
||||
# modules.append(ResnetBlock(in_ch=in_ch))
|
||||
modules.append(
|
||||
ResnetNew(
|
||||
ResnetBlock(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
@@ -408,9 +394,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
)
|
||||
)
|
||||
modules.append(AttnBlock(channels=in_ch))
|
||||
# modules.append(ResnetBlock(in_ch=in_ch))
|
||||
modules.append(
|
||||
ResnetNew(
|
||||
ResnetBlock(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
@@ -426,9 +411,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(num_res_blocks + 1):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
# modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
|
||||
modules.append(
|
||||
ResnetNew(
|
||||
ResnetBlock(
|
||||
in_channels=in_ch + hs_c.pop(),
|
||||
out_channels=out_ch,
|
||||
temb_channels=4 * nf,
|
||||
@@ -470,9 +454,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
raise ValueError(f"{progressive} is not a valid name")
|
||||
|
||||
if i_level != 0:
|
||||
# modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
||||
modules.append(
|
||||
ResnetNew(
|
||||
ResnetBlock(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
|
||||
Reference in New Issue
Block a user