1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add only cross attention to simple attention blocks (#3011)

* add only cross attention to simple attention blocks

* add test for only_cross_attention re: @patrickvonplaten

* mid_block_only_cross_attention better default

allow mid_block_only_cross_attention to default to
`only_cross_attention` when `only_cross_attention` is given
as a single boolean
This commit is contained in:
Will Berman
2023-04-11 14:38:50 -07:00
committed by Daniel Gu
parent d8eedb4787
commit 0c8fd45894
5 changed files with 148 additions and 17 deletions

View File

@@ -61,6 +61,7 @@ class Attention(nn.Module):
norm_num_groups: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
processor: Optional["AttnProcessor"] = None,
):
super().__init__()
@@ -79,6 +80,12 @@ class Attention(nn.Module):
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention
if self.added_kv_proj_dim is None and self.only_cross_attention:
raise ValueError(
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
)
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
@@ -89,8 +96,14 @@ class Attention(nn.Module):
self.norm_cross = nn.LayerNorm(cross_attention_dim)
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
else:
self.to_k = None
self.to_v = None
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
@@ -408,18 +421,21 @@ class AttnAddedKVProcessor:
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
@@ -637,18 +653,22 @@ class SlicedAttnAddedKVProcessor:
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(

View File

@@ -125,6 +125,7 @@ def get_down_block(
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
only_cross_attention=only_cross_attention,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
@@ -291,6 +292,7 @@ def get_up_block(
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
only_cross_attention=only_cross_attention,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
@@ -575,6 +577,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
output_scale_factor=1.0,
cross_attention_dim=1280,
skip_time_act=False,
only_cross_attention=False,
):
super().__init__()
@@ -614,6 +617,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
processor=AttnAddedKVProcessor(),
)
)
@@ -1356,6 +1360,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
output_scale_factor=1.0,
add_downsample=True,
skip_time_act=False,
only_cross_attention=False,
):
super().__init__()
@@ -1394,6 +1399,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
processor=AttnAddedKVProcessor(),
)
)
@@ -2354,6 +2360,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
output_scale_factor=1.0,
add_upsample=True,
skip_time_act=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
@@ -2393,6 +2400,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
processor=AttnAddedKVProcessor(),
)
)

View File

@@ -110,7 +110,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
embeddings with the class embeddings.
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
default to `False`.
"""
_supports_gradient_checkpointing = True
@@ -158,6 +163,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
):
super().__init__()
@@ -265,8 +271,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = only_cross_attention
only_cross_attention = [only_cross_attention] * len(down_block_types)
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
@@ -342,6 +354,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention,
)
elif mid_block_type is None:
self.mid_block = None

View File

@@ -191,7 +191,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
embeddings with the class embeddings.
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
default to `False`.
"""
_supports_gradient_checkpointing = True
@@ -244,6 +249,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
):
super().__init__()
@@ -357,8 +363,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = only_cross_attention
only_cross_attention = [only_cross_attention] * len(down_block_types)
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
@@ -434,6 +446,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention,
)
elif mid_block_type is None:
self.mid_block = None
@@ -1476,6 +1489,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
output_scale_factor=1.0,
cross_attention_dim=1280,
skip_time_act=False,
only_cross_attention=False,
):
super().__init__()
@@ -1515,6 +1529,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
processor=AttnAddedKVProcessor(),
)
)

View File

@@ -0,0 +1,75 @@
import unittest
import torch
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
class AttnAddedKVProcessorTests(unittest.TestCase):
def get_constructor_arguments(self, only_cross_attention: bool = False):
query_dim = 10
if only_cross_attention:
cross_attention_dim = 12
else:
# when only cross attention is not set, the cross attention dim must be the same as the query dim
cross_attention_dim = query_dim
return {
"query_dim": query_dim,
"cross_attention_dim": cross_attention_dim,
"heads": 2,
"dim_head": 4,
"added_kv_proj_dim": 6,
"norm_num_groups": 1,
"only_cross_attention": only_cross_attention,
"processor": AttnAddedKVProcessor(),
}
def get_forward_arguments(self, query_dim, added_kv_proj_dim):
batch_size = 2
hidden_states = torch.rand(batch_size, query_dim, 3, 2)
encoder_hidden_states = torch.rand(batch_size, 4, added_kv_proj_dim)
attention_mask = None
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"attention_mask": attention_mask,
}
def test_only_cross_attention(self):
# self and cross attention
torch.manual_seed(0)
constructor_args = self.get_constructor_arguments(only_cross_attention=False)
attn = Attention(**constructor_args)
self.assertTrue(attn.to_k is not None)
self.assertTrue(attn.to_v is not None)
forward_args = self.get_forward_arguments(
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"]
)
self_and_cross_attn_out = attn(**forward_args)
# only self attention
torch.manual_seed(0)
constructor_args = self.get_constructor_arguments(only_cross_attention=True)
attn = Attention(**constructor_args)
self.assertTrue(attn.to_k is None)
self.assertTrue(attn.to_v is None)
forward_args = self.get_forward_arguments(
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"]
)
only_cross_attn_out = attn(**forward_args)
self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all())