mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix Consistency Models UNet2DMidBlock2D Attention GroupNorm Bug (#4863)
* Add attn_groups argument to UNet2DMidBlock2D to control theinternal Attention block's GroupNorm. * Add docstring for attn_norm_num_groups in UNet2DModel. * Since the test UNet config uses resnet_time_scale_shift == 'scale_shift', also set attn_norm_num_groups to 32. * Add test for attn_norm_num_groups to UNet2DModelTests. * Fix expected slices for slow tests. * Also fix tolerances for slow tests. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -27,6 +27,7 @@ TEST_UNET_CONFIG = {
|
||||
"ResnetUpsampleBlock2D",
|
||||
],
|
||||
"resnet_time_scale_shift": "scale_shift",
|
||||
"attn_norm_num_groups": 32,
|
||||
"upsample_type": "resnet",
|
||||
"downsample_type": "resnet",
|
||||
}
|
||||
@@ -52,6 +53,7 @@ IMAGENET_64_UNET_CONFIG = {
|
||||
"ResnetUpsampleBlock2D",
|
||||
],
|
||||
"resnet_time_scale_shift": "scale_shift",
|
||||
"attn_norm_num_groups": 32,
|
||||
"upsample_type": "resnet",
|
||||
"downsample_type": "resnet",
|
||||
}
|
||||
|
||||
@@ -74,6 +74,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
|
||||
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
|
||||
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
|
||||
given number of groups. If left as `None`, the group norm layer will only be created if
|
||||
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
|
||||
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
||||
@@ -107,6 +111,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
act_fn: str = "silu",
|
||||
attention_head_dim: Optional[int] = 8,
|
||||
norm_num_groups: int = 32,
|
||||
attn_norm_num_groups: Optional[int] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
add_attention: bool = True,
|
||||
@@ -192,6 +197,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_groups=attn_norm_num_groups,
|
||||
add_attention=add_attention,
|
||||
)
|
||||
|
||||
|
||||
@@ -485,6 +485,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
resnet_time_scale_shift: str = "default", # default, spatial
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
attn_groups: Optional[int] = None,
|
||||
resnet_pre_norm: bool = True,
|
||||
add_attention: bool = True,
|
||||
attention_head_dim=1,
|
||||
@@ -494,6 +495,9 @@ class UNetMidBlock2D(nn.Module):
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
self.add_attention = add_attention
|
||||
|
||||
if attn_groups is None:
|
||||
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
@@ -526,7 +530,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
|
||||
norm_num_groups=attn_groups,
|
||||
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
|
||||
residual_connection=True,
|
||||
bias=True,
|
||||
|
||||
@@ -74,6 +74,36 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_mid_block_attn_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["add_attention"] = True
|
||||
init_dict["attn_norm_num_groups"] = 8
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
self.assertIsNotNone(
|
||||
model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
|
||||
)
|
||||
self.assertEqual(
|
||||
model.mid_block.attentions[0].group_norm.num_groups,
|
||||
init_dict["attn_norm_num_groups"],
|
||||
"Mid block Attention group norm does not have the expected number of groups.",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
|
||||
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
|
||||
@@ -216,9 +216,9 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
expected_slice = np.array([0.0888, 0.0881, 0.0666, 0.0479, 0.0292, 0.0195, 0.0201, 0.0163, 0.0254])
|
||||
expected_slice = np.array([0.0146, 0.0158, 0.0092, 0.0086, 0.0000, 0.0000, 0.0000, 0.0000, 0.0058])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_consistency_model_cd_onestep(self):
|
||||
unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2")
|
||||
@@ -239,9 +239,9 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
expected_slice = np.array([0.0340, 0.0152, 0.0063, 0.0267, 0.0221, 0.0107, 0.0416, 0.0186, 0.0217])
|
||||
expected_slice = np.array([0.0059, 0.0003, 0.0000, 0.0023, 0.0052, 0.0007, 0.0165, 0.0081, 0.0095])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
@require_torch_2
|
||||
def test_consistency_model_cd_multistep_flash_attn(self):
|
||||
@@ -263,7 +263,7 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
expected_slice = np.array([0.1875, 0.1428, 0.1289, 0.2151, 0.2092, 0.1477, 0.1877, 0.1641, 0.1353])
|
||||
expected_slice = np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
@@ -289,6 +289,6 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
expected_slice = np.array([0.1663, 0.1948, 0.2275, 0.1680, 0.1204, 0.1245, 0.1858, 0.1338, 0.2095])
|
||||
expected_slice = np.array([0.1623, 0.2009, 0.2387, 0.1731, 0.1168, 0.1202, 0.2031, 0.1327, 0.2447])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
Reference in New Issue
Block a user