From 185347e411247ae9c6d8ace910dc3f876958bee1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Jun 2022 17:01:06 +0000 Subject: [PATCH] up --- src/diffusers/models/resnet.py | 39 ++++++++++++++++------------------ 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 5875628352..29fc6a8f00 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -207,9 +207,6 @@ class ResBlock(TimestepBlock): self.updown = up or down -# if self.updown: -# import ipdb; ipdb.set_trace() - if up: self.h_upd = Upsample(channels, use_conv=False, dims=dims) self.x_upd = Upsample(channels, use_conv=False, dims=dims) @@ -227,8 +224,10 @@ class ResBlock(TimestepBlock): ), ) self.out_layers = nn.Sequential( - normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), - nn.SiLU() if use_scale_shift_norm else nn.Identity(), +# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), +# nn.SiLU() if use_scale_shift_norm else nn.Identity(), + normalization(self.out_channels, swish=0.0), + nn.SiLU(), nn.Dropout(p=dropout), zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), ) @@ -322,6 +321,7 @@ class ResBlock(TimestepBlock): emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] + if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] scale, shift = torch.chunk(emb_out, 2, dim=1) @@ -338,35 +338,31 @@ class ResBlock(TimestepBlock): return result - def forward_2(self, x, temb, mask=1.0): + def forward_2(self, x, temb): if self.overwrite and not self.is_overwritten: self.set_weights() self.is_overwritten = True h = x - if self.pre_norm: - h = self.norm1(h) - h = self.nonlinearity(h) + h = self.norm1(h) + h = self.nonlinearity(h) h = self.conv1(h) - if not self.pre_norm: - h = self.norm1(h) - h = self.nonlinearity(h) + temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None] - h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None] + scale, shift = torch.chunk(temb, 2, dim=1) - if self.pre_norm: - h = self.norm2(h) - h = self.nonlinearity(h) + h = self.norm2(h) + h = h * scale + shift + + h = self.norm2(h) + + h = self.nonlinearity(h) h = self.dropout(h) h = self.conv2(h) - if not self.pre_norm: - h = self.norm2(h) - h = self.nonlinearity(h) - if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) @@ -376,7 +372,7 @@ class ResBlock(TimestepBlock): return x + h -# unet.py and unet_grad_tts.py +# unet.py, unet_grad_tts.py, unet_ldm.py class ResnetBlock(nn.Module): def __init__( self, @@ -410,6 +406,7 @@ class ResnetBlock(nn.Module): self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if non_linearity == "swish": self.nonlinearity = nonlinearity elif non_linearity == "mish":