mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
more fixes
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
import functools
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -374,15 +375,20 @@ class ResnetBlock(nn.Module):
|
||||
dropout=0.0,
|
||||
temb_channels=512,
|
||||
groups=32,
|
||||
groups_out=None,
|
||||
pre_norm=True,
|
||||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
time_embedding_norm="default",
|
||||
fir_kernel=(1, 3, 3, 1),
|
||||
output_scale_factor=1.0,
|
||||
use_nin_shortcut=None,
|
||||
up=False,
|
||||
down=False,
|
||||
overwrite_for_grad_tts=False,
|
||||
overwrite_for_ldm=False,
|
||||
overwrite_for_glide=False,
|
||||
overwrite_for_score_vde=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = pre_norm
|
||||
@@ -393,6 +399,13 @@ class ResnetBlock(nn.Module):
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.up = up
|
||||
self.down = down
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
if use_nin_shortcut is None:
|
||||
use_nin_shortcut = self.in_channels != self.out_channels
|
||||
|
||||
if self.pre_norm:
|
||||
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
|
||||
@@ -406,7 +419,7 @@ class ResnetBlock(nn.Module):
|
||||
elif time_embedding_norm == "scale_shift":
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
|
||||
|
||||
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
|
||||
self.norm2 = Normalize(out_channels, num_groups=groups_out, eps=eps)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
@@ -417,14 +430,17 @@ class ResnetBlock(nn.Module):
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
|
||||
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
|
||||
elif down:
|
||||
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
# if up:
|
||||
# self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
|
||||
# self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
|
||||
# elif down:
|
||||
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
self.upsample = Upsample(in_channels, use_conv=False, dims=2) if self.up else None
|
||||
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") if self.down else None
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = None
|
||||
if use_nin_shortcut:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
|
||||
@@ -432,6 +448,7 @@ class ResnetBlock(nn.Module):
|
||||
self.overwrite_for_glide = overwrite_for_glide
|
||||
self.overwrite_for_grad_tts = overwrite_for_grad_tts
|
||||
self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
|
||||
self.overwrite_for_score_vde = overwrite_for_score_vde
|
||||
if self.overwrite_for_grad_tts:
|
||||
dim = in_channels
|
||||
dim_out = out_channels
|
||||
@@ -450,6 +467,7 @@ class ResnetBlock(nn.Module):
|
||||
channels = in_channels
|
||||
emb_channels = temb_channels
|
||||
use_scale_shift_norm = False
|
||||
non_linearity = "silu"
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels, swish=1.0),
|
||||
@@ -473,6 +491,45 @@ class ResnetBlock(nn.Module):
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
elif self.overwrite_for_score_vde:
|
||||
in_ch = in_channels
|
||||
out_ch = out_channels
|
||||
|
||||
eps = 1e-6
|
||||
num_groups = min(in_ch // 4, 32)
|
||||
num_groups_out = min(out_ch // 4, 32)
|
||||
temb_dim = temb_channels
|
||||
# output_scale_factor = np.sqrt(2.0)
|
||||
# non_linearity = "silu"
|
||||
# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True
|
||||
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
|
||||
self.up = up
|
||||
self.down = down
|
||||
self.fir_kernel = fir_kernel
|
||||
|
||||
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
|
||||
if temb_dim is not None:
|
||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
||||
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
|
||||
nn.init.zeros_(self.Dense_0.bias)
|
||||
|
||||
self.GroupNorm_1 = nn.GroupNorm(num_groups=num_groups_out, num_channels=out_ch, eps=eps)
|
||||
self.Dropout_0 = nn.Dropout(dropout)
|
||||
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=0.0, kernel_size=3, padding=1)
|
||||
if in_ch != out_ch or up or down:
|
||||
# 1x1 convolution with DDPM initialization.
|
||||
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
|
||||
|
||||
# self.skip_rescale = skip_rescale
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
|
||||
# TODO(Patrick) - move to main init
|
||||
self.upsample = functools.partial(upsample_2d, k=self.fir_kernel)
|
||||
self.downsample = functools.partial(downsample_2d, k=self.fir_kernel)
|
||||
|
||||
self.is_overwritten = False
|
||||
|
||||
def set_weights_grad_tts(self):
|
||||
self.conv1.weight.data = self.block1.block[0].weight.data
|
||||
@@ -512,6 +569,24 @@ class ResnetBlock(nn.Module):
|
||||
self.nin_shortcut.weight.data = self.skip_connection.weight.data
|
||||
self.nin_shortcut.bias.data = self.skip_connection.bias.data
|
||||
|
||||
def set_weights_score_vde(self):
|
||||
self.conv1.weight.data = self.Conv_0.weight.data
|
||||
self.conv1.bias.data = self.Conv_0.bias.data
|
||||
self.norm1.weight.data = self.GroupNorm_0.weight.data
|
||||
self.norm1.bias.data = self.GroupNorm_0.bias.data
|
||||
|
||||
self.conv2.weight.data = self.Conv_1.weight.data
|
||||
self.conv2.bias.data = self.Conv_1.bias.data
|
||||
self.norm2.weight.data = self.GroupNorm_1.weight.data
|
||||
self.norm2.bias.data = self.GroupNorm_1.bias.data
|
||||
|
||||
self.temb_proj.weight.data = self.Dense_0.weight.data
|
||||
self.temb_proj.bias.data = self.Dense_0.bias.data
|
||||
|
||||
if self.in_channels != self.out_channels or self.up or self.down:
|
||||
self.nin_shortcut.weight.data = self.Conv_2.weight.data
|
||||
self.nin_shortcut.bias.data = self.Conv_2.bias.data
|
||||
|
||||
def forward(self, x, temb, mask=1.0):
|
||||
# TODO(Patrick) eventually this class should be split into multiple classes
|
||||
# too many if else statements
|
||||
@@ -521,6 +596,9 @@ class ResnetBlock(nn.Module):
|
||||
elif self.overwrite_for_ldm and not self.is_overwritten:
|
||||
self.set_weights_ldm()
|
||||
self.is_overwritten = True
|
||||
elif self.overwrite_for_score_vde and not self.is_overwritten:
|
||||
self.set_weights_score_vde()
|
||||
self.is_overwritten = True
|
||||
|
||||
h = x
|
||||
h = h * mask
|
||||
@@ -528,10 +606,17 @@ class ResnetBlock(nn.Module):
|
||||
h = self.norm1(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
if self.up or self.down:
|
||||
x = self.x_upd(x)
|
||||
h = self.h_upd(h)
|
||||
if self.upsample is not None:
|
||||
x = self.upsample(x)
|
||||
h = self.upsample(h)
|
||||
elif self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
h = self.downsample(h)
|
||||
|
||||
# if self.up: or self.down:
|
||||
# x = self.x_upd(x)
|
||||
# h = self.h_upd(h)
|
||||
#
|
||||
h = self.conv1(h)
|
||||
|
||||
if not self.pre_norm:
|
||||
@@ -563,7 +648,7 @@ class ResnetBlock(nn.Module):
|
||||
h = h * mask
|
||||
|
||||
x = x * mask
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.nin_shortcut is not None:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
Reference in New Issue
Block a user