From 4c8a05f1159b6bacb78b608f32fd97ffe80ea59d Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Fri, 15 Sep 2023 03:27:51 -0700 Subject: [PATCH] Fix Consistency Models UNet2DMidBlock2D Attention GroupNorm Bug (#4863) * Add attn_groups argument to UNet2DMidBlock2D to control theinternal Attention block's GroupNorm. * Add docstring for attn_norm_num_groups in UNet2DModel. * Since the test UNet config uses resnet_time_scale_shift == 'scale_shift', also set attn_norm_num_groups to 32. * Add test for attn_norm_num_groups to UNet2DModelTests. * Fix expected slices for slow tests. * Also fix tolerances for slow tests. --------- Co-authored-by: Sayak Paul --- scripts/convert_consistency_to_diffusers.py | 2 ++ src/diffusers/models/unet_2d.py | 6 ++++ src/diffusers/models/unet_2d_blocks.py | 6 +++- tests/models/test_models_unet_2d.py | 30 +++++++++++++++++++ .../test_consistency_models.py | 12 ++++---- 5 files changed, 49 insertions(+), 7 deletions(-) diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index 5a6158bb98..0f8b4ddca8 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -27,6 +27,7 @@ TEST_UNET_CONFIG = { "ResnetUpsampleBlock2D", ], "resnet_time_scale_shift": "scale_shift", + "attn_norm_num_groups": 32, "upsample_type": "resnet", "downsample_type": "resnet", } @@ -52,6 +53,7 @@ IMAGENET_64_UNET_CONFIG = { "ResnetUpsampleBlock2D", ], "resnet_time_scale_shift": "scale_shift", + "attn_norm_num_groups": 32, "upsample_type": "resnet", "downsample_type": "resnet", } diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 3058dad3d6..db6d3a5dce 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -74,6 +74,10 @@ class UNet2DModel(ModelMixin, ConfigMixin): act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. + attn_norm_num_groups (`int`, *optional*, defaults to `None`): + If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the + given number of groups. If left as `None`, the group norm layer will only be created if + `resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups. norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. @@ -107,6 +111,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): act_fn: str = "silu", attention_head_dim: Optional[int] = 8, norm_num_groups: int = 32, + attn_norm_num_groups: Optional[int] = None, norm_eps: float = 1e-5, resnet_time_scale_shift: str = "default", add_attention: bool = True, @@ -192,6 +197,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): resnet_time_scale_shift=resnet_time_scale_shift, attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], resnet_groups=norm_num_groups, + attn_groups=attn_norm_num_groups, add_attention=add_attention, ) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index d6066e92b7..8aebb3aad6 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -485,6 +485,7 @@ class UNetMidBlock2D(nn.Module): resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, + attn_groups: Optional[int] = None, resnet_pre_norm: bool = True, add_attention: bool = True, attention_head_dim=1, @@ -494,6 +495,9 @@ class UNetMidBlock2D(nn.Module): resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + # there is always at least one resnet resnets = [ ResnetBlock2D( @@ -526,7 +530,7 @@ class UNetMidBlock2D(nn.Module): dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, - norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, + norm_num_groups=attn_groups, spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, residual_connection=True, bias=True, diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index c5289a54b4..4fd991b3fc 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -74,6 +74,36 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): inputs_dict = self.dummy_input return init_dict, inputs_dict + def test_mid_block_attn_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 16 + init_dict["add_attention"] = True + init_dict["attn_norm_num_groups"] = 8 + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + self.assertIsNotNone( + model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not." + ) + self.assertEqual( + model.mid_block.attentions[0].group_norm.num_groups, + init_dict["attn_norm_num_groups"], + "Mid block Attention group norm does not have the expected number of groups.", + ) + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DModel diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 59be333b62..2cf7c0adb4 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -216,9 +216,9 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase): image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.0888, 0.0881, 0.0666, 0.0479, 0.0292, 0.0195, 0.0201, 0.0163, 0.0254]) + expected_slice = np.array([0.0146, 0.0158, 0.0092, 0.0086, 0.0000, 0.0000, 0.0000, 0.0000, 0.0058]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 def test_consistency_model_cd_onestep(self): unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2") @@ -239,9 +239,9 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase): image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.0340, 0.0152, 0.0063, 0.0267, 0.0221, 0.0107, 0.0416, 0.0186, 0.0217]) + expected_slice = np.array([0.0059, 0.0003, 0.0000, 0.0023, 0.0052, 0.0007, 0.0165, 0.0081, 0.0095]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @require_torch_2 def test_consistency_model_cd_multistep_flash_attn(self): @@ -263,7 +263,7 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase): image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.1875, 0.1428, 0.1289, 0.2151, 0.2092, 0.1477, 0.1877, 0.1641, 0.1353]) + expected_slice = np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -289,6 +289,6 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase): image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.1663, 0.1948, 0.2275, 0.1680, 0.1204, 0.1245, 0.1858, 0.1338, 0.2095]) + expected_slice = np.array([0.1623, 0.2009, 0.2387, 0.1731, 0.1168, 0.1202, 0.2031, 0.1327, 0.2447]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3