diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 7d12c81395..29d1d707f5 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -57,11 +57,11 @@ class UNet1DModel(ModelMixin, ConfigMixin): obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to : obj:`(32, 32, 64)`): Tuple of block output channels. - layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet. act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks. norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. + layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. downsample_each_block (`int`, *optional*, defaults to False: experimental feature for using a UNet without upsampling. """ @@ -83,9 +83,9 @@ class UNet1DModel(ModelMixin, ConfigMixin): mid_block_type: Tuple[str] = "UNetMidBlock1D", out_block_type: str = None, block_out_channels: Tuple[int] = (32, 32, 64), - layers_per_block: int = 1, act_fn: str = None, norm_num_groups: int = 8, + layers_per_block: int = 1, downsample_each_block: bool = False, ): super().__init__() diff --git a/tests/pipelines/dance_diffusion/test_dance_diffusion.py b/tests/pipelines/dance_diffusion/test_dance_diffusion.py index c38a2e3608..1779279ba6 100644 --- a/tests/pipelines/dance_diffusion/test_dance_diffusion.py +++ b/tests/pipelines/dance_diffusion/test_dance_diffusion.py @@ -118,5 +118,5 @@ class PipelineIntegrationTests(unittest.TestCase): audio_slice = audio[0, -3:, -3:] assert audio.shape == (1, 2, pipe.unet.sample_size) - expected_slice = np.array([-0.1693, -0.1698, -0.1447, -0.3044, -0.3203, -0.2937]) + expected_slice = np.array([-0.1576, -0.1526, -0.127, -0.2699, -0.2762, -0.2487]) assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2