1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

more fixes

This commit is contained in:
Patrick von Platen
2022-07-01 14:35:14 +00:00
parent f1aade0596
commit 61dc657461

View File

@@ -1,5 +1,6 @@
from abc import abstractmethod
import functools
import numpy as np
import torch
import torch.nn as nn
@@ -374,15 +375,20 @@ class ResnetBlock(nn.Module):
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
fir_kernel=(1, 3, 3, 1),
output_scale_factor=1.0,
use_nin_shortcut=None,
up=False,
down=False,
overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
overwrite_for_glide=False,
overwrite_for_score_vde=False,
):
super().__init__()
self.pre_norm = pre_norm
@@ -393,6 +399,13 @@ class ResnetBlock(nn.Module):
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
if use_nin_shortcut is None:
use_nin_shortcut = self.in_channels != self.out_channels
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
@@ -406,7 +419,7 @@ class ResnetBlock(nn.Module):
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.norm2 = Normalize(out_channels, num_groups=groups_out, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
@@ -417,14 +430,17 @@ class ResnetBlock(nn.Module):
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if up:
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif down:
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# if up:
# self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
# self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
# elif down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.upsample = Upsample(in_channels, use_conv=False, dims=2) if self.up else None
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") if self.down else None
if self.in_channels != self.out_channels:
self.nin_shortcut = None
if use_nin_shortcut:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
@@ -432,6 +448,7 @@ class ResnetBlock(nn.Module):
self.overwrite_for_glide = overwrite_for_glide
self.overwrite_for_grad_tts = overwrite_for_grad_tts
self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
self.overwrite_for_score_vde = overwrite_for_score_vde
if self.overwrite_for_grad_tts:
dim = in_channels
dim_out = out_channels
@@ -450,6 +467,7 @@ class ResnetBlock(nn.Module):
channels = in_channels
emb_channels = temb_channels
use_scale_shift_norm = False
non_linearity = "silu"
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
@@ -473,6 +491,45 @@ class ResnetBlock(nn.Module):
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
elif self.overwrite_for_score_vde:
in_ch = in_channels
out_ch = out_channels
eps = 1e-6
num_groups = min(in_ch // 4, 32)
num_groups_out = min(out_ch // 4, 32)
temb_dim = temb_channels
# output_scale_factor = np.sqrt(2.0)
# non_linearity = "silu"
# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
self.up = up
self.down = down
self.fir_kernel = fir_kernel
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=num_groups_out, num_channels=out_ch, eps=eps)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=0.0, kernel_size=3, padding=1)
if in_ch != out_ch or up or down:
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
# self.skip_rescale = skip_rescale
self.in_ch = in_ch
self.out_ch = out_ch
# TODO(Patrick) - move to main init
self.upsample = functools.partial(upsample_2d, k=self.fir_kernel)
self.downsample = functools.partial(downsample_2d, k=self.fir_kernel)
self.is_overwritten = False
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
@@ -512,6 +569,24 @@ class ResnetBlock(nn.Module):
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data
def set_weights_score_vde(self):
self.conv1.weight.data = self.Conv_0.weight.data
self.conv1.bias.data = self.Conv_0.bias.data
self.norm1.weight.data = self.GroupNorm_0.weight.data
self.norm1.bias.data = self.GroupNorm_0.bias.data
self.conv2.weight.data = self.Conv_1.weight.data
self.conv2.bias.data = self.Conv_1.bias.data
self.norm2.weight.data = self.GroupNorm_1.weight.data
self.norm2.bias.data = self.GroupNorm_1.bias.data
self.temb_proj.weight.data = self.Dense_0.weight.data
self.temb_proj.bias.data = self.Dense_0.bias.data
if self.in_channels != self.out_channels or self.up or self.down:
self.nin_shortcut.weight.data = self.Conv_2.weight.data
self.nin_shortcut.bias.data = self.Conv_2.bias.data
def forward(self, x, temb, mask=1.0):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
@@ -521,6 +596,9 @@ class ResnetBlock(nn.Module):
elif self.overwrite_for_ldm and not self.is_overwritten:
self.set_weights_ldm()
self.is_overwritten = True
elif self.overwrite_for_score_vde and not self.is_overwritten:
self.set_weights_score_vde()
self.is_overwritten = True
h = x
h = h * mask
@@ -528,10 +606,17 @@ class ResnetBlock(nn.Module):
h = self.norm1(h)
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
if self.upsample is not None:
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
x = self.downsample(x)
h = self.downsample(h)
# if self.up: or self.down:
# x = self.x_upd(x)
# h = self.h_upd(h)
#
h = self.conv1(h)
if not self.pre_norm:
@@ -563,7 +648,7 @@ class ResnetBlock(nn.Module):
h = h * mask
x = x * mask
if self.in_channels != self.out_channels:
if self.nin_shortcut is not None:
x = self.nin_shortcut(x)
return x + h