diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index bb29a48bb7..7d12c81395 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -57,9 +57,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to : obj:`(32, 32, 64)`): Tuple of block output channels. - layers_per_block (`int`, defaults to 2): - number of resnet, attention, or other layers in the up and down blocks. - mid_block_layers (`int`, defaults to 5): number of resnet, attention, or other layers in the mid block. + layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet. act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks. @@ -85,8 +83,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): mid_block_type: Tuple[str] = "UNetMidBlock1D", out_block_type: str = None, block_out_channels: Tuple[int] = (32, 32, 64), - layers_per_block: int = 2, - mid_block_layers: int = 5, + layers_per_block: int = 1, act_fn: str = None, norm_num_groups: int = 8, downsample_each_block: bool = False, @@ -133,9 +130,9 @@ class UNet1DModel(ModelMixin, ConfigMixin): down_block = get_down_block( down_block_type, + num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, - num_layers=layers_per_block, temb_channels=block_out_channels[0], add_downsample=not is_final_block or downsample_each_block, ) @@ -145,10 +142,10 @@ class UNet1DModel(ModelMixin, ConfigMixin): self.mid_block = get_mid_block( mid_block_type, in_channels=block_out_channels[-1], - num_layers=mid_block_layers, mid_channels=block_out_channels[-1], out_channels=block_out_channels[-1], embed_dim=block_out_channels[0], + num_layers=layers_per_block, add_downsample=downsample_each_block, ) @@ -170,9 +167,9 @@ class UNet1DModel(ModelMixin, ConfigMixin): up_block = get_up_block( up_block_type, + num_layers=layers_per_block, in_channels=prev_output_channel, out_channels=output_channel, - num_layers=layers_per_block, temb_channels=block_out_channels[0], add_upsample=not is_final_block, ) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 11c3228f75..45ef3df36d 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -429,29 +429,29 @@ class ResConvBlock(nn.Module): class UNetMidBlock1D(nn.Module): - def __init__(self, mid_channels: int, in_channels: int, num_layers: int = 5, out_channels: int = None): + def __init__(self, mid_channels: int, in_channels: int, out_channels: int = None): super().__init__() - if num_layers < 1: - raise ValueError("UNetMidBlock1D requires added num_layers >= 1") - out_channels = in_channels if out_channels is None else out_channels # there is always at least one resnet self.down = KernelDownsample1D("cubic") resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), ] attentions = [ SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), ] - for i in range(num_layers): - if i < (num_layers - 1): - resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels)) - attentions.append(SelfAttention1d(mid_channels, mid_channels // 32)) - else: - resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels)) - attentions.append(SelfAttention1d(out_channels, out_channels // 32)) self.up = KernelUpsample1D(kernel="cubic") self.attentions = nn.ModuleList(attentions) diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 7f61cbfb03..089d935651 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -224,7 +224,6 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): "mid_block_type": "ValueFunctionMidBlock1D", "block_out_channels": [32, 64, 128, 256], "layers_per_block": 1, - "mid_block_layers": 1, "downsample_each_block": True, "use_timestep_embedding": True, "freq_shift": 1.0,