From 52b3ff5eb91eb614e1c976b877c7f1e4e92aec51 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Jun 2022 11:29:16 +0000 Subject: [PATCH] unify ldm and glide attention --- src/diffusers/models/attention2d.py | 167 +++++++++++++--------------- src/diffusers/models/unet_glide.py | 80 +------------ src/diffusers/models/unet_ldm.py | 79 +------------ 3 files changed, 80 insertions(+), 246 deletions(-) diff --git a/src/diffusers/models/attention2d.py b/src/diffusers/models/attention2d.py index 2b19ccc49c..4a4f702d80 100644 --- a/src/diffusers/models/attention2d.py +++ b/src/diffusers/models/attention2d.py @@ -1,3 +1,14 @@ +import math + +import torch +import torch.nn.functional as F +from torch import nn + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + # unet_grad_tts.py class LinearAttention(torch.nn.Module): def __init__(self, dim, heads=4, dim_head=32): @@ -24,6 +35,7 @@ class LinearAttention(torch.nn.Module): out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w) return self.to_out(out) + # unet.py class AttnBlock(nn.Module): def __init__(self, in_channels): @@ -62,7 +74,8 @@ class AttnBlock(nn.Module): return x + h_ -# unet_glide.py + +# unet_glide.py & unet_ldm.py class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. @@ -78,6 +91,7 @@ class AttentionBlock(nn.Module): num_head_channels=-1, use_checkpoint=False, encoder_channels=None, + use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete? ): super().__init__() self.channels = channels @@ -108,6 +122,7 @@ class AttentionBlock(nn.Module): h = self.proj_out(h) return x + h.reshape(b, c, *spatial) + class QKVAttention(nn.Module): """ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping @@ -140,106 +155,78 @@ class QKVAttention(nn.Module): return a.reshape(bs, -1, length) -# unet_ldm.py -class AttentionBlock(nn.Module): +def conv_nd(dims, *args, **kwargs): """ - An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted - to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Create a 1D, 2D, or 3D convolution module. """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - use_new_attention_order=False, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, channels * 3, 1) - # split heads before split qkv - self.attention = QKVAttentionLegacy(self.num_heads) - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) +class GroupNorm32(nn.GroupNorm): + def __init__(self, num_groups, num_channels, swish, eps=1e-5): + super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) + self.swish = swish def forward(self, x): - b, c, *spatial = x.shape - x = x.reshape(b, c, -1) - qkv = self.qkv(self.norm(x)) - h = self.attention(qkv) - h = self.proj_out(h) - return (x + h).reshape(b, c, *spatial) + y = super().forward(x.float()).to(x.dtype) + if self.swish == 1.0: + y = F.silu(y) + elif self.swish: + y = y * F.sigmoid(y * float(self.swish)) + return y -class QKVAttention(nn.Module): + +def normalization(channels, swish=0.0): """ - A module which performs QKV attention and splits in a different order. + Make a standard normalization layer, with an optional swish activation. + + :param channels: number of input channels. :return: an nn.Module for normalization. """ + return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - def forward(self, qkv): - """ - Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x - T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.chunk(3, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = torch.einsum( - "bct,bcs->bts", - (q * scale).view(bs * self.n_heads, ch, length), - (k * scale).view(bs * self.n_heads, ch, length), - ) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) - return a.reshape(bs, -1, length) +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) # unet_score_estimation.py -class AttnBlockpp(nn.Module): - """Channel-wise self-attention block. Modified from DDPM.""" - - def __init__(self, channels, skip_rescale=False, init_scale=0.0): - super().__init__() - self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) - self.NIN_0 = NIN(channels, channels) - self.NIN_1 = NIN(channels, channels) - self.NIN_2 = NIN(channels, channels) - self.NIN_3 = NIN(channels, channels, init_scale=init_scale) - self.skip_rescale = skip_rescale - - def forward(self, x): - B, C, H, W = x.shape - h = self.GroupNorm_0(x) - q = self.NIN_0(h) - k = self.NIN_1(h) - v = self.NIN_2(h) - - w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) - w = torch.reshape(w, (B, H, W, H * W)) - w = F.softmax(w, dim=-1) - w = torch.reshape(w, (B, H, W, H, W)) - h = torch.einsum("bhwij,bcij->bchw", w, v) - h = self.NIN_3(h) - if not self.skip_rescale: - return x + h - else: - return (x + h) / np.sqrt(2.0) +# class AttnBlockpp(nn.Module): +# """Channel-wise self-attention block. Modified from DDPM.""" +# +# def __init__(self, channels, skip_rescale=False, init_scale=0.0): +# super().__init__() +# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) +# self.NIN_0 = NIN(channels, channels) +# self.NIN_1 = NIN(channels, channels) +# self.NIN_2 = NIN(channels, channels) +# self.NIN_3 = NIN(channels, channels, init_scale=init_scale) +# self.skip_rescale = skip_rescale +# +# def forward(self, x): +# B, C, H, W = x.shape +# h = self.GroupNorm_0(x) +# q = self.NIN_0(h) +# k = self.NIN_1(h) +# v = self.NIN_2(h) +# +# w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) +# w = torch.reshape(w, (B, H, W, H * W)) +# w = F.softmax(w, dim=-1) +# w = torch.reshape(w, (B, H, W, H, W)) +# h = torch.einsum("bhwij,bcij->bchw", w, v) +# h = self.NIN_3(h) +# if not self.skip_rescale: +# return x + h +# else: +# return (x + h) / np.sqrt(2.0) diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 6fe27959ff..0ffd20b35f 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -1,4 +1,3 @@ -import math from abc import abstractmethod import torch @@ -7,6 +6,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .attention2d import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample @@ -226,84 +226,6 @@ class ResBlock(TimestepBlock): return self.skip_connection(x) + h -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. - - Originally ported from here, but adapted to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - """ - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - encoder_channels=None, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels, swish=0.0) - self.qkv = conv_nd(1, channels, channels * 3, 1) - self.attention = QKVAttention(self.num_heads) - - if encoder_channels is not None: - self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - - def forward(self, x, encoder_out=None): - b, c, *spatial = x.shape - qkv = self.qkv(self.norm(x).view(b, c, -1)) - if encoder_out is not None: - encoder_out = self.encoder_kv(encoder_out) - h = self.attention(qkv, encoder_out) - else: - h = self.attention(qkv) - h = self.proj_out(h) - return x + h.reshape(b, c, *spatial) - - -class QKVAttention(nn.Module): - """ - A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv, encoder_kv=None): - """ - Apply QKV attention. - - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after - attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) - if encoder_kv is not None: - assert encoder_kv.shape[1] == self.n_heads * ch * 2 - ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) - k = torch.cat([ek, k], dim=-1) - v = torch.cat([ev, v], dim=-1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - a = torch.einsum("bts,bcs->bct", weight, v) - return a.reshape(bs, -1, length) - - class GlideUNetModel(ModelMixin, ConfigMixin): """ The full UNet model with attention and timestep embedding. diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 24aec1cf56..5315285afa 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .attention2d import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample @@ -172,8 +173,6 @@ class CrossAttention(nn.Module): k = self.to_k(context) v = self.to_v(context) - # q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - q = self.reshape_heads_to_batch_dim(q) k = self.reshape_heads_to_batch_dim(k) v = self.reshape_heads_to_batch_dim(v) @@ -181,12 +180,9 @@ class CrossAttention(nn.Module): sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale if exists(mask): - # mask = rearrange(mask, "b ... -> b (...)") - maks = mask.reshape(batch_size, -1) + mask = mask.reshape(batch_size, -1) max_neg_value = -torch.finfo(sim.dtype).max - # mask = repeat(mask, "b j -> (b h) () j", h=h) mask = mask[:, None, :].repeat(h, 1, 1) - # x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of @@ -194,7 +190,6 @@ class CrossAttention(nn.Module): out = torch.einsum("b i j, b j d -> b i d", attn, v) out = self.reshape_batch_dim_to_heads(out) - # out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return self.to_out(out) @@ -487,47 +482,6 @@ class ResBlock(TimestepBlock): return self.skip_connection(x) + h -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted - to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - """ - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - use_new_attention_order=False, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, channels * 3, 1) - # split heads before split qkv - self.attention = QKVAttentionLegacy(self.num_heads) - - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - - def forward(self, x): - b, c, *spatial = x.shape - x = x.reshape(b, c, -1) - qkv = self.qkv(self.norm(x)) - h = self.attention(qkv) - h = self.proj_out(h) - return (x + h).reshape(b, c, *spatial) - - class QKVAttention(nn.Module): """ A module which performs QKV attention and splits in a different order. @@ -577,35 +531,6 @@ def count_flops_attn(model, _x, y): model.total_ops += torch.DoubleTensor([matmul_ops]) -class QKVAttentionLegacy(nn.Module): - """ - A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x - T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - a = torch.einsum("bts,bcs->bct", weight, v) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - class UNetLDMModel(ModelMixin, ConfigMixin): """ The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param