1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Attention processor cross attention norm group norm (#3021)

add group norm type to attention processor cross attention norm

This lets the cross attention norm use both a group norm block and a
layer norm block.

The group norm operates along the channels dimension
and requires input shape (batch size, channels, *) where as the layer norm with a single
`normalized_shape` dimension only operates over the least significant
dimension i.e. (*, channels).

The channels we want to normalize are the hidden dimension of the encoder hidden states.

By convention, the encoder hidden states are always passed as (batch size, sequence
length, hidden states).

This means the layer norm can operate on the tensor without modification, but the group
norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length).

All existing attention processors will have the same logic and we can
consolidate it in a helper function `prepare_encoder_hidden_states`

prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten

move norm_cross defined check to outside norm_encoder_hidden_states

add missing attn.norm_cross check
This commit is contained in:
Will Berman
2023-04-11 15:51:40 -07:00
committed by GitHub
parent 2d52e81cb9
commit 98c5e5da31
6 changed files with 96 additions and 21 deletions

View File

@@ -56,7 +56,8 @@ class Attention(nn.Module):
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: bool = False,
cross_attention_norm: Optional[str] = None,
cross_attention_norm_num_groups: int = 32,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
out_bias: bool = True,
@@ -69,7 +70,6 @@ class Attention(nn.Module):
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.cross_attention_norm = cross_attention_norm
self.scale = dim_head**-0.5 if scale_qk else 1.0
@@ -92,8 +92,28 @@ class Attention(nn.Module):
else:
self.group_norm = None
if cross_attention_norm:
if cross_attention_norm is None:
self.norm_cross = None
elif cross_attention_norm == "layer_norm":
self.norm_cross = nn.LayerNorm(cross_attention_dim)
elif cross_attention_norm == "group_norm":
if self.added_kv_proj_dim is not None:
# The given `encoder_hidden_states` are initially of shape
# (batch_size, seq_len, added_kv_proj_dim) before being projected
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
# before the projection, so we need to use `added_kv_proj_dim` as
# the number of channels for the group norm.
norm_cross_num_channels = added_kv_proj_dim
else:
norm_cross_num_channels = cross_attention_dim
self.norm_cross = nn.GroupNorm(
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
)
else:
raise ValueError(
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
@@ -304,6 +324,25 @@ class Attention(nn.Module):
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
return attention_mask
def norm_encoder_hidden_states(self, encoder_hidden_states):
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
if isinstance(self.norm_cross, nn.LayerNorm):
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
elif isinstance(self.norm_cross, nn.GroupNorm):
# Group norm norms along the channels dimension and expects
# input to be in the shape of (N, C, *). In this case, we want
# to norm along the hidden dimension, so we need to move
# (batch_size, sequence_length, hidden_size) ->
# (batch_size, hidden_size, sequence_length)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
else:
assert False
return encoder_hidden_states
class AttnProcessor:
def __call__(
@@ -321,8 +360,8 @@ class AttnProcessor:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@@ -388,7 +427,10 @@ class LoRAAttnProcessor(nn.Module):
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -416,6 +458,11 @@ class AttnAddedKVProcessor:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
@@ -467,8 +514,8 @@ class XFormersAttnProcessor:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@@ -511,8 +558,8 @@ class AttnProcessor2_0:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@@ -561,7 +608,10 @@ class LoRAXFormersAttnProcessor(nn.Module):
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -598,8 +648,8 @@ class SlicedAttnProcessor:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@@ -647,6 +697,11 @@ class SlicedAttnAddedKVProcessor:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)

View File

@@ -44,6 +44,7 @@ def get_down_block(
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
@@ -126,6 +127,7 @@ def get_down_block(
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
@@ -223,6 +225,7 @@ def get_up_block(
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
@@ -293,6 +296,7 @@ def get_up_block(
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
@@ -578,6 +582,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
cross_attention_dim=1280,
skip_time_act=False,
only_cross_attention=False,
cross_attention_norm=None,
):
super().__init__()
@@ -618,6 +623,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
)
)
@@ -1361,6 +1367,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
add_downsample=True,
skip_time_act=False,
only_cross_attention=False,
cross_attention_norm=None,
):
super().__init__()
@@ -1400,6 +1407,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
)
)
@@ -1580,7 +1588,7 @@ class KCrossAttnDownBlock2D(nn.Module):
temb_channels=temb_channels,
attention_bias=True,
add_self_attention=add_self_attention,
cross_attention_norm=True,
cross_attention_norm="layer_norm",
group_size=resnet_group_size,
)
)
@@ -2361,6 +2369,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
add_upsample=True,
skip_time_act=False,
only_cross_attention=False,
cross_attention_norm=None,
):
super().__init__()
resnets = []
@@ -2401,6 +2410,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
)
)
@@ -2608,7 +2618,7 @@ class KCrossAttnUpBlock2D(nn.Module):
temb_channels=temb_channels,
attention_bias=True,
add_self_attention=add_self_attention,
cross_attention_norm=True,
cross_attention_norm="layer_norm",
upcast_attention=upcast_attention,
)
)
@@ -2703,7 +2713,7 @@ class KAttentionBlock(nn.Module):
upcast_attention: bool = False,
temb_channels: int = 768, # for ada_group_norm
add_self_attention: bool = False,
cross_attention_norm: bool = False,
cross_attention_norm: Optional[str] = None,
group_size: int = 32,
):
super().__init__()
@@ -2719,7 +2729,7 @@ class KAttentionBlock(nn.Module):
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
cross_attention_norm=False,
cross_attention_norm=None,
)
# 2. Cross-Attn

View File

@@ -169,6 +169,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
):
super().__init__()
@@ -341,6 +342,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
)
self.down_blocks.append(down_block)
@@ -373,6 +375,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif mid_block_type is None:
self.mid_block = None
@@ -424,6 +427,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel

View File

@@ -243,8 +243,8 @@ class Pix2PixZeroAttnProcessor:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

View File

@@ -65,8 +65,8 @@ class CrossAttnStoreProcessor:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

View File

@@ -255,6 +255,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
):
super().__init__()
@@ -433,6 +434,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
)
self.down_blocks.append(down_block)
@@ -465,6 +467,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif mid_block_type is None:
self.mid_block = None
@@ -516,6 +519,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
@@ -1511,6 +1515,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
cross_attention_dim=1280,
skip_time_act=False,
only_cross_attention=False,
cross_attention_norm=None,
):
super().__init__()
@@ -1551,6 +1556,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
)
)