1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

take pooled projections out of transformer

This commit is contained in:
Edna
2025-06-11 19:38:38 -06:00
committed by GitHub
parent df7fde7a6d
commit 68f771bf43

View File

@@ -236,8 +236,6 @@ class ChromaTransformer2DModel(
joint_attention_dim (`int`, defaults to `4096`):
The number of dimensions to use for the joint attention (embedding/channel dimension of
`encoder_hidden_states`).
pooled_projection_dim (`int`, defaults to `768`):
The number of dimensions to use for the pooled projection.
guidance_embeds (`bool`, defaults to `False`):
Whether to use guidance embeddings for guidance-distilled variant of the model.
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
@@ -259,7 +257,6 @@ class ChromaTransformer2DModel(
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
approximator_in_factor: int = 16,
@@ -416,7 +413,6 @@ class ChromaTransformer2DModel(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
@@ -435,8 +431,6 @@ class ChromaTransformer2DModel(
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
@@ -474,7 +468,7 @@ class ChromaTransformer2DModel(
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
input_vec = self.time_text_embed(timestep, guidance, pooled_projections)
input_vec = self.time_text_embed(timestep, guidance)
pooled_temb = self.distilled_guidance_layer(input_vec)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)