mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add encoder_hid_dim to unet
`encoder_hid_dim` provides an additional projection for the input `encoder_hidden_states` from `encoder_hidden_dim` to `cross_attention_dim`
This commit is contained in:
committed by
Will Berman
parent
8db5e5b37d
commit
c413353e8e
@@ -88,6 +88,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`.
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||
@@ -139,6 +141,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
@@ -224,6 +227,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
cond_proj_dim=time_cond_proj_dim,
|
||||
)
|
||||
|
||||
if encoder_hid_dim is not None:
|
||||
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
||||
else:
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
@@ -626,6 +634,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
else:
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.encoder_hid_proj is not None:
|
||||
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
|
||||
@@ -169,6 +169,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`.
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
|
||||
@@ -225,6 +227,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
@@ -316,6 +319,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
cond_proj_dim=time_cond_proj_dim,
|
||||
)
|
||||
|
||||
if encoder_hid_dim is not None:
|
||||
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
||||
else:
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
@@ -718,6 +726,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.encoder_hid_proj is not None:
|
||||
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user