mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
remove sop
This commit is contained in:
@@ -290,64 +290,96 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
Returns:
|
||||
Upsampled token IDs of shape [1, N*4] where N*4 = (token_h*2) * (token_w*2)
|
||||
"""
|
||||
# Reshape to spatial format: [1, 1, H, W]
|
||||
token_ids = token_ids.view(1, 1, token_h, token_w)
|
||||
|
||||
# 2x nearest-neighbor upsampling
|
||||
token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to(
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
# Flatten back to [1, H*W*4]
|
||||
token_ids = token_ids.view(1, -1)
|
||||
|
||||
return token_ids
|
||||
|
||||
def _build_prompt_with_shape(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
is_text_to_image: bool,
|
||||
factor: int = 32,
|
||||
) -> Tuple[str, int, int, int, int]:
|
||||
"""
|
||||
Build prompt with shape info (<sop>H W<eop>) based on height and width.
|
||||
|
||||
Args:
|
||||
prompt: The raw text prompt without shape info
|
||||
height: Target image height in pixels
|
||||
width: Target image width in pixels
|
||||
is_text_to_image: Whether this is text-to-image (True) or image-to-image (False)
|
||||
|
||||
Returns:
|
||||
Tuple of (expanded_prompt, token_h, token_w, prev_token_h, prev_token_w)
|
||||
"""
|
||||
token_h = height // factor
|
||||
token_w = width // factor
|
||||
ratio = token_h / token_w
|
||||
prev_token_h = int(sqrt(ratio) * (factor // 2))
|
||||
prev_token_w = int(sqrt(1 / ratio) * (factor // 2))
|
||||
|
||||
if is_text_to_image:
|
||||
expanded_prompt = f"{prompt}<sop>{token_h} {token_w}<eop><sop>{prev_token_h} {prev_token_w}<eop>"
|
||||
else:
|
||||
expanded_prompt = f"{prompt}<sop>{token_h} {token_w}<eop>"
|
||||
|
||||
return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w
|
||||
|
||||
def generate_prior_tokens(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
image: Optional[List[PIL.Image.Image]] = None,
|
||||
factor: int = 32,
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
"""
|
||||
Generate prior tokens using the AR (vision_language_encoder) model.
|
||||
|
||||
Automatically builds the prompt with shape info based on height/width. Users only need to provide the raw text
|
||||
prompt without <sop>...<eop> tags.
|
||||
|
||||
Args:
|
||||
prompt: The text prompt with shape info (e.g., "description<sop>36 24<eop>")
|
||||
image: Optional list of condition images for i2i
|
||||
prompt: The raw text prompt (without shape info)
|
||||
height: Target image height in pixels (must be divisible by 32)
|
||||
width: Target image width in pixels (must be divisible by 32)
|
||||
image: Optional list of condition images for image-to-image generation
|
||||
|
||||
Returns:
|
||||
Tuple of (prior_token_ids, pixel_height, pixel_width)
|
||||
- prior_token_ids: Upsampled to d16 format, shape [1, token_h*token_w*4]
|
||||
- pixel_height: Image height in pixels
|
||||
- pixel_width: Image width in pixels
|
||||
- pixel_height: Image height in pixels (aligned to 32)
|
||||
- pixel_width: Image width in pixels (aligned to 32)
|
||||
|
||||
"""
|
||||
device = self.vision_language_encoder.device
|
||||
|
||||
# Parse and expand shape info
|
||||
expanded_prompt, token_h, token_w, prev_h, prev_w = self._parse_and_expand_shape_info(prompt)
|
||||
|
||||
# Build messages for processor
|
||||
height = (height // factor) * factor
|
||||
width = (width // factor) * factor
|
||||
is_text_to_image = image is None or len(image) == 0
|
||||
expanded_prompt, token_h, token_w, prev_h, prev_w = self._build_prompt_with_shape(
|
||||
prompt, height, width, is_text_to_image
|
||||
)
|
||||
content = []
|
||||
if image is not None:
|
||||
for img in image:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": expanded_prompt})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
# Process inputs
|
||||
inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Determine if text-to-image or image-to-image
|
||||
existing_grid = inputs.get("image_grid_thw")
|
||||
is_text_to_image = existing_grid is None or existing_grid.numel() == 0
|
||||
|
||||
# Build image grid
|
||||
inputs["image_grid_thw"] = self._build_image_grid_thw(
|
||||
token_h,
|
||||
token_w,
|
||||
@@ -378,8 +410,8 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
)
|
||||
prior_token_ids = self._upsample_d32_to_d16(prior_token_ids_d32, token_h, token_w)
|
||||
|
||||
pixel_height = token_h * 32
|
||||
pixel_width = token_w * 32
|
||||
pixel_height = token_h * factor
|
||||
pixel_width = token_w * factor
|
||||
|
||||
return prior_token_ids, pixel_height, pixel_width
|
||||
|
||||
@@ -683,6 +715,8 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
prior_token_id, ar_height, ar_width = self.generate_prior_tokens(
|
||||
prompt=prompt[0] if isinstance(prompt, list) else prompt,
|
||||
image=ar_condition_images,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
height = height or ar_height
|
||||
@@ -739,9 +773,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
.to(self.vae.device, self.vae.dtype)
|
||||
)
|
||||
empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...]
|
||||
for condition_image, condition_image_prior_token_id in zip(
|
||||
image, condition_images_prior_token_id
|
||||
):
|
||||
for condition_image, condition_image_prior_token_id in zip(image, condition_images_prior_token_id):
|
||||
condition_image = condition_image.to(device=device, dtype=self.vae.dtype)
|
||||
condition_latent = retrieve_latents(
|
||||
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
|
||||
|
||||
Reference in New Issue
Block a user