mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
all attentions collected
This commit is contained in:
@@ -0,0 +1,245 @@
|
||||
# unet_grad_tts.py
|
||||
class LinearAttention(torch.nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super(LinearAttention, self).__init__()
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
||||
q, k, v = (
|
||||
qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
|
||||
.permute(1, 0, 2, 3, 4, 5)
|
||||
.reshape(3, b, self.heads, self.dim_head, -1)
|
||||
)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
||||
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
|
||||
return self.to_out(out)
|
||||
|
||||
# unet.py
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
# unet_glide.py
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
encoder_channels=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels, swish=0.0)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
if encoder_channels is not None:
|
||||
self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x, encoder_out=None):
|
||||
b, c, *spatial = x.shape
|
||||
qkv = self.qkv(self.norm(x).view(b, c, -1))
|
||||
if encoder_out is not None:
|
||||
encoder_out = self.encoder_kv(encoder_out)
|
||||
h = self.attention(qkv, encoder_out)
|
||||
else:
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return x + h.reshape(b, c, *spatial)
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv, encoder_kv=None):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after
|
||||
attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
if encoder_kv is not None:
|
||||
assert encoder_kv.shape[1] == self.n_heads * ch * 2
|
||||
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
|
||||
k = torch.cat([ek, k], dim=-1)
|
||||
v = torch.cat([ev, v], dim=-1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
|
||||
# unet_ldm.py
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
||||
to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
use_new_attention_order=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
|
||||
T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = torch.einsum(
|
||||
"bct,bcs->bts",
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
# unet_score_estimation.py
|
||||
class AttnBlockpp(nn.Module):
|
||||
"""Channel-wise self-attention block. Modified from DDPM."""
|
||||
|
||||
def __init__(self, channels, skip_rescale=False, init_scale=0.0):
|
||||
super().__init__()
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
|
||||
self.NIN_0 = NIN(channels, channels)
|
||||
self.NIN_1 = NIN(channels, channels)
|
||||
self.NIN_2 = NIN(channels, channels)
|
||||
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
|
||||
self.skip_rescale = skip_rescale
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
h = self.GroupNorm_0(x)
|
||||
q = self.NIN_0(h)
|
||||
k = self.NIN_1(h)
|
||||
v = self.NIN_2(h)
|
||||
|
||||
w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
|
||||
w = torch.reshape(w, (B, H, W, H * W))
|
||||
w = F.softmax(w, dim=-1)
|
||||
w = torch.reshape(w, (B, H, W, H, W))
|
||||
h = torch.einsum("bhwij,bcij->bchw", w, v)
|
||||
h = self.NIN_3(h)
|
||||
if not self.skip_rescale:
|
||||
return x + h
|
||||
else:
|
||||
return (x + h) / np.sqrt(2.0)
|
||||
|
||||
Reference in New Issue
Block a user