mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -100,6 +100,85 @@ DEFAULT_PROMPT_TEMPLATE = {
|
||||
}
|
||||
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, image_token_index, pad_token_id
|
||||
):
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
special_image_token_mask = input_ids == image_token_index
|
||||
# 1. Create a mask to know where special image tokens are
|
||||
special_image_token_mask = input_ids == image_token_index
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
batch_indices, non_image_indices = torch.where(input_ids != image_token_index)
|
||||
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
|
||||
batch_indices, non_image_indices = torch.where(input_ids != image_token_index)
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged image-text sequence.
|
||||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
||||
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
|
||||
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
||||
if left_padding:
|
||||
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
# 3. Create the full embedding, already padded to the maximum position
|
||||
final_embedding = torch.zeros(
|
||||
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
)
|
||||
final_attention_mask = torch.zeros(
|
||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
batch_indices, non_image_indices, text_to_overwrite = (
|
||||
batch_indices.to(target_device),
|
||||
non_image_indices.to(target_device),
|
||||
text_to_overwrite.to(target_device),
|
||||
)
|
||||
attention_mask = attention_mask.to(target_device)
|
||||
|
||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
||||
image_to_overwrite = torch.full((batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device)
|
||||
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||
if left_padding:
|
||||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
||||
else:
|
||||
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
|
||||
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
|
||||
image_to_overwrite &= padding_mask
|
||||
|
||||
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
||||
raise ValueError(
|
||||
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
||||
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
||||
)
|
||||
|
||||
batch_indices, pad_indices = torch.where(input_ids == pad_token_id)
|
||||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
|
||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
batch_indices, pad_indices = torch.where(input_ids == pad_token_id)
|
||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
|
||||
final_embedding[batch_indices, indices_to_mask] = 0
|
||||
|
||||
return final_embedding, final_attention_mask, position_ids
|
||||
|
||||
|
||||
def _text_encoder_custom_forward():
|
||||
return
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
@@ -279,13 +358,30 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
|
||||
|
||||
image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
|
||||
input_embeds = self.text_encoder.get_input_embeddings()(text_input_ids)
|
||||
|
||||
inputs_embeds, attention_mask, position_ids = _merge_input_ids_with_image_features(
|
||||
image_embeds,
|
||||
input_embeds,
|
||||
text_input_ids,
|
||||
prompt_attention_mask,
|
||||
self.text_encoder.config.image_token_index,
|
||||
self.text_encoder.pad_token_id,
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
||||
"""
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
pixel_values=image_embeds,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
||||
"""
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
|
||||
image_emb_len = prompt_template.get("image_emb_len", 576)
|
||||
|
||||
Reference in New Issue
Block a user