1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Make cross-attention check more robust (#1560)

* Make cross-attention check more robust.

* Fix copies.
This commit is contained in:
Pedro Cuenca
2022-12-07 18:33:29 +01:00
committed by GitHub
parent bea7eb4314
commit 5e0369219f
3 changed files with 10 additions and 4 deletions

View File

@@ -343,6 +343,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
):
super().__init__()
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
@@ -526,6 +527,7 @@ class CrossAttnDownBlock2D(nn.Module):
resnets = []
attentions = []
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
@@ -1110,6 +1112,7 @@ class CrossAttnUpBlock2D(nn.Module):
resnets = []
attentions = []
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels

View File

@@ -377,7 +377,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
@@ -403,7 +403,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,

View File

@@ -455,7 +455,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
@@ -481,7 +481,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
@@ -726,6 +726,7 @@ class CrossAttnDownBlockFlat(nn.Module):
resnets = []
attentions = []
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
@@ -924,6 +925,7 @@ class CrossAttnUpBlockFlat(nn.Module):
resnets = []
attentions = []
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
@@ -1043,6 +1045,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
):
super().__init__()
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)