mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
adding required parameters while calling the get_up_block and get_down_block (#3210)
* removed unnecessary parameters from get_up_block and get_down_block functions * adding resnet_skip_time_act, resnet_out_scale_factor and cross_attention_norm to get_up_block and get_down_block functions --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -42,6 +42,9 @@ def get_down_block(
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_skip_time_act=False,
|
||||
resnet_out_scale_factor=1.0,
|
||||
cross_attention_norm=None,
|
||||
):
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
if down_block_type == "DownBlockFlat":
|
||||
@@ -98,6 +101,9 @@ def get_up_block(
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_skip_time_act=False,
|
||||
resnet_out_scale_factor=1.0,
|
||||
cross_attention_norm=None,
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlockFlat":
|
||||
|
||||
Reference in New Issue
Block a user