1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Patrick von Platen
2022-06-30 17:01:06 +00:00
parent c1c4dea98d
commit 185347e411

View File

@@ -207,9 +207,6 @@ class ResBlock(TimestepBlock):
self.updown = up or down
# if self.updown:
# import ipdb; ipdb.set_trace()
if up:
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
@@ -227,8 +224,10 @@ class ResBlock(TimestepBlock):
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
normalization(self.out_channels, swish=0.0),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
@@ -322,6 +321,7 @@ class ResBlock(TimestepBlock):
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
@@ -338,35 +338,31 @@ class ResBlock(TimestepBlock):
return result
def forward_2(self, x, temb, mask=1.0):
def forward_2(self, x, temb):
if self.overwrite and not self.is_overwritten:
self.set_weights()
self.is_overwritten = True
h = x
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = self.norm1(h)
h = self.nonlinearity(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
scale, shift = torch.chunk(temb, 2, dim=1)
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.norm2(h)
h = h * scale + shift
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
@@ -376,7 +372,7 @@ class ResBlock(TimestepBlock):
return x + h
# unet.py and unet_grad_tts.py
# unet.py, unet_grad_tts.py, unet_ldm.py
class ResnetBlock(nn.Module):
def __init__(
self,
@@ -410,6 +406,7 @@ class ResnetBlock(nn.Module):
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":