From 26ce60c46d128b820674e99c847304d1e424b661 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 29 Jun 2022 17:30:48 +0000 Subject: [PATCH] up --- src/diffusers/models/resnet.py | 14 +++--- src/diffusers/models/unet_ldm.py | 83 ++++++++++++++++---------------- 2 files changed, 49 insertions(+), 48 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 8972e58e5f..93c0cf1782 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -330,8 +330,8 @@ class ResBlock(TimestepBlock): result = self.skip_connection(x) + h -# TODO(Patrick) Use for glide at later stage -# result = self.forward_2(x, emb) + # TODO(Patrick) Use for glide at later stage + # result = self.forward_2(x, emb) return result @@ -439,9 +439,9 @@ class ResnetBlock(nn.Module): self.res_conv = torch.nn.Identity() elif self.overwrite_for_ldm: dims = 2 -# eps = 1e-5 -# non_linearity = "silu" -# overwrite_for_ldm + # eps = 1e-5 + # non_linearity = "silu" + # overwrite_for_ldm channels = in_channels emb_channels = temb_channels use_scale_shift_norm = False @@ -466,8 +466,8 @@ class ResnetBlock(nn.Module): ) if self.out_channels == in_channels: self.skip_connection = nn.Identity() -# elif use_conv: -# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) + # elif use_conv: + # self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index f78f3afd09..9c01f0d17e 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -10,9 +10,10 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import get_timestep_embedding -from .resnet import Downsample, TimestepBlock, Upsample -from .resnet import ResnetBlock -#from .resnet import ResBlock +from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample + + +# from .resnet import ResBlock def exists(val): @@ -561,14 +562,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin): for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ - ResnetBlock( - in_channels=ch, - out_channels=mult * model_channels, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - overwrite_for_ldm=True, + ResnetBlock( + in_channels=ch, + out_channels=mult * model_channels, + dropout=dropout, + temb_channels=time_embed_dim, + eps=1e-5, + non_linearity="silu", + overwrite_for_ldm=True, ) ] ch = mult * model_channels @@ -601,16 +602,16 @@ class UNetLDMModel(ModelMixin, ConfigMixin): out_ch = ch self.input_blocks.append( TimestepEmbedSequential( -# ResBlock( -# ch, -# time_embed_dim, -# dropout, -# out_channels=out_ch, -# dims=dims, -# use_checkpoint=use_checkpoint, -# use_scale_shift_norm=use_scale_shift_norm, -# down=True, -# ) + # ResBlock( + # ch, + # time_embed_dim, + # dropout, + # out_channels=out_ch, + # dims=dims, + # use_checkpoint=use_checkpoint, + # use_scale_shift_norm=use_scale_shift_norm, + # down=True, + # ) None if resblock_updown else Downsample( @@ -703,16 +704,16 @@ class UNetLDMModel(ModelMixin, ConfigMixin): if level and i == num_res_blocks: out_ch = ch layers.append( -# ResBlock( -# ch, -# time_embed_dim, -# dropout, -# out_channels=out_ch, -# dims=dims, -# use_checkpoint=use_checkpoint, -# use_scale_shift_norm=use_scale_shift_norm, -# up=True, -# ) + # ResBlock( + # ch, + # time_embed_dim, + # dropout, + # out_channels=out_ch, + # dims=dims, + # use_checkpoint=use_checkpoint, + # use_scale_shift_norm=use_scale_shift_norm, + # up=True, + # ) None if resblock_updown else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch) @@ -876,16 +877,16 @@ class EncoderUNetModel(nn.Module): out_ch = ch self.input_blocks.append( TimestepEmbedSequential( -# ResBlock( -# ch, -# time_embed_dim, -# dropout, -# out_channels=out_ch, -# dims=dims, -# use_checkpoint=use_checkpoint, -# use_scale_shift_norm=use_scale_shift_norm, -# down=True, -# ) + # ResBlock( + # ch, + # time_embed_dim, + # dropout, + # out_channels=out_ch, + # dims=dims, + # use_checkpoint=use_checkpoint, + # use_scale_shift_norm=use_scale_shift_norm, + # down=True, + # ) None if resblock_updown else Downsample(