From 635da723742ffc31e80e5bab3bde2c743617ac79 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Jun 2022 22:41:39 +0000 Subject: [PATCH 1/3] one attention module only --- src/diffusers/models/attention2d.py | 190 ++++++++++++------ src/diffusers/models/embeddings.py | 16 -- src/diffusers/models/unet_grad_tts.py | 3 +- .../models/unet_sde_score_estimation.py | 35 +--- 4 files changed, 137 insertions(+), 107 deletions(-) diff --git a/src/diffusers/models/attention2d.py b/src/diffusers/models/attention2d.py index e7fe805814..6839867fb5 100644 --- a/src/diffusers/models/attention2d.py +++ b/src/diffusers/models/attention2d.py @@ -1,11 +1,11 @@ import math import torch -import torch.nn.functional as F from torch import nn # unet_grad_tts.py +# TODO(Patrick) - weird linear attention layer. Check with: https://github.com/huawei-noah/Speech-Backbones/issues/15 class LinearAttention(torch.nn.Module): def __init__(self, dim, heads=4, dim_head=32): super(LinearAttention, self).__init__() @@ -18,7 +18,6 @@ class LinearAttention(torch.nn.Module): def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) - # q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) q, k, v = ( qkv.reshape(b, 3, self.heads, self.dim_head, h, w) .permute(1, 0, 2, 3, 4, 5) @@ -27,12 +26,11 @@ class LinearAttention(torch.nn.Module): k = k.softmax(dim=-1) context = torch.einsum("bhdn,bhen->bhde", k, v) out = torch.einsum("bhde,bhdn->bhen", context, q) - # out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w) return self.to_out(out) -# unet_glide.py & unet_ldm.py +# the main attention block that is used for all models class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. @@ -46,10 +44,13 @@ class AttentionBlock(nn.Module): channels, num_heads=1, num_head_channels=-1, + num_groups=32, use_checkpoint=False, encoder_channels=None, use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete? overwrite_qkv=False, + overwrite_linear=False, + rescale_output_factor=1.0, ): super().__init__() self.channels = channels @@ -62,23 +63,34 @@ class AttentionBlock(nn.Module): self.num_heads = channels // num_head_channels self.use_checkpoint = use_checkpoint - self.norm = normalization(channels, swish=0.0) - self.qkv = conv_nd(1, channels, channels * 3, 1) + self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True) + self.qkv = nn.Conv1d(channels, channels * 3, 1) self.n_heads = self.num_heads + self.rescale_output_factor = rescale_output_factor if encoder_channels is not None: - self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) + self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1) - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) self.overwrite_qkv = overwrite_qkv if overwrite_qkv: in_channels = channels + self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.overwrite_linear = overwrite_linear + if self.overwrite_linear: + num_groups = min(channels // 4, 32) + self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6) + self.NIN_0 = NIN(channels, channels) + self.NIN_1 = NIN(channels, channels) + self.NIN_2 = NIN(channels, channels) + self.NIN_3 = NIN(channels, channels) + self.is_overwritten = False def set_weights(self, module): @@ -89,11 +101,17 @@ class AttentionBlock(nn.Module): self.qkv.weight.data = qkv_weight self.qkv.bias.data = qkv_bias - proj_out = zero_module(conv_nd(1, self.channels, self.channels, 1)) + proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1)) proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0] proj_out.bias.data = module.proj_out.bias.data self.proj_out = proj_out + elif self.overwrite_linear: + self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None] + self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0) + + self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None] + self.proj_out.bias.data = self.NIN_3.b.data def forward(self, x, encoder_out=None): if self.overwrite_qkv and not self.is_overwritten: @@ -124,69 +142,74 @@ class AttentionBlock(nn.Module): h = a.reshape(bs, -1, length) h = self.proj_out(h) + h = h.reshape(b, c, *spatial) - return x + h.reshape(b, c, *spatial) + result = x + h + result = result / self.rescale_output_factor -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -class GroupNorm32(nn.GroupNorm): - def __init__(self, num_groups, num_channels, swish, eps=1e-5, affine=True): - super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine) - self.swish = swish - - def forward(self, x): - y = super().forward(x.float()).to(x.dtype) - if self.swish == 1.0: - y = F.silu(y) - elif self.swish: - y = y * F.sigmoid(y * float(self.swish)) - return y - - -def normalization(channels, swish=0.0, eps=1e-5): - """ - Make a standard normalization layer, with an optional swish activation. - - :param channels: number of input channels. :return: an nn.Module for normalization. - """ - return GroupNorm32(num_channels=channels, num_groups=32, swish=swish, eps=eps, affine=True) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module + return result # unet_score_estimation.py -# class AttnBlockpp(nn.Module): +#class AttnBlockpp(nn.Module): # """Channel-wise self-attention block. Modified from DDPM.""" # -# def __init__(self, channels, skip_rescale=False, init_scale=0.0): +# def __init__( +# self, +# channels, +# skip_rescale=False, +# init_scale=0.0, +# num_heads=1, +# num_head_channels=-1, +# use_checkpoint=False, +# encoder_channels=None, +# use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete? +# overwrite_qkv=False, +# overwrite_from_grad_tts=False, +# ): # super().__init__() -# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) +# num_groups = min(channels // 4, 32) +# self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6) # self.NIN_0 = NIN(channels, channels) # self.NIN_1 = NIN(channels, channels) # self.NIN_2 = NIN(channels, channels) # self.NIN_3 = NIN(channels, channels, init_scale=init_scale) # self.skip_rescale = skip_rescale # +# self.channels = channels +# if num_head_channels == -1: +# self.num_heads = num_heads +# else: +# assert ( +# channels % num_head_channels == 0 +# ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" +# self.num_heads = channels // num_head_channels +# +# self.use_checkpoint = use_checkpoint +# self.norm = normalization(channels, num_groups=num_groups, eps=1e-6, swish=None) +# self.qkv = conv_nd(1, channels, channels * 3, 1) +# self.n_heads = self.num_heads +# +# if encoder_channels is not None: +# self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) +# +# self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) +# +# self.is_weight_set = False +# +# def set_weights(self): +# self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None] +# self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0) +# +# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None] +# self.proj_out.bias.data = self.NIN_3.b.data +# # def forward(self, x): +# if not self.is_weight_set: +# self.set_weights() +# self.is_weight_set = True +# # B, C, H, W = x.shape # h = self.GroupNorm_0(x) # q = self.NIN_0(h) @@ -199,7 +222,58 @@ def zero_module(module): # w = torch.reshape(w, (B, H, W, H, W)) # h = torch.einsum("bhwij,bcij->bchw", w, v) # h = self.NIN_3(h) +# # if not self.skip_rescale: -# return x + h +# result = x + h # else: -# return (x + h) / np.sqrt(2.0) +# result = (x + h) / np.sqrt(2.0) +# +# result = self.forward_2(x) +# +# return result +# +# def forward_2(self, x, encoder_out=None): +# b, c, *spatial = x.shape +# hid_states = self.norm(x).view(b, c, -1) +# +# qkv = self.qkv(hid_states) +# bs, width, length = qkv.shape +# assert width % (3 * self.n_heads) == 0 +# ch = width // (3 * self.n_heads) +# q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) +# +# if encoder_out is not None: +# encoder_kv = self.encoder_kv(encoder_out) +# assert encoder_kv.shape[1] == self.n_heads * ch * 2 +# ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) +# k = torch.cat([ek, k], dim=-1) +# v = torch.cat([ev, v], dim=-1) +# +# scale = 1 / math.sqrt(math.sqrt(ch)) +# weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards +# weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) +# +# a = torch.einsum("bts,bcs->bct", weight, v) +# h = a.reshape(bs, -1, length) +# +# h = self.proj_out(h) +# h = h.reshape(b, c, *spatial) +# +# return (x + h) / np.sqrt(2.0) + +# TODO(Patrick) - this can and should be removed +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then +class NIN(nn.Module): + def __init__(self, in_dim, num_units, init_scale=0.1): + super().__init__() + self.W = nn.Parameter(torch.zeros(in_dim, num_units), requires_grad=True) + self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index e70f39319e..cb4ad0f5e5 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -65,19 +65,3 @@ class GaussianFourierProjection(nn.Module): def forward(self, x): x_proj = x[:, None] * self.W[None, :] * 2 * np.pi return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) - - -# unet_rl.py - TODO(need test) -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 4cccc8375c..880a20cefe 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -5,6 +5,7 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample +from .attention2d import LinearAttention class Mish(torch.nn.Module): @@ -54,7 +55,7 @@ class ResnetBlock(torch.nn.Module): return output -class LinearAttention(torch.nn.Module): +class old_LinearAttention(torch.nn.Module): def __init__(self, dim, heads=4, dim_head=32): super(LinearAttention, self).__init__() self.heads = heads diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 3f58a55cbe..20a2ad3169 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -22,10 +22,12 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +import math from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import GaussianFourierProjection, get_timestep_embedding +from .attention2d import AttentionBlock def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): @@ -414,37 +416,6 @@ class Combine(nn.Module): raise ValueError(f"Method {self.method} not recognized.") -class AttnBlockpp(nn.Module): - """Channel-wise self-attention block. Modified from DDPM.""" - - def __init__(self, channels, skip_rescale=False, init_scale=0.0): - super().__init__() - self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) - self.NIN_0 = NIN(channels, channels) - self.NIN_1 = NIN(channels, channels) - self.NIN_2 = NIN(channels, channels) - self.NIN_3 = NIN(channels, channels, init_scale=init_scale) - self.skip_rescale = skip_rescale - - def forward(self, x): - B, C, H, W = x.shape - h = self.GroupNorm_0(x) - q = self.NIN_0(h) - k = self.NIN_1(h) - v = self.NIN_2(h) - - w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) - w = torch.reshape(w, (B, H, W, H * W)) - w = F.softmax(w, dim=-1) - w = torch.reshape(w, (B, H, W, H, W)) - h = torch.einsum("bhwij,bcij->bchw", w, v) - h = self.NIN_3(h) - if not self.skip_rescale: - return x + h - else: - return (x + h) / np.sqrt(2.0) - - class Upsample(nn.Module): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): super().__init__() @@ -756,7 +727,7 @@ class NCSNpp(ModelMixin, ConfigMixin): modules[-1].weight.data = default_init()(modules[-1].weight.shape) nn.init.zeros_(modules[-1].bias) - AttnBlock = functools.partial(AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale) + AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) From 31d1f3c8c0c296bbdef9fa1651cfa7995cbed4b1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Jun 2022 22:59:21 +0000 Subject: [PATCH 2/3] final fix --- src/diffusers/models/attention2d.py | 30 ++++++++++++------- src/diffusers/models/unet.py | 8 ++--- src/diffusers/models/unet_grad_tts.py | 2 +- .../models/unet_sde_score_estimation.py | 5 ++-- tests/test_modeling_utils.py | 4 ++- 5 files changed, 30 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/attention2d.py b/src/diffusers/models/attention2d.py index 6839867fb5..1d7e85e3f2 100644 --- a/src/diffusers/models/attention2d.py +++ b/src/diffusers/models/attention2d.py @@ -91,11 +91,15 @@ class AttentionBlock(nn.Module): self.NIN_2 = NIN(channels, channels) self.NIN_3 = NIN(channels, channels) + self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6) + self.is_overwritten = False def set_weights(self, module): if self.overwrite_qkv: - qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[:, :, :, 0] + qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[ + :, :, :, 0 + ] qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0) self.qkv.weight.data = qkv_weight @@ -107,14 +111,19 @@ class AttentionBlock(nn.Module): self.proj_out = proj_out elif self.overwrite_linear: - self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None] + self.qkv.weight.data = torch.concat( + [self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0 + )[:, :, None] self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0) self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None] self.proj_out.bias.data = self.NIN_3.b.data + self.norm.weight.data = self.GroupNorm_0.weight.data + self.norm.bias.data = self.GroupNorm_0.bias.data + def forward(self, x, encoder_out=None): - if self.overwrite_qkv and not self.is_overwritten: + if (self.overwrite_qkv or self.overwrite_linear) and not self.is_overwritten: self.set_weights(self) self.is_overwritten = True @@ -152,7 +161,7 @@ class AttentionBlock(nn.Module): # unet_score_estimation.py -#class AttnBlockpp(nn.Module): +# class AttnBlockpp(nn.Module): # """Channel-wise self-attention block. Modified from DDPM.""" # # def __init__( @@ -187,14 +196,11 @@ class AttentionBlock(nn.Module): # self.num_heads = channels // num_head_channels # # self.use_checkpoint = use_checkpoint -# self.norm = normalization(channels, num_groups=num_groups, eps=1e-6, swish=None) -# self.qkv = conv_nd(1, channels, channels * 3, 1) +# self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6) +# self.qkv = nn.Conv1d(channels, channels * 3, 1) # self.n_heads = self.num_heads # -# if encoder_channels is not None: -# self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) -# -# self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) +# self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) # # self.is_weight_set = False # @@ -205,6 +211,9 @@ class AttentionBlock(nn.Module): # self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None] # self.proj_out.bias.data = self.NIN_3.b.data # +# self.norm.weight.data = self.GroupNorm_0.weight.data +# self.norm.bias.data = self.GroupNorm_0.bias.data +# # def forward(self, x): # if not self.is_weight_set: # self.set_weights() @@ -261,6 +270,7 @@ class AttentionBlock(nn.Module): # # return (x + h) / np.sqrt(2.0) + # TODO(Patrick) - this can and should be removed def zero_module(module): """ diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index af43389527..dc3201f18b 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -30,9 +30,9 @@ from tqdm import tqdm from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .attention2d import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample -from .attention2d import AttentionBlock def nonlinearity(x): @@ -219,11 +219,11 @@ class UNetModel(ModelMixin, ConfigMixin): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: -# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block]) -# h = self.down[i_level].attn_2[i_block](h) + # self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block]) + # h = self.down[i_level].attn_2[i_block](h) h = self.down[i_level].attn[i_block](h) -# print("Result", (h - h_2).abs().sum()) + # print("Result", (h - h_2).abs().sum()) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 880a20cefe..4854442ce1 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -3,9 +3,9 @@ from numpy import pad from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .attention2d import LinearAttention from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample -from .attention2d import LinearAttention class Mish(torch.nn.Module): diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 20a2ad3169..692ef99576 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -16,18 +16,18 @@ # helpers functions import functools +import math import string import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import math from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .embeddings import GaussianFourierProjection, get_timestep_embedding from .attention2d import AttentionBlock +from .embeddings import GaussianFourierProjection, get_timestep_embedding def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): @@ -728,7 +728,6 @@ class NCSNpp(ModelMixin, ConfigMixin): nn.init.zeros_(modules[-1].bias) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) - Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) if progressive == "output_skip": diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index aa4c67a3e2..2e0be1d334 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -859,7 +859,9 @@ class PipelineTesterMixin(unittest.TestCase): image_slice = image[0, -1, -3:, -3:].cpu() assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor([-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105]) + expected_slice = torch.tensor( + [-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105] + ) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 @slow From c482d7bd4fd694c2ccefe44b3b6e27c600801c31 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Jun 2022 23:09:50 +0000 Subject: [PATCH 3/3] some clean up --- .../models/{attention2d.py => attention.py} | 0 src/diffusers/models/unet.py | 16 +--------- src/diffusers/models/unet_glide.py | 2 +- src/diffusers/models/unet_grad_tts.py | 29 +------------------ src/diffusers/models/unet_ldm.py | 2 +- .../models/unet_sde_score_estimation.py | 2 +- 6 files changed, 5 insertions(+), 46 deletions(-) rename src/diffusers/models/{attention2d.py => attention.py} (100%) diff --git a/src/diffusers/models/attention2d.py b/src/diffusers/models/attention.py similarity index 100% rename from src/diffusers/models/attention2d.py rename to src/diffusers/models/attention.py diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index dc3201f18b..13765e1f8b 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -15,22 +15,12 @@ # helpers functions -import copy -import math -from pathlib import Path - import torch from torch import nn -from torch.cuda.amp import GradScaler, autocast -from torch.optim import Adam -from torch.utils import data - -from PIL import Image -from tqdm import tqdm from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .attention2d import AttentionBlock +from .attention import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample @@ -219,11 +209,7 @@ class UNetModel(ModelMixin, ConfigMixin): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: - # self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block]) - # h = self.down[i_level].attn_2[i_block](h) - h = self.down[i_level].attn[i_block](h) - # print("Result", (h - h_2).abs().sum()) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 0ffd20b35f..53763ddaa0 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .attention2d import AttentionBlock +from .attention import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 4854442ce1..bc0c1e7a22 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -1,9 +1,8 @@ import torch -from numpy import pad from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .attention2d import LinearAttention +from .attention import LinearAttention from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample @@ -55,32 +54,6 @@ class ResnetBlock(torch.nn.Module): return output -class old_LinearAttention(torch.nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super(LinearAttention, self).__init__() - self.heads = heads - self.dim_head = dim_head - hidden_dim = dim_head * heads - self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) - self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x) - # q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) - q, k, v = ( - qkv.reshape(b, 3, self.heads, self.dim_head, h, w) - .permute(1, 0, 2, 3, 4, 5) - .reshape(3, b, self.heads, self.dim_head, -1) - ) - k = k.softmax(dim=-1) - context = torch.einsum("bhdn,bhen->bhde", k, v) - out = torch.einsum("bhde,bhdn->bhen", context, q) - # out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) - out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w) - return self.to_out(out) - - class Residual(torch.nn.Module): def __init__(self, fn): super(Residual, self).__init__() diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 5315285afa..0012886a5e 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .attention2d import AttentionBlock +from .attention import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 692ef99576..17218a7a7e 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -26,7 +26,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .attention2d import AttentionBlock +from .attention import AttentionBlock from .embeddings import GaussianFourierProjection, get_timestep_embedding