From ae84e405a3f54518e1c278080ef03733144cb16f Mon Sep 17 00:00:00 2001 From: Stephen Date: Mon, 26 Feb 2024 00:19:14 -0500 Subject: [PATCH] Pass use_linear_projection parameter to mid block in UNetMotionModel (#7035) * pass linear projection parameter to mid block * add cond_proj_dim to motion UNet * run style and quality checks --- src/diffusers/models/unets/unet_motion_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 9cb0f42c85..246a4b8124 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -217,6 +217,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): use_motion_mid_block: int = True, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, ): super().__init__() @@ -252,9 +253,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, + timestep_input_dim, time_embed_dim, act_fn=act_fn, cond_proj_dim=time_cond_proj_dim ) if encoder_hid_dim_type is None: @@ -306,6 +305,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, + use_linear_projection=use_linear_projection, temporal_num_attention_heads=motion_num_attention_heads, temporal_max_seq_length=motion_max_seq_length, ) @@ -321,6 +321,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, + use_linear_projection=use_linear_projection, ) # count how many layers upsample the images