mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -207,9 +207,6 @@ class ResBlock(TimestepBlock):
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
# if self.updown:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
|
||||
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
|
||||
@@ -227,8 +224,10 @@ class ResBlock(TimestepBlock):
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
|
||||
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
|
||||
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
|
||||
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
|
||||
normalization(self.out_channels, swish=0.0),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
|
||||
)
|
||||
@@ -322,6 +321,7 @@ class ResBlock(TimestepBlock):
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
||||
@@ -338,35 +338,31 @@ class ResBlock(TimestepBlock):
|
||||
|
||||
return result
|
||||
|
||||
def forward_2(self, x, temb, mask=1.0):
|
||||
def forward_2(self, x, temb):
|
||||
if self.overwrite and not self.is_overwritten:
|
||||
self.set_weights()
|
||||
self.is_overwritten = True
|
||||
|
||||
h = x
|
||||
if self.pre_norm:
|
||||
h = self.norm1(h)
|
||||
h = self.nonlinearity(h)
|
||||
h = self.norm1(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
h = self.conv1(h)
|
||||
|
||||
if not self.pre_norm:
|
||||
h = self.norm1(h)
|
||||
h = self.nonlinearity(h)
|
||||
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||
|
||||
if self.pre_norm:
|
||||
h = self.norm2(h)
|
||||
h = self.nonlinearity(h)
|
||||
h = self.norm2(h)
|
||||
h = h * scale + shift
|
||||
|
||||
h = self.norm2(h)
|
||||
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if not self.pre_norm:
|
||||
h = self.norm2(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
@@ -376,7 +372,7 @@ class ResBlock(TimestepBlock):
|
||||
return x + h
|
||||
|
||||
|
||||
# unet.py and unet_grad_tts.py
|
||||
# unet.py, unet_grad_tts.py, unet_ldm.py
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -410,6 +406,7 @@ class ResnetBlock(nn.Module):
|
||||
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = nonlinearity
|
||||
elif non_linearity == "mish":
|
||||
|
||||
Reference in New Issue
Block a user