From cf2bf70a4d19797f16063100ade9f321efbb6bff Mon Sep 17 00:00:00 2001 From: Isaac <34376531+init-22@users.noreply.github.com> Date: Thu, 27 Apr 2023 17:01:43 +0530 Subject: [PATCH] adding required parameters while calling the get_up_block and get_down_block (#3210) * removed unnecessary parameters from get_up_block and get_down_block functions * adding resnet_skip_time_act, resnet_out_scale_factor and cross_attention_norm to get_up_block and get_down_block functions --------- Co-authored-by: Sayak Paul --- .../pipelines/versatile_diffusion/modeling_text_unet.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 57e1abc731..0959e2bb3a 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -42,6 +42,9 @@ def get_down_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlockFlat": @@ -98,6 +101,9 @@ def get_up_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlockFlat":