From bb2c64a08c181b450afe61dd88b2f0a575bc414b Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Thu, 24 Nov 2022 21:57:27 +0100 Subject: [PATCH] Add the new SD2 attention params to the VD text unet (#1400) --- .../versatile_diffusion/modeling_text_unet.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 24e79729a5..e3c35dcb38 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -28,7 +28,9 @@ def get_down_block( resnet_groups=None, cross_attention_dim=None, downsample_padding=None, - dual_cross_attention=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlockFlat": @@ -58,6 +60,9 @@ def get_down_block( downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, ) raise ValueError(f"{down_block_type} is not supported.") @@ -75,7 +80,9 @@ def get_up_block( attn_num_head_channels, resnet_groups=None, cross_attention_dim=None, - dual_cross_attention=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlockFlat": @@ -105,6 +112,9 @@ def get_up_block( resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, ) raise ValueError(f"{up_block_type} is not supported.")