From 6f2b310a1765177f0d7e9b4b6c8bcfe7e5d3a8a8 Mon Sep 17 00:00:00 2001 From: Xiaodong Wang <44627574+Wang-Xiaodong1899@users.noreply.github.com> Date: Sat, 9 Mar 2024 12:29:06 +0800 Subject: [PATCH] [UNet_Spatio_Temporal_Condition] fix default num_attention_heads in unet_spatio_temporal_condition (#7205) fix default num_attention_heads in unet_spatio_temporal_condition Co-authored-by: Sayak Paul --- src/diffusers/models/unets/unet_spatio_temporal_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index 39a8009d5a..5fe265e63f 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -90,7 +90,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL layers_per_block: Union[int, Tuple[int]] = 2, cross_attention_dim: Union[int, Tuple[int]] = 1024, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), + num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20), num_frames: int = 25, ): super().__init__()