1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix more tests

This commit is contained in:
Patrick von Platen
2022-06-30 21:47:40 +00:00
parent 185347e411
commit db934c6750
2 changed files with 69 additions and 21 deletions

View File

@@ -1,4 +1,3 @@
import string
from abc import abstractmethod
import numpy as np
@@ -188,7 +187,7 @@ class ResBlock(TimestepBlock):
use_checkpoint=False,
up=False,
down=False,
overwrite=False, # TODO(Patrick) - use for glide at later stage
overwrite=True, # TODO(Patrick) - use for glide at later stage
):
super().__init__()
self.channels = channels
@@ -220,12 +219,10 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
2 * self.out_channels,
),
)
self.out_layers = nn.Sequential(
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
normalization(self.out_channels, swish=0.0),
nn.SiLU(),
nn.Dropout(p=dropout),
@@ -257,13 +254,16 @@ class ResBlock(TimestepBlock):
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
# Add to init
self.time_embedding_norm = "scale_shift"
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
@@ -277,6 +277,14 @@ class ResBlock(TimestepBlock):
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.up, self.down = up, down
# if self.up:
# self.h_upd = Upsample(in_channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(in_channels, use_conv=False, dims=dims)
# elif self.down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=dims, padding=1, name="op")
def set_weights(self):
# TODO(Patrick): use for glide at later stage
self.norm1.weight.data = self.in_layers[0].weight.data
@@ -309,6 +317,7 @@ class ResBlock(TimestepBlock):
# TODO(Patrick): use for glide at later stage
self.set_weights()
orig_x = x
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
@@ -334,8 +343,7 @@ class ResBlock(TimestepBlock):
result = self.skip_connection(x) + h
# TODO(Patrick) Use for glide at later stage
# result = self.forward_2(x, emb)
result = self.forward_2(orig_x, emb)
return result
def forward_2(self, x, temb):
@@ -347,18 +355,24 @@ class ResBlock(TimestepBlock):
h = self.norm1(h)
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
h = self.conv1(h)
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
scale, shift = torch.chunk(temb, 2, dim=1)
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h * scale + shift
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
else:
h = h + temb
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
@@ -386,8 +400,12 @@ class ResnetBlock(nn.Module):
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
up=False,
down=False,
overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
overwrite_for_glide=False,
):
super().__init__()
self.pre_norm = pre_norm
@@ -395,6 +413,9 @@ class ResnetBlock(nn.Module):
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
@@ -402,7 +423,12 @@ class ResnetBlock(nn.Module):
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
if time_embedding_norm == "default":
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
if time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
@@ -414,6 +440,13 @@ class ResnetBlock(nn.Module):
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if up:
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif down:
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
# TODO(Patrick) - this branch is never used I think => can be deleted!
@@ -422,8 +455,9 @@ class ResnetBlock(nn.Module):
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.is_overwritten = False
self.overwrite_for_glide = overwrite_for_glide
self.overwrite_for_grad_tts = overwrite_for_grad_tts
self.overwrite_for_ldm = overwrite_for_ldm
self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
if self.overwrite_for_grad_tts:
dim = in_channels
dim_out = out_channels
@@ -517,12 +551,18 @@ class ResnetBlock(nn.Module):
self.set_weights_ldm()
self.is_overwritten = True
if self.up or self.down:
x = self.x_upd(x)
h = x
h = h * mask
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
if self.up or self.down:
h = self.h_upd(h)
h = self.conv1(h)
if not self.pre_norm:
@@ -530,12 +570,20 @@ class ResnetBlock(nn.Module):
h = self.nonlinearity(h)
h = h * mask
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)

View File

@@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: off
expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):