From c45fd7498cfebb8f7fff2b2081ac90ee0e2393a0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Jun 2022 17:31:44 +0000 Subject: [PATCH] merge unet attention into glide attention --- src/diffusers/models/attention2d.py | 56 ----------------------------- src/diffusers/models/unet.py | 43 +--------------------- tests/test_modeling_utils.py | 17 +++++---- 3 files changed, 9 insertions(+), 107 deletions(-) diff --git a/src/diffusers/models/attention2d.py b/src/diffusers/models/attention2d.py index 0a85b2d3be..e7fe805814 100644 --- a/src/diffusers/models/attention2d.py +++ b/src/diffusers/models/attention2d.py @@ -32,62 +32,6 @@ class LinearAttention(torch.nn.Module): return self.to_out(out) -# unet.py -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = normalization(in_channels, swish=None, eps=1e-6) - self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x): - print("x", x.abs().sum()) - h_ = x - h_ = self.norm(h_) - - print("hid_states shape", h_.shape) - print("hid_states", h_.abs().sum()) - print("hid_states - 3 - 3", h_.view(h_.shape[0], h_.shape[1], -1)[:, :3, -3:]) - - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - print(self.q) - print("q_shape", q.shape) - print("q", q.abs().sum()) -# print("k_shape", k.shape) -# print("k", k.abs().sum()) -# print("v_shape", v.shape) -# print("v", v.abs().sum()) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h * w) # b,c,hw - - w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - - print("weight", w_.abs().sum()) - - # attend to values - v = v.reshape(b, c, h * w) - h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x + h_ - - # unet_glide.py & unet_ldm.py class AttentionBlock(nn.Module): """ diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index 1a40a6139a..af43389527 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -32,7 +32,7 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample -from .attention2d import AttnBlock, AttentionBlock +from .attention2d import AttentionBlock def nonlinearity(x): @@ -86,44 +86,6 @@ class ResnetBlock(nn.Module): return x + h -#class AttnBlock(nn.Module): -# def __init__(self, in_channels): -# super().__init__() -# self.in_channels = in_channels -# -# self.norm = Normalize(in_channels) -# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) -# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) -# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) -# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) -# -# def forward(self, x): -# h_ = x -# h_ = self.norm(h_) -# q = self.q(h_) -# k = self.k(h_) -# v = self.v(h_) -# - # compute attention -# b, c, h, w = q.shape -# q = q.reshape(b, c, h * w) -# q = q.permute(0, 2, 1) # b,hw,c -# k = k.reshape(b, c, h * w) # b,c,hw -# w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] -# w_ = w_ * (int(c) ** (-0.5)) -# w_ = torch.nn.functional.softmax(w_, dim=2) -# - # attend to values -# v = v.reshape(b, c, h * w) -# w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) -# h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] -# h_ = h_.reshape(b, c, h, w) -# -# h_ = self.proj_out(h_) -# -# return x + h_ - - class UNetModel(ModelMixin, ConfigMixin): def __init__( self, @@ -186,7 +148,6 @@ class UNetModel(ModelMixin, ConfigMixin): ) block_in = block_out if curr_res in attn_resolutions: -# attn.append(AttnBlock(block_in)) attn.append(AttentionBlock(block_in, overwrite_qkv=True)) down = nn.Module() down.block = block @@ -202,7 +163,6 @@ class UNetModel(ModelMixin, ConfigMixin): self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout ) -# self.mid.attn_1 = AttnBlock(block_in) self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout @@ -228,7 +188,6 @@ class UNetModel(ModelMixin, ConfigMixin): ) block_in = block_out if curr_res in attn_resolutions: -# attn.append(AttnBlock(block_in)) attn.append(AttentionBlock(block_in, overwrite_qkv=True)) up = nn.Module() up.block = block diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index d659d7ca72..f57797346c 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -858,25 +858,26 @@ class PipelineTesterMixin(unittest.TestCase): image_slice = image[0, -1, -3:, -3:].cpu() assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor([0.2250, 0.3375, 0.2360, 0.0930, 0.3440, 0.3156, 0.1937, 0.3585, 0.1761]) + expected_slice = torch.tensor([0.2249, 0.3375, 0.2359, 0.0929, 0.3439, 0.3156, 0.1937, 0.3585, 0.1761]) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 @slow def test_ddim_cifar10(self): - generator = torch.manual_seed(0) model_id = "fusing/ddpm-cifar10" unet = UNetModel.from_pretrained(model_id) noise_scheduler = DDIMScheduler(tensor_format="pt") ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler) + + generator = torch.manual_seed(0) image = ddim(generator=generator, eta=0.0) image_slice = image[0, -1, -3:, -3:].cpu() assert image.shape == (1, 3, 32, 32) expected_slice = torch.tensor( - [-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068] + [-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094] ) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 @@ -895,7 +896,7 @@ class PipelineTesterMixin(unittest.TestCase): assert image.shape == (1, 3, 32, 32) expected_slice = torch.tensor( - [-0.7888, -0.7870, -0.7759, -0.7823, -0.8014, -0.7608, -0.6818, -0.7130, -0.7471] + [-0.7925, -0.7902, -0.7789, -0.7796, -0.8000, -0.7596, -0.6852, -0.7125, -0.7494] ) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 @@ -966,24 +967,22 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_score_sde_ve_pipeline(self): - torch.manual_seed(0) - model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp") scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp") sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) + torch.manual_seed(0) image = sde_ve(num_inference_steps=2) - expected_image_sum = 3382810112.0 - expected_image_mean = 1075.366455078125 + expected_image_sum = 3382849024.0 + expected_image_mean = 1075.3788 assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 @slow def test_score_sde_vp_pipeline(self): - model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp") scheduler = ScoreSdeVpScheduler.from_config("fusing/cifar10-ddpmpp-vp")