1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

attn refactoring

This commit is contained in:
Nouamane Tazi
2022-09-14 12:36:38 +00:00
parent d30f968222
commit 0c70c0e189

View File

@@ -1,4 +1,5 @@
import math
from typing import Optional
import torch
import torch.nn.functional as F
@@ -10,16 +11,24 @@ class AttentionBlock(nn.Module):
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention
Uses three q, k, v linear layers to compute attention.
Parameters:
channels (:obj:`int`): The number of channels in the input and output.
num_head_channels (:obj:`int`, *optional*):
The number of channels in each head. If None, then `num_heads` = 1.
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""
def __init__(
self,
channels,
num_head_channels=None,
num_groups=32,
rescale_output_factor=1.0,
eps=1e-5,
channels: int,
num_head_channels: Optional[int] = None,
num_groups: int = 32,
rescale_output_factor: float = 1.0,
eps: float = 1e-5,
):
super().__init__()
self.channels = channels
@@ -86,10 +95,26 @@ class AttentionBlock(nn.Module):
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
standard transformer action. Finally, reshape to image.
Parameters:
in_channels (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
"""
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
def __init__(
self,
in_channels: int,
n_heads: int,
d_head: int,
depth: int = 1,
dropout: float = 0.0,
context_dim: Optional[int] = None,
):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
@@ -112,22 +137,44 @@ class SpatialTransformer(nn.Module):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)
def forward(self, x, context=None):
def forward(self, hidden_states, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
for block in self.transformer_blocks:
x = block(x, context=context)
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
x = self.proj_out(x)
return x + x_in
hidden_states = block(hidden_states, context=context)
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
hidden_states = self.proj_out(hidden_states)
return hidden_states + residual
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
r"""
A basic Transformer block.
Parameters:
dim (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
"""
def __init__(
self,
dim: int,
n_heads: int,
d_head: int,
dropout=0.0,
context_dim: Optional[int] = None,
gated_ff: bool = True,
checkpoint: bool = True,
):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
@@ -145,15 +192,30 @@ class BasicTransformerBlock(nn.Module):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size
def forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
def forward(self, hidden_states, context=None):
hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
r"""
A cross attention layer.
Parameters:
query_dim (:obj:`int`): The number of channels in the query.
context_dim (:obj:`int`, *optional*):
The number of channels in the context. If not given, defaults to `query_dim`.
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""
def __init__(
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
):
super().__init__()
inner_dim = dim_head * heads
context_dim = context_dim if context_dim is not None else query_dim
@@ -174,52 +236,58 @@ class CrossAttention(nn.Module):
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
tensor2 = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor3 = tensor2.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor3
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
tensor2 = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor3 = tensor2.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor3
def forward(self, x, context=None, mask=None):
batch_size, sequence_length, dim = x.shape
def forward(self, hidden_states, context=None, mask=None):
batch_size, sequence_length, dim = hidden_states.shape
q = self.to_q(x)
context = context if context is not None else x
k = self.to_k(context)
v = self.to_v(context)
query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
key = self.to_k(context)
value = self.to_v(context)
q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
# attention, what we cannot get enough of
hidden_states = self._attention(q, k, v, sequence_length, dim)
hidden_states = self._attention(query, key, value, sequence_length, dim)
return self.to_out(hidden_states)
def _attention(self, query, key, value, sequence_length, dim):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = (
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
)
attn_slice = attn_slice.softmax(dim=-1)
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
# hidden_states = torch.zeros(
# (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
# )
slice_size = self._slice_size if self._slice_size is not None else batch_size_attention
# for i in range(hidden_states.shape[0] // slice_size):
# start_idx = i * slice_size
# end_idx = (i + 1) * slice_size
# qslice = query[start_idx:end_idx]
qslice = query
# kslice = key[start_idx:end_idx].transpose(1, 2)
kslice = key.transpose(1, 2)
attn_slice = torch.matmul(qslice, kslice) * self.scale
attn_slice = attn_slice.softmax(dim=-1)
# vslice = value[start_idx:end_idx]
vslice = value
hidden_states = torch.matmul(attn_slice, vslice)
# hidden_states = torch.cat(attn_slices, dim=0)
hidden_states[start_idx:end_idx] = attn_slice
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
@@ -227,7 +295,20 @@ class CrossAttention(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
r"""
A feed-forward layer.
Parameters:
dim (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""
def __init__(
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
@@ -235,16 +316,24 @@ class FeedForward(nn.Module):
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
def forward(self, hidden_states):
return self.net(hidden_states)
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`): The number of channels in the output.
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * F.gelu(gate)