mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Update pipeline_glm_image.py
This commit is contained in:
@@ -306,14 +306,14 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
def generate_prior_tokens(
|
||||
self,
|
||||
prompt: str,
|
||||
condition_images: Optional[List[PIL.Image.Image]] = None,
|
||||
image: Optional[List[PIL.Image.Image]] = None,
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
"""
|
||||
Generate prior tokens using the AR (vision_language_encoder) model.
|
||||
|
||||
Args:
|
||||
prompt: The text prompt with shape info (e.g., "description<sop>36 24<eop>")
|
||||
condition_images: Optional list of condition images for i2i
|
||||
image: Optional list of condition images for i2i
|
||||
|
||||
Returns:
|
||||
Tuple of (prior_token_ids, pixel_height, pixel_width)
|
||||
@@ -328,8 +328,8 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
|
||||
# Build messages for processor
|
||||
content = []
|
||||
if condition_images is not None:
|
||||
for img in condition_images:
|
||||
if image is not None:
|
||||
for img in image:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": expanded_prompt})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
@@ -579,7 +579,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
condition_images: Optional[
|
||||
image: Optional[
|
||||
Union[
|
||||
torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray]
|
||||
]
|
||||
@@ -612,7 +612,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
The prompt or prompts to guide the image generation. Must contain shape info in the format '<sop>H
|
||||
W<eop>' where H and W are token dimensions (d32). Example: "A beautiful sunset<sop>36 24<eop>"
|
||||
generates a 1152x768 image.
|
||||
condition_images: Optional condition images for image-to-image generation.
|
||||
image: Optional condition images for image-to-image generation.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels. If not provided, derived from prompt shape info.
|
||||
width (`int`, *optional*):
|
||||
@@ -661,11 +661,11 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
device = self._execution_device
|
||||
|
||||
ar_condition_images = None
|
||||
if condition_images is not None:
|
||||
if not isinstance(condition_images, list):
|
||||
condition_images = [condition_images]
|
||||
if image is not None:
|
||||
if not isinstance(image, list):
|
||||
image = [image]
|
||||
ar_condition_images = []
|
||||
for img in condition_images:
|
||||
for img in image:
|
||||
if isinstance(img, PIL.Image.Image):
|
||||
ar_condition_images.append(img)
|
||||
elif isinstance(img, torch.Tensor):
|
||||
@@ -682,7 +682,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
|
||||
prior_token_id, ar_height, ar_width = self.generate_prior_tokens(
|
||||
prompt=prompt[0] if isinstance(prompt, list) else prompt,
|
||||
condition_images=ar_condition_images,
|
||||
image=ar_condition_images,
|
||||
)
|
||||
|
||||
height = height or ar_height
|
||||
@@ -701,19 +701,19 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
|
||||
# 4. process images
|
||||
condition_images_prior_token_id = None
|
||||
if condition_images is not None:
|
||||
if image is not None:
|
||||
preprocessed_condition_images = []
|
||||
condition_images_prior_token_id = []
|
||||
for img in condition_images:
|
||||
for img in image:
|
||||
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
|
||||
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
|
||||
preprocessed_condition_images.append(img)
|
||||
condition_images = preprocessed_condition_images
|
||||
image = preprocessed_condition_images
|
||||
|
||||
# 5. Prepare latents and (optional) condition_images kv cache
|
||||
# 5. Prepare latents and (optional) image kv cache
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
@@ -726,7 +726,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
if condition_images is not None and condition_images_prior_token_id is not None:
|
||||
if image is not None and condition_images_prior_token_id is not None:
|
||||
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditWriteKV)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
@@ -740,7 +740,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
)
|
||||
empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...]
|
||||
for condition_image, condition_image_prior_token_id in zip(
|
||||
condition_images, condition_images_prior_token_id
|
||||
image, condition_images_prior_token_id
|
||||
):
|
||||
condition_image = condition_image.to(device=device, dtype=self.vae.dtype)
|
||||
condition_latent = retrieve_latents(
|
||||
@@ -804,7 +804,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
|
||||
timestep = t.expand(latents.shape[0]) - 1
|
||||
|
||||
if condition_images is not None:
|
||||
if image is not None:
|
||||
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditReadKV)
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
@@ -821,7 +821,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
if condition_images is not None:
|
||||
if image is not None:
|
||||
self.transformer.set_attention_processors_state(
|
||||
GlmImageAttenProcessorState.ImageEditDontReadKV
|
||||
)
|
||||
@@ -874,16 +874,16 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std + latents_mean
|
||||
condition_images = self.vae.decode(latents, return_dict=False, generator=generator)[0]
|
||||
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
|
||||
else:
|
||||
condition_images = latents
|
||||
image = latents
|
||||
|
||||
condition_images = self.image_processor.postprocess(condition_images, output_type=output_type)
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (condition_images,)
|
||||
return (image,)
|
||||
|
||||
return GlmImagePipelineOutput(images=condition_images)
|
||||
return GlmImagePipelineOutput(images=image)
|
||||
|
||||
Reference in New Issue
Block a user