mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -101,6 +101,50 @@ DEFAULT_PROMPT_TEMPLATE = {
|
||||
}
|
||||
|
||||
|
||||
def _expand_input_ids_with_image_tokens(
|
||||
text_input_ids,
|
||||
prompt_attention_mask,
|
||||
max_sequence_length,
|
||||
image_token_index,
|
||||
image_emb_len,
|
||||
image_emb_start,
|
||||
image_emb_end,
|
||||
pad_token_id,
|
||||
):
|
||||
special_image_token_mask = text_input_ids == image_token_index
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index)
|
||||
|
||||
max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1))
|
||||
new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
expanded_input_ids = torch.full(
|
||||
(text_input_ids.shape[0], max_expanded_length),
|
||||
pad_token_id,
|
||||
dtype=text_input_ids.dtype,
|
||||
device=text_input_ids.device,
|
||||
)
|
||||
expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices]
|
||||
expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index
|
||||
|
||||
expanded_attention_mask = torch.zeros(
|
||||
(text_input_ids.shape[0], max_expanded_length),
|
||||
dtype=prompt_attention_mask.dtype,
|
||||
device=prompt_attention_mask.device,
|
||||
)
|
||||
attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id)
|
||||
expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0
|
||||
expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype)
|
||||
position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1)
|
||||
|
||||
return {
|
||||
"input_ids": expanded_input_ids,
|
||||
"attention_mask": expanded_attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
@@ -259,6 +303,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
prompt = [prompt_template["template"].format(p) for p in prompt]
|
||||
|
||||
crop_start = prompt_template.get("crop_start", None)
|
||||
|
||||
image_emb_len = prompt_template.get("image_emb_len", 576)
|
||||
image_emb_start = prompt_template.get("image_emb_start", 5)
|
||||
image_emb_end = prompt_template.get("image_emb_end", 581)
|
||||
double_return_token_id = prompt_template.get("double_return_token_id", 271)
|
||||
|
||||
if crop_start is None:
|
||||
prompt_template_input = self.tokenizer(
|
||||
prompt_template["template"],
|
||||
@@ -288,69 +338,25 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
|
||||
image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
|
||||
|
||||
_, _, image_height, image_width = image_embeds.shape
|
||||
patch_size = self.text_encoder.config.vision_config.patch_size
|
||||
num_image_tokens = (image_height // patch_size) * (image_width // patch_size)
|
||||
if self.text_encoder.config.vision_config.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
|
||||
image_token_index = self.text_encoder.config.image_token_index
|
||||
pad_token_id = self.text_encoder.config.pad_token_id
|
||||
batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index)
|
||||
|
||||
special_image_token_mask = text_input_ids == image_token_index
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
|
||||
max_expanded_length = max_sequence_length + (
|
||||
num_special_image_tokens.max() * (prompt_template["image_emb_len"] - 1)
|
||||
expanded_inputs = _expand_input_ids_with_image_tokens(
|
||||
text_input_ids,
|
||||
prompt_attention_mask,
|
||||
max_sequence_length,
|
||||
image_token_index,
|
||||
image_emb_len,
|
||||
image_emb_start,
|
||||
image_emb_end,
|
||||
pad_token_id,
|
||||
)
|
||||
new_token_positions = (
|
||||
torch.cumsum((special_image_token_mask * (prompt_template["image_emb_len"] - 1) + 1), -1) - 1
|
||||
)
|
||||
nb_image_pad = max_expanded_length - 1 - new_token_positions[:, -1]
|
||||
if left_padding:
|
||||
new_token_positions += nb_image_pad[:, None]
|
||||
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
expanded_input_ids = torch.full(
|
||||
(batch_size, max_expanded_length), pad_token_id, dtype=text_input_ids.dtype, device=device
|
||||
)
|
||||
expanded_attention_mask = torch.ones(
|
||||
(batch_size, max_expanded_length), dtype=prompt_attention_mask.dtype, device=device
|
||||
)
|
||||
|
||||
expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices]
|
||||
expanded_inputs_ids[batch_indices, prompt_template["image_emb_start"] : prompt_template["image_emb_end"]] = (
|
||||
image_token_index
|
||||
)
|
||||
|
||||
inputs = self.llava_processor(
|
||||
text=prompt,
|
||||
images=image,
|
||||
# max_length=max_sequence_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
text_input_ids = inputs["input_ids"]
|
||||
prompt_attention_mask = inputs["attention_mask"]
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
**inputs,
|
||||
**expanded_inputs,
|
||||
pixel_value=image_embeds,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
|
||||
image_emb_len = prompt_template.get("image_emb_len", 576)
|
||||
image_emb_start = prompt_template.get("image_emb_start", 5)
|
||||
image_emb_end = prompt_template.get("image_emb_end", 581)
|
||||
double_return_token_id = prompt_template.get("double_return_token_id", 271)
|
||||
|
||||
if crop_start is not None and crop_start > 0:
|
||||
text_crop_start = crop_start - 1 + image_emb_len
|
||||
batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)
|
||||
|
||||
Reference in New Issue
Block a user