1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Update pipeline_glm_image.py

This commit is contained in:
zRzRzRzRzRzRzR
2026-01-07 22:55:16 +08:00
parent acd13d8769
commit e2b31f8b15

View File

@@ -306,14 +306,14 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
def generate_prior_tokens(
self,
prompt: str,
condition_images: Optional[List[PIL.Image.Image]] = None,
image: Optional[List[PIL.Image.Image]] = None,
) -> Tuple[torch.Tensor, int, int]:
"""
Generate prior tokens using the AR (vision_language_encoder) model.
Args:
prompt: The text prompt with shape info (e.g., "description<sop>36 24<eop>")
condition_images: Optional list of condition images for i2i
image: Optional list of condition images for i2i
Returns:
Tuple of (prior_token_ids, pixel_height, pixel_width)
@@ -328,8 +328,8 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
# Build messages for processor
content = []
if condition_images is not None:
for img in condition_images:
if image is not None:
for img in image:
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": expanded_prompt})
messages = [{"role": "user", "content": content}]
@@ -579,7 +579,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
condition_images: Optional[
image: Optional[
Union[
torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray]
]
@@ -612,7 +612,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
The prompt or prompts to guide the image generation. Must contain shape info in the format '<sop>H
W<eop>' where H and W are token dimensions (d32). Example: "A beautiful sunset<sop>36 24<eop>"
generates a 1152x768 image.
condition_images: Optional condition images for image-to-image generation.
image: Optional condition images for image-to-image generation.
height (`int`, *optional*):
The height in pixels. If not provided, derived from prompt shape info.
width (`int`, *optional*):
@@ -661,11 +661,11 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
device = self._execution_device
ar_condition_images = None
if condition_images is not None:
if not isinstance(condition_images, list):
condition_images = [condition_images]
if image is not None:
if not isinstance(image, list):
image = [image]
ar_condition_images = []
for img in condition_images:
for img in image:
if isinstance(img, PIL.Image.Image):
ar_condition_images.append(img)
elif isinstance(img, torch.Tensor):
@@ -682,7 +682,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
prior_token_id, ar_height, ar_width = self.generate_prior_tokens(
prompt=prompt[0] if isinstance(prompt, list) else prompt,
condition_images=ar_condition_images,
image=ar_condition_images,
)
height = height or ar_height
@@ -701,19 +701,19 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
# 4. process images
condition_images_prior_token_id = None
if condition_images is not None:
if image is not None:
preprocessed_condition_images = []
condition_images_prior_token_id = []
for img in condition_images:
for img in image:
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
image_height = (image_height // multiple_of) * multiple_of
image_width = (image_width // multiple_of) * multiple_of
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
preprocessed_condition_images.append(img)
condition_images = preprocessed_condition_images
image = preprocessed_condition_images
# 5. Prepare latents and (optional) condition_images kv cache
# 5. Prepare latents and (optional) image kv cache
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size=batch_size * num_images_per_prompt,
@@ -726,7 +726,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
latents=latents,
)
if condition_images is not None and condition_images_prior_token_id is not None:
if image is not None and condition_images_prior_token_id is not None:
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditWriteKV)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
@@ -740,7 +740,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
)
empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...]
for condition_image, condition_image_prior_token_id in zip(
condition_images, condition_images_prior_token_id
image, condition_images_prior_token_id
):
condition_image = condition_image.to(device=device, dtype=self.vae.dtype)
condition_latent = retrieve_latents(
@@ -804,7 +804,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
timestep = t.expand(latents.shape[0]) - 1
if condition_images is not None:
if image is not None:
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditReadKV)
noise_pred_cond = self.transformer(
@@ -821,7 +821,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
# perform guidance
if self.do_classifier_free_guidance:
if condition_images is not None:
if image is not None:
self.transformer.set_attention_processors_state(
GlmImageAttenProcessorState.ImageEditDontReadKV
)
@@ -874,16 +874,16 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
.to(latents.device, latents.dtype)
)
latents = latents * latents_std + latents_mean
condition_images = self.vae.decode(latents, return_dict=False, generator=generator)[0]
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
else:
condition_images = latents
image = latents
condition_images = self.image_processor.postprocess(condition_images, output_type=output_type)
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (condition_images,)
return (image,)
return GlmImagePipelineOutput(images=condition_images)
return GlmImagePipelineOutput(images=image)