mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
unify ldm and glide attention
This commit is contained in:
@@ -1,3 +1,14 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
# unet_grad_tts.py
|
||||
class LinearAttention(torch.nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
@@ -24,6 +35,7 @@ class LinearAttention(torch.nn.Module):
|
||||
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
# unet.py
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
@@ -62,7 +74,8 @@ class AttnBlock(nn.Module):
|
||||
|
||||
return x + h_
|
||||
|
||||
# unet_glide.py
|
||||
|
||||
# unet_glide.py & unet_ldm.py
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
@@ -78,6 +91,7 @@ class AttentionBlock(nn.Module):
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
encoder_channels=None,
|
||||
use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
@@ -108,6 +122,7 @@ class AttentionBlock(nn.Module):
|
||||
h = self.proj_out(h)
|
||||
return x + h.reshape(b, c, *spatial)
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||
@@ -140,106 +155,78 @@ class QKVAttention(nn.Module):
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
|
||||
# unet_ldm.py
|
||||
class AttentionBlock(nn.Module):
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
||||
to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
use_new_attention_order=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
|
||||
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
|
||||
self.swish = swish
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
y = super().forward(x.float()).to(x.dtype)
|
||||
if self.swish == 1.0:
|
||||
y = F.silu(y)
|
||||
elif self.swish:
|
||||
y = y * F.sigmoid(y * float(self.swish))
|
||||
return y
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
|
||||
def normalization(channels, swish=0.0):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
Make a standard normalization layer, with an optional swish activation.
|
||||
|
||||
:param channels: number of input channels. :return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
|
||||
T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = torch.einsum(
|
||||
"bct,bcs->bts",
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
# unet_score_estimation.py
|
||||
class AttnBlockpp(nn.Module):
|
||||
"""Channel-wise self-attention block. Modified from DDPM."""
|
||||
|
||||
def __init__(self, channels, skip_rescale=False, init_scale=0.0):
|
||||
super().__init__()
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
|
||||
self.NIN_0 = NIN(channels, channels)
|
||||
self.NIN_1 = NIN(channels, channels)
|
||||
self.NIN_2 = NIN(channels, channels)
|
||||
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
|
||||
self.skip_rescale = skip_rescale
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
h = self.GroupNorm_0(x)
|
||||
q = self.NIN_0(h)
|
||||
k = self.NIN_1(h)
|
||||
v = self.NIN_2(h)
|
||||
|
||||
w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
|
||||
w = torch.reshape(w, (B, H, W, H * W))
|
||||
w = F.softmax(w, dim=-1)
|
||||
w = torch.reshape(w, (B, H, W, H, W))
|
||||
h = torch.einsum("bhwij,bcij->bchw", w, v)
|
||||
h = self.NIN_3(h)
|
||||
if not self.skip_rescale:
|
||||
return x + h
|
||||
else:
|
||||
return (x + h) / np.sqrt(2.0)
|
||||
# class AttnBlockpp(nn.Module):
|
||||
# """Channel-wise self-attention block. Modified from DDPM."""
|
||||
#
|
||||
# def __init__(self, channels, skip_rescale=False, init_scale=0.0):
|
||||
# super().__init__()
|
||||
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
|
||||
# self.NIN_0 = NIN(channels, channels)
|
||||
# self.NIN_1 = NIN(channels, channels)
|
||||
# self.NIN_2 = NIN(channels, channels)
|
||||
# self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
|
||||
# self.skip_rescale = skip_rescale
|
||||
#
|
||||
# def forward(self, x):
|
||||
# B, C, H, W = x.shape
|
||||
# h = self.GroupNorm_0(x)
|
||||
# q = self.NIN_0(h)
|
||||
# k = self.NIN_1(h)
|
||||
# v = self.NIN_2(h)
|
||||
#
|
||||
# w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
|
||||
# w = torch.reshape(w, (B, H, W, H * W))
|
||||
# w = F.softmax(w, dim=-1)
|
||||
# w = torch.reshape(w, (B, H, W, H, W))
|
||||
# h = torch.einsum("bhwij,bcij->bchw", w, v)
|
||||
# h = self.NIN_3(h)
|
||||
# if not self.skip_rescale:
|
||||
# return x + h
|
||||
# else:
|
||||
# return (x + h) / np.sqrt(2.0)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
@@ -7,6 +6,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention2d import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, Upsample
|
||||
|
||||
@@ -226,84 +226,6 @@ class ResBlock(TimestepBlock):
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
encoder_channels=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels, swish=0.0)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
if encoder_channels is not None:
|
||||
self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x, encoder_out=None):
|
||||
b, c, *spatial = x.shape
|
||||
qkv = self.qkv(self.norm(x).view(b, c, -1))
|
||||
if encoder_out is not None:
|
||||
encoder_out = self.encoder_kv(encoder_out)
|
||||
h = self.attention(qkv, encoder_out)
|
||||
else:
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return x + h.reshape(b, c, *spatial)
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv, encoder_kv=None):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after
|
||||
attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
if encoder_kv is not None:
|
||||
assert encoder_kv.shape[1] == self.n_heads * ch * 2
|
||||
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
|
||||
k = torch.cat([ek, k], dim=-1)
|
||||
v = torch.cat([ev, v], dim=-1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
|
||||
class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention2d import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, Upsample
|
||||
|
||||
@@ -172,8 +173,6 @@ class CrossAttention(nn.Module):
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
||||
|
||||
q = self.reshape_heads_to_batch_dim(q)
|
||||
k = self.reshape_heads_to_batch_dim(k)
|
||||
v = self.reshape_heads_to_batch_dim(v)
|
||||
@@ -181,12 +180,9 @@ class CrossAttention(nn.Module):
|
||||
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
# mask = rearrange(mask, "b ... -> b (...)")
|
||||
maks = mask.reshape(batch_size, -1)
|
||||
mask = mask.reshape(batch_size, -1)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
# mask = repeat(mask, "b j -> (b h) () j", h=h)
|
||||
mask = mask[:, None, :].repeat(h, 1, 1)
|
||||
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
@@ -194,7 +190,6 @@ class CrossAttention(nn.Module):
|
||||
|
||||
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
||||
out = self.reshape_batch_dim_to_heads(out)
|
||||
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@@ -487,47 +482,6 @@ class ResBlock(TimestepBlock):
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
||||
to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
use_new_attention_order=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
@@ -577,35 +531,6 @@ def count_flops_attn(model, _x, y):
|
||||
model.total_ops += torch.DoubleTensor([matmul_ops])
|
||||
|
||||
|
||||
class QKVAttentionLegacy(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
|
||||
T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
|
||||
|
||||
Reference in New Issue
Block a user