1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
zRzRzRzRzRzRzR
2026-01-07 17:16:39 +08:00
parent adcc53206b
commit ec678a1fb7
2 changed files with 4 additions and 16 deletions

View File

@@ -427,7 +427,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
# 2. Patch & Text-timestep embedding
self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size)
self.glyph_projector = FeedForward(text_embed_dim, inner_dim, activation_fn="gelu")
self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu")
self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim)
self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu")

View File

@@ -432,7 +432,6 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
output_type: str = "pil",
return_dict: bool = True,
@@ -496,11 +495,6 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
original_size (`Tuple[int]`, *optional*, defaults to (2048, 2048)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
@@ -637,24 +631,20 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
condition_latent = (condition_latent - latents_mean) / latents_std
_ = self.transformer(
hidden_states=condition_latent,
glyph_hidden_states=empty_glyph_hiddens,
encoder_hidden_states=empty_glyph_hiddens,
prior_token_id=condition_image_prior_token_id,
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
timestep=torch.zeros((1,), device=device),
original_size=torch.tensor([condition_image.shape[-2:]], device=device),
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
crop_coords=torch.zeros((1, 2), device=device),
attention_kwargs=attention_kwargs,
)
# 6. Prepare additional timestep conditions
original_size = original_size or (height, width)
target_size = (height, width)
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device)
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
original_size = original_size.repeat(batch_size * num_images_per_prompt, 1)
target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
@@ -702,11 +692,10 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
glyph_hidden_states=prompt_embeds,
encoder_hidden_states=prompt_embeds,
prior_token_id=prior_token_id,
prior_token_drop=prior_token_drop_cond,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
@@ -721,11 +710,10 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
)
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
glyph_hidden_states=negative_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
prior_token_id=prior_token_id,
prior_token_drop=prior_token_drop_uncond,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,