diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 83e7cfd979..f95d9198b3 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -378,10 +378,7 @@ class ResBlock(TimestepBlock): h = self.conv2(h) if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) + x = self.nin_shortcut(x) return x + h @@ -426,7 +423,7 @@ class ResnetBlock(nn.Module): if time_embedding_norm == "default": self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - if time_embedding_norm == "scale_shift": + elif time_embedding_norm == "scale_shift": self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps) @@ -489,7 +486,7 @@ class ResnetBlock(nn.Module): nn.SiLU(), linear( emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + 2 * self.out_channels if self.time_embedding_norm == "scale_shift" else self.out_channels, ), ) self.out_layers = nn.Sequential( @@ -551,9 +548,6 @@ class ResnetBlock(nn.Module): self.set_weights_ldm() self.is_overwritten = True - if self.up or self.down: - x = self.x_upd(x) - h = x h = h * mask if self.pre_norm: @@ -561,6 +555,7 @@ class ResnetBlock(nn.Module): h = self.nonlinearity(h) if self.up or self.down: + x = self.x_upd(x) h = self.h_upd(h) h = self.conv1(h) @@ -571,7 +566,6 @@ class ResnetBlock(nn.Module): h = h * mask temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None] - if self.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) @@ -595,10 +589,10 @@ class ResnetBlock(nn.Module): x = x * mask if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) +# if self.use_conv_shortcut: +# x = self.conv_shortcut(x) +# else: + x = self.nin_shortcut(x) return x + h diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 477c1768ae..a0af4b9f48 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -7,6 +7,7 @@ from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, ResBlock, TimestepBlock, Upsample +from .resnet import ResnetBlock def convert_module_to_f16(l): @@ -101,7 +102,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): def forward(self, x, emb, encoder_out=None): for layer in self: - if isinstance(layer, TimestepBlock): + if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock): x = layer(x, emb) elif isinstance(layer, AttentionBlock): x = layer(x, encoder_out) @@ -190,14 +191,24 @@ class GlideUNetModel(ModelMixin, ConfigMixin): for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=int(mult * model_channels), - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, +# ResBlock( +# ch, +# time_embed_dim, +# dropout, +# out_channels=int(mult * model_channels), +# dims=dims, +# use_checkpoint=use_checkpoint, +# use_scale_shift_norm=use_scale_shift_norm, +# ) + ResnetBlock( + in_channels=ch, + out_channels=mult * model_channels, + dropout=dropout, + temb_channels=time_embed_dim, + eps=1e-5, + non_linearity="silu", + time_embedding_norm="scale_shift", + overwrite_for_glide=True, ) ] ch = int(mult * model_channels) @@ -218,15 +229,26 @@ class GlideUNetModel(ModelMixin, ConfigMixin): out_ch = ch self.input_blocks.append( TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, +# ResBlock( +# ch, +# time_embed_dim, +# dropout, +# out_channels=out_ch, +# dims=dims, +# use_checkpoint=use_checkpoint, +# use_scale_shift_norm=use_scale_shift_norm, +# down=True, +# ) + ResnetBlock( + in_channels=ch, out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, + dropout=dropout, + temb_channels=time_embed_dim, + eps=1e-5, + non_linearity="silu", + time_embedding_norm="scale_shift", + overwrite_for_glide=True, + down=True ) if resblock_updown else Downsample( @@ -240,13 +262,22 @@ class GlideUNetModel(ModelMixin, ConfigMixin): self._feature_size += ch self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, +# ResBlock( +# ch, +# time_embed_dim, +# dropout, +# dims=dims, +# use_checkpoint=use_checkpoint, +# use_scale_shift_norm=use_scale_shift_norm, +# ), + ResnetBlock( + in_channels=ch, + dropout=dropout, + temb_channels=time_embed_dim, + eps=1e-5, + non_linearity="silu", + time_embedding_norm="scale_shift", + overwrite_for_glide=True, ), AttentionBlock( ch, @@ -255,14 +286,23 @@ class GlideUNetModel(ModelMixin, ConfigMixin): num_head_channels=num_head_channels, encoder_channels=transformer_dim, ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), +# ResBlock( +# ch, +# time_embed_dim, +# dropout, +# dims=dims, +# use_checkpoint=use_checkpoint, +# use_scale_shift_norm=use_scale_shift_norm, +# ), + ResnetBlock( + in_channels=ch, + dropout=dropout, + temb_channels=time_embed_dim, + eps=1e-5, + non_linearity="silu", + time_embedding_norm="scale_shift", + overwrite_for_glide=True, + ) ) self._feature_size += ch @@ -271,15 +311,25 @@ class GlideUNetModel(ModelMixin, ConfigMixin): for i in range(num_res_blocks + 1): ich = input_block_chans.pop() layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=int(model_channels * mult), - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) +# ResBlock( +# ch + ich, +# time_embed_dim, +# dropout, +# out_channels=int(model_channels * mult), +# dims=dims, +# use_checkpoint=use_checkpoint, +# use_scale_shift_norm=use_scale_shift_norm, +# ) + ResnetBlock( + in_channels=ch + ich, + out_channels=model_channels * mult, + dropout=dropout, + temb_channels=time_embed_dim, + eps=1e-5, + non_linearity="silu", + time_embedding_norm="scale_shift", + overwrite_for_glide=True, + ), ] ch = int(model_channels * mult) if ds in attention_resolutions: @@ -295,14 +345,25 @@ class GlideUNetModel(ModelMixin, ConfigMixin): if level and i == num_res_blocks: out_ch = ch layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, +# ResBlock( +# ch, +# time_embed_dim, +# dropout, +# out_channels=out_ch, +# dims=dims, +# use_checkpoint=use_checkpoint, +# use_scale_shift_norm=use_scale_shift_norm, +# up=True, +# ) + ResnetBlock( + in_channels=ch, out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, + dropout=dropout, + temb_channels=time_embed_dim, + eps=1e-5, + non_linearity="silu", + time_embedding_norm="scale_shift", + overwrite_for_glide=True, up=True, ) if resblock_updown