mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Make cross-attention check more robust (#1560)
* Make cross-attention check more robust. * Fix copies.
This commit is contained in:
@@ -343,6 +343,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
@@ -526,6 +527,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
@@ -1110,6 +1112,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
|
||||
@@ -377,7 +377,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
@@ -403,7 +403,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
|
||||
@@ -455,7 +455,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
@@ -481,7 +481,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
@@ -726,6 +726,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
@@ -924,6 +925,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
@@ -1043,6 +1045,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
Reference in New Issue
Block a user