From adcc53206bcb75b045daeed889ded027c639f1e5 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 7 Jan 2026 17:07:13 +0800 Subject: [PATCH] 2 --- .../models/transformers/transformer_glm_image.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 4c296b48aa..6eb9b5e803 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -18,7 +18,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin @@ -376,7 +375,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach The number of heads to use for multi-head attention. out_channels (`int`, defaults to `16`): The number of channels in the output. - text_embed_dim (`int`, defaults to `4096`): + text_embed_dim (`int`, defaults to `1472`): Input dimension of text embeddings from the text encoder. time_embed_dim (`int`, defaults to `512`): Output dimension of timestep embeddings. @@ -428,8 +427,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach # 2. Patch & Text-timestep embedding self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size) - self.text_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu") - + self.glyph_projector = FeedForward(text_embed_dim, inner_dim, activation_fn="gelu") self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim) self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu") @@ -482,7 +480,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach post_patch_width = width // p hidden_states = self.image_projector(hidden_states) - encoder_hidden_states = self.text_projector(encoder_hidden_states) + encoder_hidden_states = self.glyph_projector(encoder_hidden_states) prior_embedding = self.prior_token_embedding(prior_token_id) prior_embedding[prior_token_drop] *= 0.0 prior_hidden_states = self.prior_projector(prior_embedding)