diff --git a/src/diffusers/models/attention2d.py b/src/diffusers/models/attention2d.py index 6839867fb5..1d7e85e3f2 100644 --- a/src/diffusers/models/attention2d.py +++ b/src/diffusers/models/attention2d.py @@ -91,11 +91,15 @@ class AttentionBlock(nn.Module): self.NIN_2 = NIN(channels, channels) self.NIN_3 = NIN(channels, channels) + self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6) + self.is_overwritten = False def set_weights(self, module): if self.overwrite_qkv: - qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[:, :, :, 0] + qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[ + :, :, :, 0 + ] qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0) self.qkv.weight.data = qkv_weight @@ -107,14 +111,19 @@ class AttentionBlock(nn.Module): self.proj_out = proj_out elif self.overwrite_linear: - self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None] + self.qkv.weight.data = torch.concat( + [self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0 + )[:, :, None] self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0) self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None] self.proj_out.bias.data = self.NIN_3.b.data + self.norm.weight.data = self.GroupNorm_0.weight.data + self.norm.bias.data = self.GroupNorm_0.bias.data + def forward(self, x, encoder_out=None): - if self.overwrite_qkv and not self.is_overwritten: + if (self.overwrite_qkv or self.overwrite_linear) and not self.is_overwritten: self.set_weights(self) self.is_overwritten = True @@ -152,7 +161,7 @@ class AttentionBlock(nn.Module): # unet_score_estimation.py -#class AttnBlockpp(nn.Module): +# class AttnBlockpp(nn.Module): # """Channel-wise self-attention block. Modified from DDPM.""" # # def __init__( @@ -187,14 +196,11 @@ class AttentionBlock(nn.Module): # self.num_heads = channels // num_head_channels # # self.use_checkpoint = use_checkpoint -# self.norm = normalization(channels, num_groups=num_groups, eps=1e-6, swish=None) -# self.qkv = conv_nd(1, channels, channels * 3, 1) +# self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6) +# self.qkv = nn.Conv1d(channels, channels * 3, 1) # self.n_heads = self.num_heads # -# if encoder_channels is not None: -# self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) -# -# self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) +# self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) # # self.is_weight_set = False # @@ -205,6 +211,9 @@ class AttentionBlock(nn.Module): # self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None] # self.proj_out.bias.data = self.NIN_3.b.data # +# self.norm.weight.data = self.GroupNorm_0.weight.data +# self.norm.bias.data = self.GroupNorm_0.bias.data +# # def forward(self, x): # if not self.is_weight_set: # self.set_weights() @@ -261,6 +270,7 @@ class AttentionBlock(nn.Module): # # return (x + h) / np.sqrt(2.0) + # TODO(Patrick) - this can and should be removed def zero_module(module): """ diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index af43389527..dc3201f18b 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -30,9 +30,9 @@ from tqdm import tqdm from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .attention2d import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample -from .attention2d import AttentionBlock def nonlinearity(x): @@ -219,11 +219,11 @@ class UNetModel(ModelMixin, ConfigMixin): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: -# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block]) -# h = self.down[i_level].attn_2[i_block](h) + # self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block]) + # h = self.down[i_level].attn_2[i_block](h) h = self.down[i_level].attn[i_block](h) -# print("Result", (h - h_2).abs().sum()) + # print("Result", (h - h_2).abs().sum()) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 880a20cefe..4854442ce1 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -3,9 +3,9 @@ from numpy import pad from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .attention2d import LinearAttention from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample -from .attention2d import LinearAttention class Mish(torch.nn.Module): diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 20a2ad3169..692ef99576 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -16,18 +16,18 @@ # helpers functions import functools +import math import string import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import math from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .embeddings import GaussianFourierProjection, get_timestep_embedding from .attention2d import AttentionBlock +from .embeddings import GaussianFourierProjection, get_timestep_embedding def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): @@ -728,7 +728,6 @@ class NCSNpp(ModelMixin, ConfigMixin): nn.init.zeros_(modules[-1].bias) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) - Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) if progressive == "output_skip": diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index aa4c67a3e2..2e0be1d334 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -859,7 +859,9 @@ class PipelineTesterMixin(unittest.TestCase): image_slice = image[0, -1, -3:, -3:].cpu() assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor([-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105]) + expected_slice = torch.tensor( + [-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105] + ) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 @slow