diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py new file mode 100644 index 0000000000..1d7e85e3f2 --- /dev/null +++ b/src/diffusers/models/attention.py @@ -0,0 +1,289 @@ +import math + +import torch +from torch import nn + + +# unet_grad_tts.py +# TODO(Patrick) - weird linear attention layer. Check with: https://github.com/huawei-noah/Speech-Backbones/issues/15 +class LinearAttention(torch.nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super(LinearAttention, self).__init__() + self.heads = heads + self.dim_head = dim_head + hidden_dim = dim_head * heads + self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = ( + qkv.reshape(b, 3, self.heads, self.dim_head, h, w) + .permute(1, 0, 2, 3, 4, 5) + .reshape(3, b, self.heads, self.dim_head, -1) + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w) + return self.to_out(out) + + +# the main attention block that is used for all models +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + num_groups=32, + use_checkpoint=False, + encoder_channels=None, + use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete? + overwrite_qkv=False, + overwrite_linear=False, + rescale_output_factor=1.0, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + + self.use_checkpoint = use_checkpoint + self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True) + self.qkv = nn.Conv1d(channels, channels * 3, 1) + self.n_heads = self.num_heads + self.rescale_output_factor = rescale_output_factor + + if encoder_channels is not None: + self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1) + + self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) + + self.overwrite_qkv = overwrite_qkv + if overwrite_qkv: + in_channels = channels + self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + self.overwrite_linear = overwrite_linear + if self.overwrite_linear: + num_groups = min(channels // 4, 32) + self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6) + self.NIN_0 = NIN(channels, channels) + self.NIN_1 = NIN(channels, channels) + self.NIN_2 = NIN(channels, channels) + self.NIN_3 = NIN(channels, channels) + + self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6) + + self.is_overwritten = False + + def set_weights(self, module): + if self.overwrite_qkv: + qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[ + :, :, :, 0 + ] + qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0) + + self.qkv.weight.data = qkv_weight + self.qkv.bias.data = qkv_bias + + proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1)) + proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0] + proj_out.bias.data = module.proj_out.bias.data + + self.proj_out = proj_out + elif self.overwrite_linear: + self.qkv.weight.data = torch.concat( + [self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0 + )[:, :, None] + self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0) + + self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None] + self.proj_out.bias.data = self.NIN_3.b.data + + self.norm.weight.data = self.GroupNorm_0.weight.data + self.norm.bias.data = self.GroupNorm_0.bias.data + + def forward(self, x, encoder_out=None): + if (self.overwrite_qkv or self.overwrite_linear) and not self.is_overwritten: + self.set_weights(self) + self.is_overwritten = True + + b, c, *spatial = x.shape + hid_states = self.norm(x).view(b, c, -1) + + qkv = self.qkv(hid_states) + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + + if encoder_out is not None: + encoder_kv = self.encoder_kv(encoder_out) + assert encoder_kv.shape[1] == self.n_heads * ch * 2 + ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) + k = torch.cat([ek, k], dim=-1) + v = torch.cat([ev, v], dim=-1) + + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + a = torch.einsum("bts,bcs->bct", weight, v) + h = a.reshape(bs, -1, length) + + h = self.proj_out(h) + h = h.reshape(b, c, *spatial) + + result = x + h + + result = result / self.rescale_output_factor + + return result + + +# unet_score_estimation.py +# class AttnBlockpp(nn.Module): +# """Channel-wise self-attention block. Modified from DDPM.""" +# +# def __init__( +# self, +# channels, +# skip_rescale=False, +# init_scale=0.0, +# num_heads=1, +# num_head_channels=-1, +# use_checkpoint=False, +# encoder_channels=None, +# use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete? +# overwrite_qkv=False, +# overwrite_from_grad_tts=False, +# ): +# super().__init__() +# num_groups = min(channels // 4, 32) +# self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6) +# self.NIN_0 = NIN(channels, channels) +# self.NIN_1 = NIN(channels, channels) +# self.NIN_2 = NIN(channels, channels) +# self.NIN_3 = NIN(channels, channels, init_scale=init_scale) +# self.skip_rescale = skip_rescale +# +# self.channels = channels +# if num_head_channels == -1: +# self.num_heads = num_heads +# else: +# assert ( +# channels % num_head_channels == 0 +# ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" +# self.num_heads = channels // num_head_channels +# +# self.use_checkpoint = use_checkpoint +# self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6) +# self.qkv = nn.Conv1d(channels, channels * 3, 1) +# self.n_heads = self.num_heads +# +# self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) +# +# self.is_weight_set = False +# +# def set_weights(self): +# self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None] +# self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0) +# +# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None] +# self.proj_out.bias.data = self.NIN_3.b.data +# +# self.norm.weight.data = self.GroupNorm_0.weight.data +# self.norm.bias.data = self.GroupNorm_0.bias.data +# +# def forward(self, x): +# if not self.is_weight_set: +# self.set_weights() +# self.is_weight_set = True +# +# B, C, H, W = x.shape +# h = self.GroupNorm_0(x) +# q = self.NIN_0(h) +# k = self.NIN_1(h) +# v = self.NIN_2(h) +# +# w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) +# w = torch.reshape(w, (B, H, W, H * W)) +# w = F.softmax(w, dim=-1) +# w = torch.reshape(w, (B, H, W, H, W)) +# h = torch.einsum("bhwij,bcij->bchw", w, v) +# h = self.NIN_3(h) +# +# if not self.skip_rescale: +# result = x + h +# else: +# result = (x + h) / np.sqrt(2.0) +# +# result = self.forward_2(x) +# +# return result +# +# def forward_2(self, x, encoder_out=None): +# b, c, *spatial = x.shape +# hid_states = self.norm(x).view(b, c, -1) +# +# qkv = self.qkv(hid_states) +# bs, width, length = qkv.shape +# assert width % (3 * self.n_heads) == 0 +# ch = width // (3 * self.n_heads) +# q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) +# +# if encoder_out is not None: +# encoder_kv = self.encoder_kv(encoder_out) +# assert encoder_kv.shape[1] == self.n_heads * ch * 2 +# ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) +# k = torch.cat([ek, k], dim=-1) +# v = torch.cat([ev, v], dim=-1) +# +# scale = 1 / math.sqrt(math.sqrt(ch)) +# weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards +# weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) +# +# a = torch.einsum("bts,bcs->bct", weight, v) +# h = a.reshape(bs, -1, length) +# +# h = self.proj_out(h) +# h = h.reshape(b, c, *spatial) +# +# return (x + h) / np.sqrt(2.0) + + +# TODO(Patrick) - this can and should be removed +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then +class NIN(nn.Module): + def __init__(self, in_dim, num_units, init_scale=0.1): + super().__init__() + self.W = nn.Parameter(torch.zeros(in_dim, num_units), requires_grad=True) + self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) diff --git a/src/diffusers/models/attention2d.py b/src/diffusers/models/attention2d.py deleted file mode 100644 index e7fe805814..0000000000 --- a/src/diffusers/models/attention2d.py +++ /dev/null @@ -1,205 +0,0 @@ -import math - -import torch -import torch.nn.functional as F -from torch import nn - - -# unet_grad_tts.py -class LinearAttention(torch.nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super(LinearAttention, self).__init__() - self.heads = heads - self.dim_head = dim_head - hidden_dim = dim_head * heads - self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) - self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x) - # q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) - q, k, v = ( - qkv.reshape(b, 3, self.heads, self.dim_head, h, w) - .permute(1, 0, 2, 3, 4, 5) - .reshape(3, b, self.heads, self.dim_head, -1) - ) - k = k.softmax(dim=-1) - context = torch.einsum("bhdn,bhen->bhde", k, v) - out = torch.einsum("bhde,bhdn->bhen", context, q) - # out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) - out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w) - return self.to_out(out) - - -# unet_glide.py & unet_ldm.py -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. - - Originally ported from here, but adapted to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - """ - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - encoder_channels=None, - use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete? - overwrite_qkv=False, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels, swish=0.0) - self.qkv = conv_nd(1, channels, channels * 3, 1) - self.n_heads = self.num_heads - - if encoder_channels is not None: - self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) - - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - - self.overwrite_qkv = overwrite_qkv - if overwrite_qkv: - in_channels = channels - self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - - self.is_overwritten = False - - def set_weights(self, module): - if self.overwrite_qkv: - qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[:, :, :, 0] - qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0) - - self.qkv.weight.data = qkv_weight - self.qkv.bias.data = qkv_bias - - proj_out = zero_module(conv_nd(1, self.channels, self.channels, 1)) - proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0] - proj_out.bias.data = module.proj_out.bias.data - - self.proj_out = proj_out - - def forward(self, x, encoder_out=None): - if self.overwrite_qkv and not self.is_overwritten: - self.set_weights(self) - self.is_overwritten = True - - b, c, *spatial = x.shape - hid_states = self.norm(x).view(b, c, -1) - - qkv = self.qkv(hid_states) - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) - - if encoder_out is not None: - encoder_kv = self.encoder_kv(encoder_out) - assert encoder_kv.shape[1] == self.n_heads * ch * 2 - ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) - k = torch.cat([ek, k], dim=-1) - v = torch.cat([ev, v], dim=-1) - - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - - a = torch.einsum("bts,bcs->bct", weight, v) - h = a.reshape(bs, -1, length) - - h = self.proj_out(h) - - return x + h.reshape(b, c, *spatial) - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -class GroupNorm32(nn.GroupNorm): - def __init__(self, num_groups, num_channels, swish, eps=1e-5, affine=True): - super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine) - self.swish = swish - - def forward(self, x): - y = super().forward(x.float()).to(x.dtype) - if self.swish == 1.0: - y = F.silu(y) - elif self.swish: - y = y * F.sigmoid(y * float(self.swish)) - return y - - -def normalization(channels, swish=0.0, eps=1e-5): - """ - Make a standard normalization layer, with an optional swish activation. - - :param channels: number of input channels. :return: an nn.Module for normalization. - """ - return GroupNorm32(num_channels=channels, num_groups=32, swish=swish, eps=eps, affine=True) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -# unet_score_estimation.py -# class AttnBlockpp(nn.Module): -# """Channel-wise self-attention block. Modified from DDPM.""" -# -# def __init__(self, channels, skip_rescale=False, init_scale=0.0): -# super().__init__() -# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) -# self.NIN_0 = NIN(channels, channels) -# self.NIN_1 = NIN(channels, channels) -# self.NIN_2 = NIN(channels, channels) -# self.NIN_3 = NIN(channels, channels, init_scale=init_scale) -# self.skip_rescale = skip_rescale -# -# def forward(self, x): -# B, C, H, W = x.shape -# h = self.GroupNorm_0(x) -# q = self.NIN_0(h) -# k = self.NIN_1(h) -# v = self.NIN_2(h) -# -# w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) -# w = torch.reshape(w, (B, H, W, H * W)) -# w = F.softmax(w, dim=-1) -# w = torch.reshape(w, (B, H, W, H, W)) -# h = torch.einsum("bhwij,bcij->bchw", w, v) -# h = self.NIN_3(h) -# if not self.skip_rescale: -# return x + h -# else: -# return (x + h) / np.sqrt(2.0) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index e70f39319e..cb4ad0f5e5 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -65,19 +65,3 @@ class GaussianFourierProjection(nn.Module): def forward(self, x): x_proj = x[:, None] * self.W[None, :] * 2 * np.pi return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) - - -# unet_rl.py - TODO(need test) -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index af43389527..13765e1f8b 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -15,24 +15,14 @@ # helpers functions -import copy -import math -from pathlib import Path - import torch from torch import nn -from torch.cuda.amp import GradScaler, autocast -from torch.optim import Adam -from torch.utils import data - -from PIL import Image -from tqdm import tqdm from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .attention import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample -from .attention2d import AttentionBlock def nonlinearity(x): @@ -219,11 +209,7 @@ class UNetModel(ModelMixin, ConfigMixin): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: -# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block]) -# h = self.down[i_level].attn_2[i_block](h) - h = self.down[i_level].attn[i_block](h) -# print("Result", (h - h_2).abs().sum()) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 0ffd20b35f..53763ddaa0 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .attention2d import AttentionBlock +from .attention import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 4cccc8375c..bc0c1e7a22 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -1,8 +1,8 @@ import torch -from numpy import pad from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .attention import LinearAttention from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample @@ -54,32 +54,6 @@ class ResnetBlock(torch.nn.Module): return output -class LinearAttention(torch.nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super(LinearAttention, self).__init__() - self.heads = heads - self.dim_head = dim_head - hidden_dim = dim_head * heads - self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) - self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x) - # q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) - q, k, v = ( - qkv.reshape(b, 3, self.heads, self.dim_head, h, w) - .permute(1, 0, 2, 3, 4, 5) - .reshape(3, b, self.heads, self.dim_head, -1) - ) - k = k.softmax(dim=-1) - context = torch.einsum("bhdn,bhen->bhde", k, v) - out = torch.einsum("bhde,bhdn->bhen", context, q) - # out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) - out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w) - return self.to_out(out) - - class Residual(torch.nn.Module): def __init__(self, fn): super(Residual, self).__init__() diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 5315285afa..0012886a5e 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .attention2d import AttentionBlock +from .attention import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 3f58a55cbe..17218a7a7e 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -16,6 +16,7 @@ # helpers functions import functools +import math import string import numpy as np @@ -25,6 +26,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .attention import AttentionBlock from .embeddings import GaussianFourierProjection, get_timestep_embedding @@ -414,37 +416,6 @@ class Combine(nn.Module): raise ValueError(f"Method {self.method} not recognized.") -class AttnBlockpp(nn.Module): - """Channel-wise self-attention block. Modified from DDPM.""" - - def __init__(self, channels, skip_rescale=False, init_scale=0.0): - super().__init__() - self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) - self.NIN_0 = NIN(channels, channels) - self.NIN_1 = NIN(channels, channels) - self.NIN_2 = NIN(channels, channels) - self.NIN_3 = NIN(channels, channels, init_scale=init_scale) - self.skip_rescale = skip_rescale - - def forward(self, x): - B, C, H, W = x.shape - h = self.GroupNorm_0(x) - q = self.NIN_0(h) - k = self.NIN_1(h) - v = self.NIN_2(h) - - w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) - w = torch.reshape(w, (B, H, W, H * W)) - w = F.softmax(w, dim=-1) - w = torch.reshape(w, (B, H, W, H, W)) - h = torch.einsum("bhwij,bcij->bchw", w, v) - h = self.NIN_3(h) - if not self.skip_rescale: - return x + h - else: - return (x + h) / np.sqrt(2.0) - - class Upsample(nn.Module): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): super().__init__() @@ -756,8 +727,7 @@ class NCSNpp(ModelMixin, ConfigMixin): modules[-1].weight.data = default_init()(modules[-1].weight.shape) nn.init.zeros_(modules[-1].bias) - AttnBlock = functools.partial(AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale) - + AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) if progressive == "output_skip": diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index aa4c67a3e2..2e0be1d334 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -859,7 +859,9 @@ class PipelineTesterMixin(unittest.TestCase): image_slice = image[0, -1, -3:, -3:].cpu() assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor([-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105]) + expected_slice = torch.tensor( + [-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105] + ) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 @slow