From 5e0369219f285a97db2c4ca2153defea9cc4f177 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 7 Dec 2022 18:33:29 +0100 Subject: [PATCH] Make cross-attention check more robust (#1560) * Make cross-attention check more robust. * Fix copies. --- src/diffusers/models/unet_2d_blocks.py | 3 +++ src/diffusers/models/unet_2d_condition.py | 4 ++-- .../pipelines/versatile_diffusion/modeling_text_unet.py | 7 +++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 726f050e65..aa8d4c9849 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -343,6 +343,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ): super().__init__() + self.has_cross_attention = True self.attention_type = attention_type self.attn_num_head_channels = attn_num_head_channels resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) @@ -526,6 +527,7 @@ class CrossAttnDownBlock2D(nn.Module): resnets = [] attentions = [] + self.has_cross_attention = True self.attention_type = attention_type self.attn_num_head_channels = attn_num_head_channels @@ -1110,6 +1112,7 @@ class CrossAttnUpBlock2D(nn.Module): resnets = [] attentions = [] + self.has_cross_attention = True self.attention_type = attention_type self.attn_num_head_channels = attn_num_head_channels diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 7d6db4aba6..0cfb152249 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -377,7 +377,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: - if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, @@ -403,7 +403,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] - if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index f1cf46aaf6..0c30f21eb6 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -455,7 +455,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: - if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, @@ -481,7 +481,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] - if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, @@ -726,6 +726,7 @@ class CrossAttnDownBlockFlat(nn.Module): resnets = [] attentions = [] + self.has_cross_attention = True self.attention_type = attention_type self.attn_num_head_channels = attn_num_head_channels @@ -924,6 +925,7 @@ class CrossAttnUpBlockFlat(nn.Module): resnets = [] attentions = [] + self.has_cross_attention = True self.attention_type = attention_type self.attn_num_head_channels = attn_num_head_channels @@ -1043,6 +1045,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ): super().__init__() + self.has_cross_attention = True self.attention_type = attention_type self.attn_num_head_channels = attn_num_head_channels resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)