diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 0508cc6f2e..72fc0b519e 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -88,6 +88,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. + encoder_hid_dim (`int`, *optional*, defaults to None): + If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. @@ -139,6 +141,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, + encoder_hid_dim: Optional[int] = None, attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, use_linear_projection: bool = False, @@ -224,6 +227,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) cond_proj_dim=time_cond_proj_dim, ) + if encoder_hid_dim is not None: + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + else: + self.encoder_hid_proj = None + # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) @@ -626,6 +634,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) else: emb = emb + class_emb + if self.encoder_hid_proj is not None: + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 1427f23636..7d68f6f06e 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -169,6 +169,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. + encoder_hid_dim (`int`, *optional*, defaults to None): + If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. @@ -225,6 +227,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, + encoder_hid_dim: Optional[int] = None, attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, use_linear_projection: bool = False, @@ -316,6 +319,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): cond_proj_dim=time_cond_proj_dim, ) + if encoder_hid_dim is not None: + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + else: + self.encoder_hid_proj = None + # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) @@ -718,6 +726,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): else: emb = emb + class_emb + if self.encoder_hid_proj is not None: + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + # 2. pre-process sample = self.conv_in(sample)