1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

unify ldm and glide attention

This commit is contained in:
Patrick von Platen
2022-06-28 11:29:16 +00:00
parent fff981df2f
commit 52b3ff5eb9
3 changed files with 80 additions and 246 deletions

View File

@@ -1,3 +1,14 @@
import math
import torch
import torch.nn.functional as F
from torch import nn
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
# unet_grad_tts.py
class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
@@ -24,6 +35,7 @@ class LinearAttention(torch.nn.Module):
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out)
# unet.py
class AttnBlock(nn.Module):
def __init__(self, in_channels):
@@ -62,7 +74,8 @@ class AttnBlock(nn.Module):
return x + h_
# unet_glide.py
# unet_glide.py & unet_ldm.py
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
@@ -78,6 +91,7 @@ class AttentionBlock(nn.Module):
num_head_channels=-1,
use_checkpoint=False,
encoder_channels=None,
use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
):
super().__init__()
self.channels = channels
@@ -108,6 +122,7 @@ class AttentionBlock(nn.Module):
h = self.proj_out(h)
return x + h.reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
@@ -140,106 +155,78 @@ class QKVAttention(nn.Module):
return a.reshape(bs, -1, length)
# unet_ldm.py
class AttentionBlock(nn.Module):
def conv_nd(dims, *args, **kwargs):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
self.swish = swish
def forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
y = super().forward(x.float()).to(x.dtype)
if self.swish == 1.0:
y = F.silu(y)
elif self.swish:
y = y * F.sigmoid(y * float(self.swish))
return y
class QKVAttention(nn.Module):
def normalization(channels, swish=0.0):
"""
A module which performs QKV attention and splits in a different order.
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
# unet_score_estimation.py
class AttnBlockpp(nn.Module):
"""Channel-wise self-attention block. Modified from DDPM."""
def __init__(self, channels, skip_rescale=False, init_scale=0.0):
super().__init__()
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
self.skip_rescale = skip_rescale
def forward(self, x):
B, C, H, W = x.shape
h = self.GroupNorm_0(x)
q = self.NIN_0(h)
k = self.NIN_1(h)
v = self.NIN_2(h)
w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
w = torch.reshape(w, (B, H, W, H * W))
w = F.softmax(w, dim=-1)
w = torch.reshape(w, (B, H, W, H, W))
h = torch.einsum("bhwij,bcij->bchw", w, v)
h = self.NIN_3(h)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
# class AttnBlockpp(nn.Module):
# """Channel-wise self-attention block. Modified from DDPM."""
#
# def __init__(self, channels, skip_rescale=False, init_scale=0.0):
# super().__init__()
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
# self.NIN_0 = NIN(channels, channels)
# self.NIN_1 = NIN(channels, channels)
# self.NIN_2 = NIN(channels, channels)
# self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
# self.skip_rescale = skip_rescale
#
# def forward(self, x):
# B, C, H, W = x.shape
# h = self.GroupNorm_0(x)
# q = self.NIN_0(h)
# k = self.NIN_1(h)
# v = self.NIN_2(h)
#
# w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
# w = torch.reshape(w, (B, H, W, H * W))
# w = F.softmax(w, dim=-1)
# w = torch.reshape(w, (B, H, W, H, W))
# h = torch.einsum("bhwij,bcij->bchw", w, v)
# h = self.NIN_3(h)
# if not self.skip_rescale:
# return x + h
# else:
# return (x + h) / np.sqrt(2.0)

View File

@@ -1,4 +1,3 @@
import math
from abc import abstractmethod
import torch
@@ -7,6 +6,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
@@ -226,84 +226,6 @@ class ResBlock(TimestepBlock):
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
encoder_channels=None,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels, swish=0.0)
self.qkv = conv_nd(1, channels, channels * 3, 1)
self.attention = QKVAttention(self.num_heads)
if encoder_channels is not None:
self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x, encoder_out=None):
b, c, *spatial = x.shape
qkv = self.qkv(self.norm(x).view(b, c, -1))
if encoder_out is not None:
encoder_out = self.encoder_kv(encoder_out)
h = self.attention(qkv, encoder_out)
else:
h = self.attention(qkv)
h = self.proj_out(h)
return x + h.reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, encoder_kv=None):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after
attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_kv is not None:
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
class GlideUNetModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding.

View File

@@ -9,6 +9,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
@@ -172,8 +173,6 @@ class CrossAttention(nn.Module):
k = self.to_k(context)
v = self.to_v(context)
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)
@@ -181,12 +180,9 @@ class CrossAttention(nn.Module):
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
# mask = rearrange(mask, "b ... -> b (...)")
maks = mask.reshape(batch_size, -1)
mask = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
# mask = repeat(mask, "b j -> (b h) () j", h=h)
mask = mask[:, None, :].repeat(h, 1, 1)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
@@ -194,7 +190,6 @@ class CrossAttention(nn.Module):
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
@@ -487,47 +482,6 @@ class ResBlock(TimestepBlock):
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
@@ -577,35 +531,6 @@ def count_flops_attn(model, _x, y):
model.total_ops += torch.DoubleTensor([matmul_ops])
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class UNetLDMModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param