mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
some clean up
This commit is contained in:
@@ -175,6 +175,7 @@ class Downsample(nn.Module):
|
||||
|
||||
|
||||
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
|
||||
# => All 2D-Resnets are included here now!
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -317,9 +318,6 @@ class ResnetBlock(nn.Module):
|
||||
num_groups = min(in_ch // 4, 32)
|
||||
num_groups_out = min(out_ch // 4, 32)
|
||||
temb_dim = temb_channels
|
||||
# output_scale_factor = np.sqrt(2.0)
|
||||
# non_linearity = "silu"
|
||||
# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True
|
||||
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
|
||||
self.up = up
|
||||
@@ -337,13 +335,9 @@ class ResnetBlock(nn.Module):
|
||||
# 1x1 convolution with DDPM initialization.
|
||||
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
|
||||
|
||||
# self.skip_rescale = skip_rescale
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
|
||||
# TODO(Patrick) - move to main init
|
||||
self.is_overwritten = False
|
||||
|
||||
def set_weights_grad_tts(self):
|
||||
self.conv1.weight.data = self.block1.block[0].weight.data
|
||||
self.conv1.bias.data = self.block1.block[0].bias.data
|
||||
|
||||
Reference in New Issue
Block a user