1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add the new SD2 attention params to the VD text unet (#1400)

This commit is contained in:
Anton Lozhkov
2022-11-24 21:57:27 +01:00
committed by GitHub
parent 05a36d5c1a
commit bb2c64a08c

View File

@@ -28,7 +28,9 @@ def get_down_block(
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
dual_cross_attention=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlockFlat":
@@ -58,6 +60,9 @@ def get_down_block(
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
raise ValueError(f"{down_block_type} is not supported.")
@@ -75,7 +80,9 @@ def get_up_block(
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlockFlat":
@@ -105,6 +112,9 @@ def get_up_block(
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
raise ValueError(f"{up_block_type} is not supported.")