mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add only cross attention to simple attention blocks (#3011)
* add only cross attention to simple attention blocks * add test for only_cross_attention re: @patrickvonplaten * mid_block_only_cross_attention better default allow mid_block_only_cross_attention to default to `only_cross_attention` when `only_cross_attention` is given as a single boolean
This commit is contained in:
@@ -61,6 +61,7 @@ class Attention(nn.Module):
|
||||
norm_num_groups: Optional[int] = None,
|
||||
out_bias: bool = True,
|
||||
scale_qk: bool = True,
|
||||
only_cross_attention: bool = False,
|
||||
processor: Optional["AttnProcessor"] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -79,6 +80,12 @@ class Attention(nn.Module):
|
||||
self.sliceable_head_dim = heads
|
||||
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
||||
raise ValueError(
|
||||
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
||||
)
|
||||
|
||||
if norm_num_groups is not None:
|
||||
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
||||
@@ -89,8 +96,14 @@ class Attention(nn.Module):
|
||||
self.norm_cross = nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
|
||||
if not self.only_cross_attention:
|
||||
# only relevant for the `AddedKVProcessor` classes
|
||||
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
else:
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
||||
@@ -408,18 +421,21 @@ class AttnAddedKVProcessor:
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
||||
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
||||
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
if not attn.only_cross_attention:
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
else:
|
||||
key = encoder_hidden_states_key_proj
|
||||
value = encoder_hidden_states_value_proj
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
@@ -637,18 +653,22 @@ class SlicedAttnAddedKVProcessor:
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
||||
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
||||
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
if not attn.only_cross_attention:
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
else:
|
||||
key = encoder_hidden_states_key_proj
|
||||
value = encoder_hidden_states_value_proj
|
||||
|
||||
batch_size_attention, query_tokens, _ = query.shape
|
||||
hidden_states = torch.zeros(
|
||||
|
||||
@@ -125,6 +125,7 @@ def get_down_block(
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
output_scale_factor=resnet_out_scale_factor,
|
||||
only_cross_attention=only_cross_attention,
|
||||
)
|
||||
elif down_block_type == "SkipDownBlock2D":
|
||||
return SkipDownBlock2D(
|
||||
@@ -291,6 +292,7 @@ def get_up_block(
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
output_scale_factor=resnet_out_scale_factor,
|
||||
only_cross_attention=only_cross_attention,
|
||||
)
|
||||
elif up_block_type == "AttnUpBlock2D":
|
||||
return AttnUpBlock2D(
|
||||
@@ -575,6 +577,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
skip_time_act=False,
|
||||
only_cross_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -614,6 +617,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
only_cross_attention=only_cross_attention,
|
||||
processor=AttnAddedKVProcessor(),
|
||||
)
|
||||
)
|
||||
@@ -1356,6 +1360,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
skip_time_act=False,
|
||||
only_cross_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1394,6 +1399,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
only_cross_attention=only_cross_attention,
|
||||
processor=AttnAddedKVProcessor(),
|
||||
)
|
||||
)
|
||||
@@ -2354,6 +2360,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
skip_time_act=False,
|
||||
only_cross_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -2393,6 +2400,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
only_cross_attention=only_cross_attention,
|
||||
processor=AttnAddedKVProcessor(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -110,7 +110,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
||||
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
|
||||
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
||||
embeddings with the class embeddings.
|
||||
embeddings with the class embeddings.
|
||||
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
||||
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
|
||||
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
|
||||
default to `False`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -158,6 +163,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
conv_out_kernel: int = 3,
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
class_embeddings_concat: bool = False,
|
||||
mid_block_only_cross_attention: Optional[bool] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -265,8 +271,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
if mid_block_only_cross_attention is None:
|
||||
mid_block_only_cross_attention = only_cross_attention
|
||||
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if mid_block_only_cross_attention is None:
|
||||
mid_block_only_cross_attention = False
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
@@ -342,6 +354,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
only_cross_attention=mid_block_only_cross_attention,
|
||||
)
|
||||
elif mid_block_type is None:
|
||||
self.mid_block = None
|
||||
|
||||
@@ -191,7 +191,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
||||
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
|
||||
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
||||
embeddings with the class embeddings.
|
||||
embeddings with the class embeddings.
|
||||
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If
|
||||
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
|
||||
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
|
||||
default to `False`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -244,6 +249,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
conv_out_kernel: int = 3,
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
class_embeddings_concat: bool = False,
|
||||
mid_block_only_cross_attention: Optional[bool] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -357,8 +363,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
if mid_block_only_cross_attention is None:
|
||||
mid_block_only_cross_attention = only_cross_attention
|
||||
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if mid_block_only_cross_attention is None:
|
||||
mid_block_only_cross_attention = False
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
@@ -434,6 +446,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
only_cross_attention=mid_block_only_cross_attention,
|
||||
)
|
||||
elif mid_block_type is None:
|
||||
self.mid_block = None
|
||||
@@ -1476,6 +1489,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
skip_time_act=False,
|
||||
only_cross_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1515,6 +1529,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
only_cross_attention=only_cross_attention,
|
||||
processor=AttnAddedKVProcessor(),
|
||||
)
|
||||
)
|
||||
|
||||
75
tests/models/test_attention_processor.py
Normal file
75
tests/models/test_attention_processor.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
|
||||
|
||||
|
||||
class AttnAddedKVProcessorTests(unittest.TestCase):
|
||||
def get_constructor_arguments(self, only_cross_attention: bool = False):
|
||||
query_dim = 10
|
||||
|
||||
if only_cross_attention:
|
||||
cross_attention_dim = 12
|
||||
else:
|
||||
# when only cross attention is not set, the cross attention dim must be the same as the query dim
|
||||
cross_attention_dim = query_dim
|
||||
|
||||
return {
|
||||
"query_dim": query_dim,
|
||||
"cross_attention_dim": cross_attention_dim,
|
||||
"heads": 2,
|
||||
"dim_head": 4,
|
||||
"added_kv_proj_dim": 6,
|
||||
"norm_num_groups": 1,
|
||||
"only_cross_attention": only_cross_attention,
|
||||
"processor": AttnAddedKVProcessor(),
|
||||
}
|
||||
|
||||
def get_forward_arguments(self, query_dim, added_kv_proj_dim):
|
||||
batch_size = 2
|
||||
|
||||
hidden_states = torch.rand(batch_size, query_dim, 3, 2)
|
||||
encoder_hidden_states = torch.rand(batch_size, 4, added_kv_proj_dim)
|
||||
attention_mask = None
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
def test_only_cross_attention(self):
|
||||
# self and cross attention
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
constructor_args = self.get_constructor_arguments(only_cross_attention=False)
|
||||
attn = Attention(**constructor_args)
|
||||
|
||||
self.assertTrue(attn.to_k is not None)
|
||||
self.assertTrue(attn.to_v is not None)
|
||||
|
||||
forward_args = self.get_forward_arguments(
|
||||
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"]
|
||||
)
|
||||
|
||||
self_and_cross_attn_out = attn(**forward_args)
|
||||
|
||||
# only self attention
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
constructor_args = self.get_constructor_arguments(only_cross_attention=True)
|
||||
attn = Attention(**constructor_args)
|
||||
|
||||
self.assertTrue(attn.to_k is None)
|
||||
self.assertTrue(attn.to_v is None)
|
||||
|
||||
forward_args = self.get_forward_arguments(
|
||||
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"]
|
||||
)
|
||||
|
||||
only_cross_attn_out = attn(**forward_args)
|
||||
|
||||
self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all())
|
||||
Reference in New Issue
Block a user