mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Tencent Hunyuan Team] Add HunyuanDiT-v1.2 Support (#8747)
* add v1.2 support --------- Co-authored-by: xingchaoliu <xingchaoliu@tencent.com> Co-authored-by: yiyixuxu <yixu310@gmail.com>
This commit is contained in:
@@ -717,7 +717,14 @@ class HunyuanDiTAttentionPool(nn.Module):
|
||||
|
||||
|
||||
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim,
|
||||
pooled_projection_dim=1024,
|
||||
seq_len=256,
|
||||
cross_attention_dim=2048,
|
||||
use_style_cond_and_image_meta_size=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
@@ -726,9 +733,15 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
||||
self.pooler = HunyuanDiTAttentionPool(
|
||||
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
|
||||
)
|
||||
|
||||
# Here we use a default learned embedder layer for future extension.
|
||||
self.style_embedder = nn.Embedding(1, embedding_dim)
|
||||
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
|
||||
self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
|
||||
if use_style_cond_and_image_meta_size:
|
||||
self.style_embedder = nn.Embedding(1, embedding_dim)
|
||||
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
|
||||
else:
|
||||
extra_in_dim = pooled_projection_dim
|
||||
|
||||
self.extra_embedder = PixArtAlphaTextProjection(
|
||||
in_features=extra_in_dim,
|
||||
hidden_size=embedding_dim * 4,
|
||||
@@ -743,16 +756,20 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
||||
# extra condition1: text
|
||||
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
|
||||
|
||||
# extra condition2: image meta size embdding
|
||||
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
|
||||
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
|
||||
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
|
||||
if self.use_style_cond_and_image_meta_size:
|
||||
# extra condition2: image meta size embdding
|
||||
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
|
||||
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
|
||||
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
|
||||
|
||||
# extra condition3: style embedding
|
||||
style_embedding = self.style_embedder(style) # (N, embedding_dim)
|
||||
# extra condition3: style embedding
|
||||
style_embedding = self.style_embedder(style) # (N, embedding_dim)
|
||||
|
||||
# Concatenate all extra vectors
|
||||
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
|
||||
else:
|
||||
extra_cond = torch.cat([pooled_projections], dim=1)
|
||||
|
||||
# Concatenate all extra vectors
|
||||
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
|
||||
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
|
||||
|
||||
return conditioning
|
||||
|
||||
@@ -249,6 +249,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
The length of the clip text embedding.
|
||||
text_len_t5 (`int`, *optional*):
|
||||
The length of the T5 text embedding.
|
||||
use_style_cond_and_image_meta_size (`bool`, *optional*):
|
||||
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -270,6 +272,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
pooled_projection_dim: int = 1024,
|
||||
text_len: int = 77,
|
||||
text_len_t5: int = 256,
|
||||
use_style_cond_and_image_meta_size: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
@@ -301,6 +304,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
seq_len=text_len_t5,
|
||||
cross_attention_dim=cross_attention_dim_t5,
|
||||
use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
|
||||
)
|
||||
|
||||
# HunyuanDiT Blocks
|
||||
|
||||
Reference in New Issue
Block a user