From 1cf277d36d55e3942e34726b5f3e982006d2a615 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 7 Jan 2026 23:51:40 +0800 Subject: [PATCH] remove sop --- .../pipelines/glm_image/pipeline_glm_image.py | 84 +++++++++++++------ 1 file changed, 58 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 7286f9ab27..e4dd7de809 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -290,64 +290,96 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): Returns: Upsampled token IDs of shape [1, N*4] where N*4 = (token_h*2) * (token_w*2) """ - # Reshape to spatial format: [1, 1, H, W] token_ids = token_ids.view(1, 1, token_h, token_w) - - # 2x nearest-neighbor upsampling token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to( dtype=torch.long ) - # Flatten back to [1, H*W*4] token_ids = token_ids.view(1, -1) - return token_ids + def _build_prompt_with_shape( + self, + prompt: str, + height: int, + width: int, + is_text_to_image: bool, + factor: int = 32, + ) -> Tuple[str, int, int, int, int]: + """ + Build prompt with shape info (H W) based on height and width. + + Args: + prompt: The raw text prompt without shape info + height: Target image height in pixels + width: Target image width in pixels + is_text_to_image: Whether this is text-to-image (True) or image-to-image (False) + + Returns: + Tuple of (expanded_prompt, token_h, token_w, prev_token_h, prev_token_w) + """ + token_h = height // factor + token_w = width // factor + ratio = token_h / token_w + prev_token_h = int(sqrt(ratio) * (factor // 2)) + prev_token_w = int(sqrt(1 / ratio) * (factor // 2)) + + if is_text_to_image: + expanded_prompt = f"{prompt}{token_h} {token_w}{prev_token_h} {prev_token_w}" + else: + expanded_prompt = f"{prompt}{token_h} {token_w}" + + return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w + def generate_prior_tokens( self, prompt: str, + height: int, + width: int, image: Optional[List[PIL.Image.Image]] = None, + factor: int = 32, ) -> Tuple[torch.Tensor, int, int]: """ Generate prior tokens using the AR (vision_language_encoder) model. + Automatically builds the prompt with shape info based on height/width. Users only need to provide the raw text + prompt without ... tags. + Args: - prompt: The text prompt with shape info (e.g., "description36 24") - image: Optional list of condition images for i2i + prompt: The raw text prompt (without shape info) + height: Target image height in pixels (must be divisible by 32) + width: Target image width in pixels (must be divisible by 32) + image: Optional list of condition images for image-to-image generation Returns: Tuple of (prior_token_ids, pixel_height, pixel_width) - prior_token_ids: Upsampled to d16 format, shape [1, token_h*token_w*4] - - pixel_height: Image height in pixels - - pixel_width: Image width in pixels + - pixel_height: Image height in pixels (aligned to 32) + - pixel_width: Image width in pixels (aligned to 32) + """ device = self.vision_language_encoder.device - - # Parse and expand shape info - expanded_prompt, token_h, token_w, prev_h, prev_w = self._parse_and_expand_shape_info(prompt) - - # Build messages for processor + height = (height // factor) * factor + width = (width // factor) * factor + is_text_to_image = image is None or len(image) == 0 + expanded_prompt, token_h, token_w, prev_h, prev_w = self._build_prompt_with_shape( + prompt, height, width, is_text_to_image + ) content = [] if image is not None: for img in image: content.append({"type": "image", "image": img}) content.append({"type": "text", "text": expanded_prompt}) messages = [{"role": "user", "content": content}] - - # Process inputs inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, - return_tensors="pt", + return_tensors="pt" ) - # Determine if text-to-image or image-to-image existing_grid = inputs.get("image_grid_thw") - is_text_to_image = existing_grid is None or existing_grid.numel() == 0 - - # Build image grid inputs["image_grid_thw"] = self._build_image_grid_thw( token_h, token_w, @@ -378,8 +410,8 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): ) prior_token_ids = self._upsample_d32_to_d16(prior_token_ids_d32, token_h, token_w) - pixel_height = token_h * 32 - pixel_width = token_w * 32 + pixel_height = token_h * factor + pixel_width = token_w * factor return prior_token_ids, pixel_height, pixel_width @@ -683,6 +715,8 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): prior_token_id, ar_height, ar_width = self.generate_prior_tokens( prompt=prompt[0] if isinstance(prompt, list) else prompt, image=ar_condition_images, + height=height, + width=width, ) height = height or ar_height @@ -739,9 +773,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): .to(self.vae.device, self.vae.dtype) ) empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...] - for condition_image, condition_image_prior_token_id in zip( - image, condition_images_prior_token_id - ): + for condition_image, condition_image_prior_token_id in zip(image, condition_images_prior_token_id): condition_image = condition_image.to(device=device, dtype=self.vae.dtype) condition_latent = retrieve_latents( self.vae.encode(condition_image), generator=generator, sample_mode="argmax"