1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix Consistency Models UNet2DMidBlock2D Attention GroupNorm Bug (#4863)

* Add attn_groups argument to UNet2DMidBlock2D to control theinternal Attention block's GroupNorm.

* Add docstring for attn_norm_num_groups in UNet2DModel.

* Since the test UNet config uses resnet_time_scale_shift == 'scale_shift', also set attn_norm_num_groups to 32.

* Add test for attn_norm_num_groups to UNet2DModelTests.

* Fix expected slices for slow tests.

* Also fix tolerances for slow tests.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
dg845
2023-09-15 03:27:51 -07:00
committed by GitHub
parent 5fd42e5d61
commit 4c8a05f115
5 changed files with 49 additions and 7 deletions

View File

@@ -27,6 +27,7 @@ TEST_UNET_CONFIG = {
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
"attn_norm_num_groups": 32,
"upsample_type": "resnet",
"downsample_type": "resnet",
}
@@ -52,6 +53,7 @@ IMAGENET_64_UNET_CONFIG = {
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
"attn_norm_num_groups": 32,
"upsample_type": "resnet",
"downsample_type": "resnet",
}