1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

some clean up

This commit is contained in:
Patrick von Platen
2022-07-01 18:14:46 +00:00
parent dcb9070bc2
commit a7b0047e0f

View File

@@ -175,6 +175,7 @@ class Downsample(nn.Module):
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
# => All 2D-Resnets are included here now!
class ResnetBlock(nn.Module):
def __init__(
self,
@@ -317,9 +318,6 @@ class ResnetBlock(nn.Module):
num_groups = min(in_ch // 4, 32)
num_groups_out = min(out_ch // 4, 32)
temb_dim = temb_channels
# output_scale_factor = np.sqrt(2.0)
# non_linearity = "silu"
# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
self.up = up
@@ -337,13 +335,9 @@ class ResnetBlock(nn.Module):
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
# self.skip_rescale = skip_rescale
self.in_ch = in_ch
self.out_ch = out_ch
# TODO(Patrick) - move to main init
self.is_overwritten = False
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
self.conv1.bias.data = self.block1.block[0].bias.data