mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
take pooled projections out of transformer
This commit is contained in:
@@ -236,8 +236,6 @@ class ChromaTransformer2DModel(
|
||||
joint_attention_dim (`int`, defaults to `4096`):
|
||||
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
||||
`encoder_hidden_states`).
|
||||
pooled_projection_dim (`int`, defaults to `768`):
|
||||
The number of dimensions to use for the pooled projection.
|
||||
guidance_embeds (`bool`, defaults to `False`):
|
||||
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
||||
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
@@ -259,7 +257,6 @@ class ChromaTransformer2DModel(
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
|
||||
approximator_in_factor: int = 16,
|
||||
@@ -416,7 +413,6 @@ class ChromaTransformer2DModel(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
@@ -435,8 +431,6 @@ class ChromaTransformer2DModel(
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
@@ -474,7 +468,7 @@ class ChromaTransformer2DModel(
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
input_vec = self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
input_vec = self.time_text_embed(timestep, guidance)
|
||||
pooled_temb = self.distilled_guidance_layer(input_vec)
|
||||
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user