From dc7cd893fdee3906b3223a623b0f6d884a2df7c4 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Mon, 19 Dec 2022 12:01:46 +0100 Subject: [PATCH] Add resnet_time_scale_shift to VD layers (#1757) --- .../pipelines/versatile_diffusion/modeling_text_unet.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 2d35bced7d..0bf6cfd586 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -33,6 +33,7 @@ def get_down_block( use_linear_projection=False, only_cross_attention=False, upcast_attention=False, + resnet_time_scale_shift="default", ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlockFlat": @@ -46,6 +47,7 @@ def get_down_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "CrossAttnDownBlockFlat": if cross_attention_dim is None: @@ -65,6 +67,7 @@ def get_down_block( dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, + resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{down_block_type} is not supported.") @@ -86,6 +89,7 @@ def get_up_block( use_linear_projection=False, only_cross_attention=False, upcast_attention=False, + resnet_time_scale_shift="default", ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlockFlat": @@ -99,6 +103,7 @@ def get_up_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "CrossAttnUpBlockFlat": if cross_attention_dim is None: @@ -118,6 +123,7 @@ def get_up_block( dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, + resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{up_block_type} is not supported.")