1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add encoder_hid_dim to unet

`encoder_hid_dim` provides an additional projection for the input `encoder_hidden_states` from `encoder_hidden_dim` to `cross_attention_dim`
This commit is contained in:
William Berman
2023-04-08 22:27:30 -07:00
committed by Will Berman
parent 8db5e5b37d
commit c413353e8e
2 changed files with 22 additions and 0 deletions

View File

@@ -88,6 +88,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
encoder_hid_dim (`int`, *optional*, defaults to None):
If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
@@ -139,6 +141,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
encoder_hid_dim: Optional[int] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
@@ -224,6 +227,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
cond_proj_dim=time_cond_proj_dim,
)
if encoder_hid_dim is not None:
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
else:
self.encoder_hid_proj = None
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
@@ -626,6 +634,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
else:
emb = emb + class_emb
if self.encoder_hid_proj is not None:
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
# 2. pre-process
sample = self.conv_in(sample)

View File

@@ -169,6 +169,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
encoder_hid_dim (`int`, *optional*, defaults to None):
If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
@@ -225,6 +227,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
encoder_hid_dim: Optional[int] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
@@ -316,6 +319,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
cond_proj_dim=time_cond_proj_dim,
)
if encoder_hid_dim is not None:
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
else:
self.encoder_hid_proj = None
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
@@ -718,6 +726,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
else:
emb = emb + class_emb
if self.encoder_hid_proj is not None:
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
# 2. pre-process
sample = self.conv_in(sample)