1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2025-04-10 11:37:24 +02:00
parent b8093e6665
commit 0f599d9901

View File

@@ -100,6 +100,85 @@ DEFAULT_PROMPT_TEMPLATE = {
}
def _merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, image_token_index, pad_token_id
):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
special_image_token_mask = input_ids == image_token_index
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
batch_indices, non_image_indices = torch.where(input_ids != image_token_index)
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
batch_indices, non_image_indices = torch.where(input_ids != image_token_index)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full((batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device)
image_to_overwrite[batch_indices, text_to_overwrite] = False
if left_padding:
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
else:
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
image_to_overwrite &= padding_mask
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)
batch_indices, pad_indices = torch.where(input_ids == pad_token_id)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
return final_embedding, final_attention_mask, position_ids
def _text_encoder_custom_forward():
return
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
@@ -279,13 +358,30 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
input_embeds = self.text_encoder.get_input_embeddings()(text_input_ids)
inputs_embeds, attention_mask, position_ids = _merge_input_ids_with_image_features(
image_embeds,
input_embeds,
text_input_ids,
prompt_attention_mask,
self.text_encoder.config.image_token_index,
self.text_encoder.pad_token_id,
)
prompt_embeds = self.text_encoder.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
).hidden_states[-(num_hidden_layers_to_skip + 1)]
"""
prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_attention_mask,
pixel_values=image_embeds,
output_hidden_states=True,
).hidden_states[-(num_hidden_layers_to_skip + 1)]
"""
prompt_embeds = prompt_embeds.to(dtype=dtype)
image_emb_len = prompt_template.get("image_emb_len", 576)