mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix Consistency Models UNet2DMidBlock2D Attention GroupNorm Bug (#4863)
* Add attn_groups argument to UNet2DMidBlock2D to control theinternal Attention block's GroupNorm. * Add docstring for attn_norm_num_groups in UNet2DModel. * Since the test UNet config uses resnet_time_scale_shift == 'scale_shift', also set attn_norm_num_groups to 32. * Add test for attn_norm_num_groups to UNet2DModelTests. * Fix expected slices for slow tests. * Also fix tolerances for slow tests. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -27,6 +27,7 @@ TEST_UNET_CONFIG = {
|
||||
"ResnetUpsampleBlock2D",
|
||||
],
|
||||
"resnet_time_scale_shift": "scale_shift",
|
||||
"attn_norm_num_groups": 32,
|
||||
"upsample_type": "resnet",
|
||||
"downsample_type": "resnet",
|
||||
}
|
||||
@@ -52,6 +53,7 @@ IMAGENET_64_UNET_CONFIG = {
|
||||
"ResnetUpsampleBlock2D",
|
||||
],
|
||||
"resnet_time_scale_shift": "scale_shift",
|
||||
"attn_norm_num_groups": 32,
|
||||
"upsample_type": "resnet",
|
||||
"downsample_type": "resnet",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user