From d36103a0897248ff288c3dff84991e18c6cc34a5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 31 Mar 2023 15:20:46 +0200 Subject: [PATCH] [Tests] Speed up test (#2919) speed up test --- tests/models/test_models_unet_3d_condition.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index ea71ae4af2..729367a0c1 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -88,19 +88,17 @@ class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase): def prepare_init_args_and_inputs_for_common(self): init_dict = { - "block_out_channels": (32, 64, 64, 64), + "block_out_channels": (32, 64), "down_block_types": ( - "CrossAttnDownBlock3D", - "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), - "up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + "up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D"), "cross_attention_dim": 32, - "attention_head_dim": 4, + "attention_head_dim": 8, "out_channels": 4, "in_channels": 4, - "layers_per_block": 2, + "layers_per_block": 1, "sample_size": 32, } inputs_dict = self.dummy_input