diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 37c3189312..34c60110fa 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -43,7 +43,8 @@ class UNet1DModel(ModelMixin, ConfigMixin): obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to : obj:`(32, 32, 64)`): Tuple of block output channels. - up_down_block_layers (`int`, defaults to 2): number of resnet, attention, or other layers in the up and down blocks. + up_down_block_layers (`int`, defaults to 2): + number of resnet, attention, or other layers in the up and down blocks. mid_block_layers (`int`, defaults to 5): number of resnet, attention, or other layers in the mid block. """ diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 71a9b568c1..7b2ef2ef7d 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -200,7 +200,9 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels): def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels): if mid_block_type == "UNetMidBlock1D": - return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels, num_layers=num_layers) + return UNetMidBlock1D( + in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels, num_layers=num_layers + ) raise ValueError(f"{mid_block_type} does not exist.")