From 64c5688284addc75619647d7a4e39d3c55b1f808 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 14 Nov 2022 14:12:47 -0800 Subject: [PATCH] revert change to make less breaking --- src/diffusers/models/unet_1d.py | 8 ++++---- tests/models/test_models_unet_1d.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index dcfe75cf87..bb29a48bb7 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -57,7 +57,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to : obj:`(32, 32, 64)`): Tuple of block output channels. - up_down_block_layers (`int`, defaults to 2): + layers_per_block (`int`, defaults to 2): number of resnet, attention, or other layers in the up and down blocks. mid_block_layers (`int`, defaults to 5): number of resnet, attention, or other layers in the mid block. mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. @@ -85,7 +85,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): mid_block_type: Tuple[str] = "UNetMidBlock1D", out_block_type: str = None, block_out_channels: Tuple[int] = (32, 32, 64), - up_down_block_layers: int = 2, + layers_per_block: int = 2, mid_block_layers: int = 5, act_fn: str = None, norm_num_groups: int = 8, @@ -135,7 +135,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): down_block_type, in_channels=input_channel, out_channels=output_channel, - num_layers=up_down_block_layers, + num_layers=layers_per_block, temb_channels=block_out_channels[0], add_downsample=not is_final_block or downsample_each_block, ) @@ -172,7 +172,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): up_block_type, in_channels=prev_output_channel, out_channels=output_channel, - num_layers=up_down_block_layers, + num_layers=layers_per_block, temb_channels=block_out_channels[0], add_upsample=not is_final_block, ) diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 41c4fdecfa..202936324a 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -224,6 +224,7 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): "mid_block_type": "ValueFunctionMidBlock1D", "block_out_channels": [32, 64, 128, 256], "layers_per_block": 1, + "mid_block_layers": 1, "downsample_each_block": True, "use_timestep_embedding": True, "freq_shift": 1.0,