1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix Consistency Models UNet2DMidBlock2D Attention GroupNorm Bug (#4863)

* Add attn_groups argument to UNet2DMidBlock2D to control theinternal Attention block's GroupNorm.

* Add docstring for attn_norm_num_groups in UNet2DModel.

* Since the test UNet config uses resnet_time_scale_shift == 'scale_shift', also set attn_norm_num_groups to 32.

* Add test for attn_norm_num_groups to UNet2DModelTests.

* Fix expected slices for slow tests.

* Also fix tolerances for slow tests.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
dg845
2023-09-15 03:27:51 -07:00
committed by GitHub
parent 5fd42e5d61
commit 4c8a05f115
5 changed files with 49 additions and 7 deletions

View File

@@ -27,6 +27,7 @@ TEST_UNET_CONFIG = {
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
"attn_norm_num_groups": 32,
"upsample_type": "resnet",
"downsample_type": "resnet",
}
@@ -52,6 +53,7 @@ IMAGENET_64_UNET_CONFIG = {
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
"attn_norm_num_groups": 32,
"upsample_type": "resnet",
"downsample_type": "resnet",
}

View File

@@ -74,6 +74,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
given number of groups. If left as `None`, the group norm layer will only be created if
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
@@ -107,6 +111,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
act_fn: str = "silu",
attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 32,
attn_norm_num_groups: Optional[int] = None,
norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default",
add_attention: bool = True,
@@ -192,6 +197,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)

View File

@@ -485,6 +485,7 @@ class UNetMidBlock2D(nn.Module):
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim=1,
@@ -494,6 +495,9 @@ class UNetMidBlock2D(nn.Module):
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
if attn_groups is None:
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
# there is always at least one resnet
resnets = [
ResnetBlock2D(
@@ -526,7 +530,7 @@ class UNetMidBlock2D(nn.Module):
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
norm_num_groups=attn_groups,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True,
bias=True,

View File

@@ -74,6 +74,36 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_mid_block_attn_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["add_attention"] = True
init_dict["attn_norm_num_groups"] = 8
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
self.assertIsNotNone(
model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
)
self.assertEqual(
model.mid_block.attentions[0].group_norm.num_groups,
init_dict["attn_norm_num_groups"],
"Mid block Attention group norm does not have the expected number of groups.",
)
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel

View File

@@ -216,9 +216,9 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0888, 0.0881, 0.0666, 0.0479, 0.0292, 0.0195, 0.0201, 0.0163, 0.0254])
expected_slice = np.array([0.0146, 0.0158, 0.0092, 0.0086, 0.0000, 0.0000, 0.0000, 0.0000, 0.0058])
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_consistency_model_cd_onestep(self):
unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2")
@@ -239,9 +239,9 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0340, 0.0152, 0.0063, 0.0267, 0.0221, 0.0107, 0.0416, 0.0186, 0.0217])
expected_slice = np.array([0.0059, 0.0003, 0.0000, 0.0023, 0.0052, 0.0007, 0.0165, 0.0081, 0.0095])
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@require_torch_2
def test_consistency_model_cd_multistep_flash_attn(self):
@@ -263,7 +263,7 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.1875, 0.1428, 0.1289, 0.2151, 0.2092, 0.1477, 0.1877, 0.1641, 0.1353])
expected_slice = np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@@ -289,6 +289,6 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.1663, 0.1948, 0.2275, 0.1680, 0.1204, 0.1245, 0.1858, 0.1338, 0.2095])
expected_slice = np.array([0.1623, 0.2009, 0.2387, 0.1731, 0.1168, 0.1202, 0.2031, 0.1327, 0.2447])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3