diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index cb64bc61f3..92e9a6d4f5 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -717,7 +717,14 @@ class HunyuanDiTAttentionPool(nn.Module): class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): - def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048): + def __init__( + self, + embedding_dim, + pooled_projection_dim=1024, + seq_len=256, + cross_attention_dim=2048, + use_style_cond_and_image_meta_size=True, + ): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) @@ -726,9 +733,15 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): self.pooler = HunyuanDiTAttentionPool( seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim ) + # Here we use a default learned embedder layer for future extension. - self.style_embedder = nn.Embedding(1, embedding_dim) - extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim + self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size + if use_style_cond_and_image_meta_size: + self.style_embedder = nn.Embedding(1, embedding_dim) + extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim + else: + extra_in_dim = pooled_projection_dim + self.extra_embedder = PixArtAlphaTextProjection( in_features=extra_in_dim, hidden_size=embedding_dim * 4, @@ -743,16 +756,20 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): # extra condition1: text pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) - # extra condition2: image meta size embdding - image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0) - image_meta_size = image_meta_size.to(dtype=hidden_dtype) - image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) + if self.use_style_cond_and_image_meta_size: + # extra condition2: image meta size embdding + image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0) + image_meta_size = image_meta_size.to(dtype=hidden_dtype) + image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) - # extra condition3: style embedding - style_embedding = self.style_embedder(style) # (N, embedding_dim) + # extra condition3: style embedding + style_embedding = self.style_embedder(style) # (N, embedding_dim) + + # Concatenate all extra vectors + extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) + else: + extra_cond = torch.cat([pooled_projections], dim=1) - # Concatenate all extra vectors - extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] return conditioning diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index d67b35586a..8313ffd87a 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -249,6 +249,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): The length of the clip text embedding. text_len_t5 (`int`, *optional*): The length of the T5 text embedding. + use_style_cond_and_image_meta_size (`bool`, *optional*): + Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 """ @register_to_config @@ -270,6 +272,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): pooled_projection_dim: int = 1024, text_len: int = 77, text_len_t5: int = 256, + use_style_cond_and_image_meta_size: bool = True, ): super().__init__() self.out_channels = in_channels * 2 if learn_sigma else in_channels @@ -301,6 +304,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): pooled_projection_dim=pooled_projection_dim, seq_len=text_len_t5, cross_attention_dim=cross_attention_dim_t5, + use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size, ) # HunyuanDiT Blocks