From 61dc657461130997d3a45929953c122419ee892e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 1 Jul 2022 14:35:14 +0000 Subject: [PATCH] more fixes --- src/diffusers/models/resnet.py | 109 +++++++++++++++++++++++++++++---- 1 file changed, 97 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d707457c35..dfd26ec19b 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -1,5 +1,6 @@ from abc import abstractmethod +import functools import numpy as np import torch import torch.nn as nn @@ -374,15 +375,20 @@ class ResnetBlock(nn.Module): dropout=0.0, temb_channels=512, groups=32, + groups_out=None, pre_norm=True, eps=1e-6, non_linearity="swish", time_embedding_norm="default", + fir_kernel=(1, 3, 3, 1), + output_scale_factor=1.0, + use_nin_shortcut=None, up=False, down=False, overwrite_for_grad_tts=False, overwrite_for_ldm=False, overwrite_for_glide=False, + overwrite_for_score_vde=False, ): super().__init__() self.pre_norm = pre_norm @@ -393,6 +399,13 @@ class ResnetBlock(nn.Module): self.time_embedding_norm = time_embedding_norm self.up = up self.down = down + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + if use_nin_shortcut is None: + use_nin_shortcut = self.in_channels != self.out_channels if self.pre_norm: self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps) @@ -406,7 +419,7 @@ class ResnetBlock(nn.Module): elif time_embedding_norm == "scale_shift": self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) - self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps) + self.norm2 = Normalize(out_channels, num_groups=groups_out, eps=eps) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) @@ -417,14 +430,17 @@ class ResnetBlock(nn.Module): elif non_linearity == "silu": self.nonlinearity = nn.SiLU() - if up: - self.h_upd = Upsample(in_channels, use_conv=False, dims=2) - self.x_upd = Upsample(in_channels, use_conv=False, dims=2) - elif down: - self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") - self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") +# if up: +# self.h_upd = Upsample(in_channels, use_conv=False, dims=2) +# self.x_upd = Upsample(in_channels, use_conv=False, dims=2) +# elif down: +# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") +# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") + self.upsample = Upsample(in_channels, use_conv=False, dims=2) if self.up else None + self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") if self.down else None - if self.in_channels != self.out_channels: + self.nin_shortcut = None + if use_nin_shortcut: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) # TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED @@ -432,6 +448,7 @@ class ResnetBlock(nn.Module): self.overwrite_for_glide = overwrite_for_glide self.overwrite_for_grad_tts = overwrite_for_grad_tts self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide + self.overwrite_for_score_vde = overwrite_for_score_vde if self.overwrite_for_grad_tts: dim = in_channels dim_out = out_channels @@ -450,6 +467,7 @@ class ResnetBlock(nn.Module): channels = in_channels emb_channels = temb_channels use_scale_shift_norm = False + non_linearity = "silu" self.in_layers = nn.Sequential( normalization(channels, swish=1.0), @@ -473,6 +491,45 @@ class ResnetBlock(nn.Module): self.skip_connection = nn.Identity() else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + elif self.overwrite_for_score_vde: + in_ch = in_channels + out_ch = out_channels + + eps = 1e-6 + num_groups = min(in_ch // 4, 32) + num_groups_out = min(out_ch // 4, 32) + temb_dim = temb_channels +# output_scale_factor = np.sqrt(2.0) +# non_linearity = "silu" +# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True + + self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps) + self.up = up + self.down = down + self.fir_kernel = fir_kernel + + self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape) + nn.init.zeros_(self.Dense_0.bias) + + self.GroupNorm_1 = nn.GroupNorm(num_groups=num_groups_out, num_channels=out_ch, eps=eps) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = conv2d(out_ch, out_ch, init_scale=0.0, kernel_size=3, padding=1) + if in_ch != out_ch or up or down: + # 1x1 convolution with DDPM initialization. + self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0) + +# self.skip_rescale = skip_rescale + self.in_ch = in_ch + self.out_ch = out_ch + + # TODO(Patrick) - move to main init + self.upsample = functools.partial(upsample_2d, k=self.fir_kernel) + self.downsample = functools.partial(downsample_2d, k=self.fir_kernel) + + self.is_overwritten = False def set_weights_grad_tts(self): self.conv1.weight.data = self.block1.block[0].weight.data @@ -512,6 +569,24 @@ class ResnetBlock(nn.Module): self.nin_shortcut.weight.data = self.skip_connection.weight.data self.nin_shortcut.bias.data = self.skip_connection.bias.data + def set_weights_score_vde(self): + self.conv1.weight.data = self.Conv_0.weight.data + self.conv1.bias.data = self.Conv_0.bias.data + self.norm1.weight.data = self.GroupNorm_0.weight.data + self.norm1.bias.data = self.GroupNorm_0.bias.data + + self.conv2.weight.data = self.Conv_1.weight.data + self.conv2.bias.data = self.Conv_1.bias.data + self.norm2.weight.data = self.GroupNorm_1.weight.data + self.norm2.bias.data = self.GroupNorm_1.bias.data + + self.temb_proj.weight.data = self.Dense_0.weight.data + self.temb_proj.bias.data = self.Dense_0.bias.data + + if self.in_channels != self.out_channels or self.up or self.down: + self.nin_shortcut.weight.data = self.Conv_2.weight.data + self.nin_shortcut.bias.data = self.Conv_2.bias.data + def forward(self, x, temb, mask=1.0): # TODO(Patrick) eventually this class should be split into multiple classes # too many if else statements @@ -521,6 +596,9 @@ class ResnetBlock(nn.Module): elif self.overwrite_for_ldm and not self.is_overwritten: self.set_weights_ldm() self.is_overwritten = True + elif self.overwrite_for_score_vde and not self.is_overwritten: + self.set_weights_score_vde() + self.is_overwritten = True h = x h = h * mask @@ -528,10 +606,17 @@ class ResnetBlock(nn.Module): h = self.norm1(h) h = self.nonlinearity(h) - if self.up or self.down: - x = self.x_upd(x) - h = self.h_upd(h) + if self.upsample is not None: + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) +# if self.up: or self.down: +# x = self.x_upd(x) +# h = self.h_upd(h) +# h = self.conv1(h) if not self.pre_norm: @@ -563,7 +648,7 @@ class ResnetBlock(nn.Module): h = h * mask x = x * mask - if self.in_channels != self.out_channels: + if self.nin_shortcut is not None: x = self.nin_shortcut(x) return x + h