1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Nathan Lambert
2022-10-26 16:00:39 -07:00
parent 5df4c8b81f
commit 741122e722
2 changed files with 5 additions and 2 deletions

View File

@@ -43,7 +43,8 @@ class UNet1DModel(ModelMixin, ConfigMixin):
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(32, 32, 64)`): Tuple of block output channels.
up_down_block_layers (`int`, defaults to 2): number of resnet, attention, or other layers in the up and down blocks.
up_down_block_layers (`int`, defaults to 2):
number of resnet, attention, or other layers in the up and down blocks.
mid_block_layers (`int`, defaults to 5): number of resnet, attention, or other layers in the mid block.
"""

View File

@@ -200,7 +200,9 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels):
def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels):
if mid_block_type == "UNetMidBlock1D":
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels, num_layers=num_layers)
return UNetMidBlock1D(
in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels, num_layers=num_layers
)
raise ValueError(f"{mid_block_type} does not exist.")