From 98c5e5da31dd70facf92970074be49501cd5e20b Mon Sep 17 00:00:00 2001 From: Will Berman Date: Tue, 11 Apr 2023 15:51:40 -0700 Subject: [PATCH] Attention processor cross attention norm group norm (#3021) add group norm type to attention processor cross attention norm This lets the cross attention norm use both a group norm block and a layer norm block. The group norm operates along the channels dimension and requires input shape (batch size, channels, *) where as the layer norm with a single `normalized_shape` dimension only operates over the least significant dimension i.e. (*, channels). The channels we want to normalize are the hidden dimension of the encoder hidden states. By convention, the encoder hidden states are always passed as (batch size, sequence length, hidden states). This means the layer norm can operate on the tensor without modification, but the group norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length). All existing attention processors will have the same logic and we can consolidate it in a helper function `prepare_encoder_hidden_states` prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten move norm_cross defined check to outside norm_encoder_hidden_states add missing attn.norm_cross check --- src/diffusers/models/attention_processor.py | 81 ++++++++++++++++--- src/diffusers/models/unet_2d_blocks.py | 18 ++++- src/diffusers/models/unet_2d_condition.py | 4 + .../pipeline_stable_diffusion_pix2pix_zero.py | 4 +- .../pipeline_stable_diffusion_sag.py | 4 +- .../versatile_diffusion/modeling_text_unet.py | 6 ++ 6 files changed, 96 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 864b042c24..41baf99999 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -56,7 +56,8 @@ class Attention(nn.Module): bias=False, upcast_attention: bool = False, upcast_softmax: bool = False, - cross_attention_norm: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, out_bias: bool = True, @@ -69,7 +70,6 @@ class Attention(nn.Module): cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax - self.cross_attention_norm = cross_attention_norm self.scale = dim_head**-0.5 if scale_qk else 1.0 @@ -92,8 +92,28 @@ class Attention(nn.Module): else: self.group_norm = None - if cross_attention_norm: + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": self.norm_cross = nn.LayerNorm(cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) @@ -304,6 +324,25 @@ class Attention(nn.Module): attention_mask = attention_mask.repeat_interleave(head_size, dim=0) return attention_mask + def norm_encoder_hidden_states(self, encoder_hidden_states): + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + class AttnProcessor: def __call__( @@ -321,8 +360,8 @@ class AttnProcessor: if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -388,7 +427,10 @@ class LoRAAttnProcessor(nn.Module): query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) @@ -416,6 +458,11 @@ class AttnAddedKVProcessor: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) @@ -467,8 +514,8 @@ class XFormersAttnProcessor: if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -511,8 +558,8 @@ class AttnProcessor2_0: if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -561,7 +608,10 @@ class LoRAXFormersAttnProcessor(nn.Module): query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query).contiguous() - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) @@ -598,8 +648,8 @@ class SlicedAttnProcessor: if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -647,6 +697,11 @@ class SlicedAttnAddedKVProcessor: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 540059b107..08578c8109 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -44,6 +44,7 @@ def get_down_block( resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, + cross_attention_norm=None, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -126,6 +127,7 @@ def get_down_block( skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -223,6 +225,7 @@ def get_up_block( resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, + cross_attention_norm=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -293,6 +296,7 @@ def get_up_block( skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif up_block_type == "AttnUpBlock2D": return AttnUpBlock2D( @@ -578,6 +582,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): cross_attention_dim=1280, skip_time_act=False, only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() @@ -618,6 +623,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, processor=AttnAddedKVProcessor(), ) ) @@ -1361,6 +1367,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): add_downsample=True, skip_time_act=False, only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() @@ -1400,6 +1407,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, processor=AttnAddedKVProcessor(), ) ) @@ -1580,7 +1588,7 @@ class KCrossAttnDownBlock2D(nn.Module): temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, - cross_attention_norm=True, + cross_attention_norm="layer_norm", group_size=resnet_group_size, ) ) @@ -2361,6 +2369,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): add_upsample=True, skip_time_act=False, only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() resnets = [] @@ -2401,6 +2410,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, processor=AttnAddedKVProcessor(), ) ) @@ -2608,7 +2618,7 @@ class KCrossAttnUpBlock2D(nn.Module): temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, - cross_attention_norm=True, + cross_attention_norm="layer_norm", upcast_attention=upcast_attention, ) ) @@ -2703,7 +2713,7 @@ class KAttentionBlock(nn.Module): upcast_attention: bool = False, temb_channels: int = 768, # for ada_group_norm add_self_attention: bool = False, - cross_attention_norm: bool = False, + cross_attention_norm: Optional[str] = None, group_size: int = 32, ): super().__init__() @@ -2719,7 +2729,7 @@ class KAttentionBlock(nn.Module): dropout=dropout, bias=attention_bias, cross_attention_dim=None, - cross_attention_norm=False, + cross_attention_norm=None, ) # 2. Cross-Attn diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 9243dc66d3..1b982aedc5 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -169,6 +169,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, ): super().__init__() @@ -341,6 +342,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.down_blocks.append(down_block) @@ -373,6 +375,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None @@ -424,6 +427,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index e457ad2b3a..0239c81281 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -243,8 +243,8 @@ class Pix2PixZeroAttnProcessor: if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 0638822847..c6d67c6148 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -65,8 +65,8 @@ class CrossAttnStoreProcessor: if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index cc8cde4daa..4c0a4d89dc 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -255,6 +255,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, ): super().__init__() @@ -433,6 +434,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.down_blocks.append(down_block) @@ -465,6 +467,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None @@ -516,6 +519,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -1511,6 +1515,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): cross_attention_dim=1280, skip_time_act=False, only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() @@ -1551,6 +1556,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, processor=AttnAddedKVProcessor(), ) )