From 68f771bf43cc4732ddbb714341242f2ac37ce983 Mon Sep 17 00:00:00 2001 From: Edna <88869424+Ednaordinary@users.noreply.github.com> Date: Wed, 11 Jun 2025 19:38:38 -0600 Subject: [PATCH] take pooled projections out of transformer --- src/diffusers/models/transformers/transformer_chroma.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 7b46ef9c43..72cde1f60b 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -236,8 +236,6 @@ class ChromaTransformer2DModel( joint_attention_dim (`int`, defaults to `4096`): The number of dimensions to use for the joint attention (embedding/channel dimension of `encoder_hidden_states`). - pooled_projection_dim (`int`, defaults to `768`): - The number of dimensions to use for the pooled projection. guidance_embeds (`bool`, defaults to `False`): Whether to use guidance embeddings for guidance-distilled variant of the model. axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): @@ -259,7 +257,6 @@ class ChromaTransformer2DModel( attention_head_dim: int = 128, num_attention_heads: int = 24, joint_attention_dim: int = 4096, - pooled_projection_dim: int = 768, guidance_embeds: bool = False, axes_dims_rope: Tuple[int, ...] = (16, 56, 56), approximator_in_factor: int = 16, @@ -416,7 +413,6 @@ class ChromaTransformer2DModel( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, - pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, @@ -435,8 +431,6 @@ class ChromaTransformer2DModel( Input `hidden_states`. encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step. block_controlnet_hidden_states: (`list` of `torch.Tensor`): @@ -474,7 +468,7 @@ class ChromaTransformer2DModel( if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 - input_vec = self.time_text_embed(timestep, guidance, pooled_projections) + input_vec = self.time_text_embed(timestep, guidance) pooled_temb = self.distilled_guidance_layer(input_vec) encoder_hidden_states = self.context_embedder(encoder_hidden_states)