1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

adding required parameters while calling the get_up_block and get_down_block (#3210)

* removed unnecessary parameters from get_up_block and get_down_block functions

* adding resnet_skip_time_act, resnet_out_scale_factor and cross_attention_norm to get_up_block and get_down_block functions

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Isaac
2023-04-27 17:01:43 +05:30
committed by Daniel Gu
parent 97cf3866db
commit cf2bf70a4d

View File

@@ -42,6 +42,9 @@ def get_down_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlockFlat":
@@ -98,6 +101,9 @@ def get_up_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlockFlat":