From 405a1facd20a673833d5972c6d0b36e538f411cd Mon Sep 17 00:00:00 2001 From: Yuanhao Zhai Date: Wed, 20 Mar 2024 22:16:32 -0400 Subject: [PATCH] fix: enable unet_3d_condition to support time_cond_proj_dim (#7364) Co-authored-by: Sayak Paul --- src/diffusers/models/unets/unet_3d_condition.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index b7641a96a7..a827b4ddc5 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -91,6 +91,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. """ _supports_gradient_checkpointing = False @@ -123,6 +125,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) cross_attention_dim: int = 1024, attention_head_dim: Union[int, Tuple[int]] = 64, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + time_cond_proj_dim: Optional[int] = None, ): super().__init__() @@ -174,6 +177,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) timestep_input_dim, time_embed_dim, act_fn=act_fn, + cond_proj_dim=time_cond_proj_dim, ) self.transformer_in = TransformerTemporalModel(