From 0f599d99010d96996e6c930c015fba1583ae3992 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 10 Apr 2025 11:37:24 +0200 Subject: [PATCH] update --- .../pipeline_hunyuan_video_image2video.py | 96 +++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py index 774b72e6c7..aa54c72a5d 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py @@ -100,6 +100,85 @@ DEFAULT_PROMPT_TEMPLATE = { } +def _merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, image_token_index, pad_token_id +): + num_images, num_image_patches, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + special_image_token_mask = input_ids == image_token_index + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + batch_indices, non_image_indices = torch.where(input_ids != image_token_index) + max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length + batch_indices, non_image_indices = torch.where(input_ids != image_token_index) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 + nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_image_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + image_to_overwrite = torch.full((batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device) + image_to_overwrite[batch_indices, text_to_overwrite] = False + if left_padding: + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + else: + mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 + padding_mask = mask <= new_token_positions[:, -1:].to(target_device) + image_to_overwrite &= padding_mask + + if image_to_overwrite.sum() != image_features.shape[:-1].numel(): + raise ValueError( + f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." + ) + + batch_indices, pad_indices = torch.where(input_ids == pad_token_id) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. + batch_indices, pad_indices = torch.where(input_ids == pad_token_id) + indices_to_mask = new_token_positions[batch_indices, pad_indices] + + final_embedding[batch_indices, indices_to_mask] = 0 + + return final_embedding, final_attention_mask, position_ids + + +def _text_encoder_custom_forward(): + return + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -279,13 +358,30 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader prompt_attention_mask = text_inputs.attention_mask.to(device=device) image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device) + input_embeds = self.text_encoder.get_input_embeddings()(text_input_ids) + inputs_embeds, attention_mask, position_ids = _merge_input_ids_with_image_features( + image_embeds, + input_embeds, + text_input_ids, + prompt_attention_mask, + self.text_encoder.config.image_token_index, + self.text_encoder.pad_token_id, + ) + + prompt_embeds = self.text_encoder.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + """ prompt_embeds = self.text_encoder( input_ids=text_input_ids, attention_mask=prompt_attention_mask, pixel_values=image_embeds, output_hidden_states=True, ).hidden_states[-(num_hidden_layers_to_skip + 1)] + """ prompt_embeds = prompt_embeds.to(dtype=dtype) image_emb_len = prompt_template.get("image_emb_len", 576)