1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add i2v pipeline

This commit is contained in:
yiyi@huggingface.co
2025-11-29 00:49:18 +00:00
parent 090ceb5d4f
commit e3301cbda4
7 changed files with 1032 additions and 26 deletions

View File

@@ -1,3 +1,4 @@
# to convert only transformer
"""
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
--original_state_dict_repo_id tencent/HunyuanVideo-1.5\
@@ -5,6 +6,7 @@ python scripts/convert_hunyuan_video1_5_to_diffusers.py \
--transformer_type 480p_t2v
"""
# to convert full pipeline
"""
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
--original_state_dict_repo_id tencent/HunyuanVideo-1.5\
@@ -23,8 +25,8 @@ from safetensors.torch import load_file
from huggingface_hub import snapshot_download, hf_hub_download
import pathlib
from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15, FlowMatchEulerDiscreteScheduler, ClassifierFreeGuidance, HunyuanVideo15Pipeline
from transformers import AutoModel, AutoTokenizer, T5EncoderModel, ByT5Tokenizer
from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15, FlowMatchEulerDiscreteScheduler, ClassifierFreeGuidance, HunyuanVideo15Pipeline, HunyuanVideo15Image2VideoPipeline, HunyuanVideo15Text2VideoPipeline
from transformers import AutoModel, AutoTokenizer, T5EncoderModel, ByT5Tokenizer, SiglipVisionModel, SiglipImageProcessor
import json
import argparse
@@ -812,6 +814,16 @@ def load_byt5(args):
return encoder, tokenizer
def load_siglip():
image_encoder = SiglipVisionModel.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", subfolder="image_encoder", torch_dtype=torch.bfloat16
)
feature_extractor = SiglipImageProcessor.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", subfolder="feature_extractor"
)
return image_encoder, feature_extractor
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -852,8 +864,9 @@ if __name__ == "__main__":
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True)
else:
vae = convert_vae(args)
task_type = transformer.config.task_type
vae = convert_vae(args)
text_encoder, tokenizer = load_mllm()
text_encoder_2, tokenizer_2 = load_byt5(args)
@@ -864,17 +877,35 @@ if __name__ == "__main__":
guidance_scale = GUIDANCE_CONFIGS[args.transformer_type]["guidance_scale"]
guider = ClassifierFreeGuidance(guidance_scale=guidance_scale)
pipeline = HunyuanVideo15Pipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
guider=guider,
scheduler=scheduler,
)
pipeline.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if task_type == "i2v":
image_encoder, feature_extractor = load_siglip()
pipeline = HunyuanVideo15Image2VideoPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
guider=guider,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
elif task_type == "t2v":
pipeline = HunyuanVideo15Text2VideoPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
guider=guider,
scheduler=scheduler,
)
else:
raise ValueError(f"Task type {task_type} is not supported")
pipeline.save_pretrained(args.output_path, safe_serialization=True)

View File

@@ -483,6 +483,7 @@ else:
"HunyuanVideoImageToVideoPipeline",
"HunyuanVideoPipeline",
"HunyuanVideo15Pipeline",
"HunyuanVideo15ImageToVideoPipeline",
"I2VGenXLPipeline",
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
@@ -1170,6 +1171,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
HunyuanVideo15Pipeline,
HunyuanVideo15ImageToVideoPipeline,
I2VGenXLPipeline,
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,

View File

@@ -242,7 +242,7 @@ else:
"HunyuanVideoImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
]
_import_structure["hunyuan_video1_5"] = ["HunyuanVideo15Pipeline"]
_import_structure["hunyuan_video1_5"] = ["HunyuanVideo15Pipeline", "HunyuanVideo15ImageToVideoPipeline"]
_import_structure["hunyuan_image"] = ["HunyuanImagePipeline", "HunyuanImageRefinerPipeline"]
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
@@ -663,7 +663,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
)
from .hunyuan_video1_5 import HunyuanVideo15Pipeline
from .hunyuan_video1_5 import HunyuanVideo15Pipeline, HunyuanVideo15ImageToVideoPipeline
from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline
from .kandinsky import (

View File

@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_hunyuan_video1_5"] = ["HunyuanVideo15Pipeline"]
_import_structure["pipeline_hunyuan_video1_5_image2video"] = ["HunyuanVideo15ImageToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline
from .pipeline_hunyuan_video1_5_image2video import HunyuanVideo15ImageToVideoPipeline
else:
import sys

View File

@@ -80,11 +80,13 @@ class HunyuanVideo15ImageProcessor(VideoProcessor):
do_resize: bool = True,
vae_scale_factor: int = 16,
vae_latent_channels: int = 32,
do_convert_rgb: bool = True,
):
super().__init__(
do_resize=do_resize,
vae_scale_factor=vae_scale_factor,
vae_latent_channels=vae_latent_channels
vae_latent_channels=vae_latent_channels,
do_convert_rgb=do_convert_rgb,
)

View File

@@ -759,26 +759,20 @@ class HunyuanVideo15Pipeline(DiffusionPipeline):
height,
width,
num_frames,
torch.float32,
self.transformer.dtype,
device,
generator,
latents,
)
cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents, torch.float32, device)
cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents, self.transformer.dtype, device)
image_embeds = torch.zeros(
batch_size,
self.vision_num_semantic_tokens,
self.vision_states_dim,
dtype=torch.float32,
dtype=self.transformer.dtype,
device=device
)
image_embeds = image_embeds.to(self.transformer.dtype)
latents=latents.to(self.transformer.dtype)
cond_latents_concat=cond_latents_concat.to(self.transformer.dtype)
mask_concat=mask_concat.to(self.transformer.dtype)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)

View File

@@ -0,0 +1,975 @@
# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import re
import numpy as np
import torch
from transformers import Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel, ByT5Tokenizer, SiglipVisionModel, SiglipImageProcessor
from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from .image_processor import HunyuanVideo15ImageProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HunyuanVideo15PipelineOutput
from ...guiders import ClassifierFreeGuidance
from ...utils.torch_utils import randn_tensor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import HunyuanVideo15Pipeline
>>> from diffusers.utils import export_to_video
>>> model_id = "hunyuanvideo-community/HunyuanVideo15"
>>> pipe = HunyuanVideo15Pipeline.from_pretrained(model_id, torch_dtype=torch.float16)
>>> pipe.vae.enable_tiling()
>>> pipe.to("cuda")
>>> output = pipe(
... prompt="A cat walks on the grass, realistic",
... num_inference_steps=50,
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=15)
```
"""
def format_text_input(prompt: List[str], system_message: str
) -> List[Dict[str, Any]]:
"""
Apply text to template.
Args:
prompt (List[str]): Input text.
system_message (str): System message.
Returns:
List[Dict[str, Any]]: List of chat conversation.
"""
template = [
[
{
'role': 'system',
'content': system_message},
{'role': 'user', 'content': p if p else " "}
]
for p in prompt]
return template
def extract_glyph_texts(prompt: str) -> List[str]:
"""
Extract glyph texts from prompt using regex pattern.
Args:
prompt: Input prompt string
Returns:
List of extracted glyph texts
"""
pattern = r'\"(.*?)\"|“(.*?)”'
matches = re.findall(pattern, prompt)
result = [match[0] or match[1] for match in matches]
result = list(dict.fromkeys(result)) if len(result) > 1 else result
if result:
formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". "
else:
formatted_result = None
return formatted_result
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class HunyuanVideo15Image2VideoPipeline(DiffusionPipeline):
r"""
Pipeline for image-to-video generation using HunyuanVideo1.5.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
transformer ([`HunyuanVideo15Transformer3DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded video latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
vae ([`AutoencoderKLHunyuanVideo15`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
text_encoder_2 ([`T5EncoderModel`]):
[T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel)
variant.
tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
guider ([`ClassifierFreeGuidance`]):
[ClassifierFreeGuidance]for classifier free guidance.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
text_encoder: Qwen2_5_VLTextModel,
tokenizer: Qwen2Tokenizer,
transformer: HunyuanVideo15Transformer3DModel,
vae: AutoencoderKLHunyuanVideo15,
scheduler: FlowMatchEulerDiscreteScheduler,
text_encoder_2: T5EncoderModel,
tokenizer_2: ByT5Tokenizer,
guider: ClassifierFreeGuidance,
image_encoder: SiglipVisionModel,
feature_extractor: SiglipImageProcessor,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
guider=guider,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16
self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial, do_resize=False, do_convert_rgb=True)
self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640
self.vision_states_dim = self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152
self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32
# fmt: off
self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \
1. The main content and theme of the video. \
2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
4. background environment, light, style and atmosphere. \
5. camera angles, movements, and transitions used in the video."
# fmt: on
self.prompt_template_encode_start_idx = 108
self.tokenizer_max_length = 1000
self.tokenizer_2_max_length = 256
self.vision_num_semantic_tokens = 729
@staticmethod
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_mllm_prompt_embeds
def _get_mllm_prompt_embeds(
text_encoder: Qwen2_5_VLTextModel,
tokenizer: Qwen2Tokenizer,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
tokenizer_max_length: int = 1000,
num_hidden_layers_to_skip: int = 2,
# fmt: off
system_message: str = "You are a helpful assistant. Describe the video by detailing the following aspects: \
1. The main content and theme of the video. \
2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
4. background environment, light, style and atmosphere. \
5. camera angles, movements, and transitions used in the video.",
# fmt: on
crop_start: int = 108,
) -> Tuple[torch.Tensor, torch.Tensor]:
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = format_text_input(prompt, system_message)
text_inputs = tokenizer.apply_chat_template(
prompt,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
padding="max_length",
max_length=tokenizer_max_length + crop_start,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device=device)
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
prompt_embeds = text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_attention_mask,
output_hidden_states=True,
).hidden_states[-(num_hidden_layers_to_skip + 1)]
if crop_start is not None and crop_start > 0:
prompt_embeds = prompt_embeds[:, crop_start:]
prompt_attention_mask = prompt_attention_mask[:, crop_start:]
return prompt_embeds, prompt_attention_mask
@staticmethod
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_byt5_prompt_embeds
def _get_byt5_prompt_embeds(
tokenizer: ByT5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
tokenizer_max_length: int = 256,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
glyph_texts = [extract_glyph_texts(p) for p in prompt]
prompt_embeds_list = []
prompt_embeds_mask_list = []
for glyph_text in glyph_texts:
if glyph_text is None:
glyph_text_embeds = torch.zeros(
(1, tokenizer_max_length, text_encoder.config.d_model), device=device, text_encoder.dtype
)
glyph_text_embeds_mask = torch.zeros(
(1, tokenizer_max_length), device=device, dtype=torch.int64
)
else:
txt_tokens = tokenizer(
glyph_text,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
).to(device)
glyph_text_embeds = text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask.float(),
)[0]
glyph_text_embeds = glyph_text_embeds.to(device=device)
glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device)
prompt_embeds_list.append(glyph_text_embeds)
prompt_embeds_mask_list.append(glyph_text_embeds_mask)
prompt_embeds = torch.cat(prompt_embeds_list, dim=0)
prompt_embeds_mask = torch.cat(prompt_embeds_mask_list, dim=0)
return prompt_embeds, prompt_embeds_mask
@staticmethod
def _get_vae_image_latents(
vae: AutoencoderKLHunyuanVideo15,
image_processor: HunyuanVideo15ImageProcessor,
image: PIL.Image.Image,
height: int,
width: int,
device: torch.device,
) -> torch.Tensor:
vae_dtype = self.vae.dtype
image_tensor = image_processor.preprocess(image, height=height, width=width).to(device, dtype=vae_dtype)
image_latents = retrieve_latents(vae.encode(image_tensor), sample_mode="argmax")
image_latents = image_latents * vae.config.scaling_factor
return image_latents
@staticmethod
def _get_image_embeds(
image_encoder: SiglipVisionModel,
feature_extractor: SiglipImageProcessor,
image: PIL.Image.Image,
device: torch.device,
) -> torch.Tensor:
image_encoder_dtype = next(image_encoder.parameters()).dtype
image = feature_extractor.preprocess(
images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True
)
image = image.to(device=device, dtype=image_encoder_dtype)
image_enc_hidden_states = image_encoder(**image).last_hidden_state
return image_enc_hidden_states
def encode_image(
self,
image: PIL.Image.Image,
batch_size: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
image_embeds = self._get_image_embeds(
image_encoder=self.image_encoder,
feature_extractor=self.feature_extractor,
image=image,
device=device,
)
image_embeds = image_embeds.repeat(batch_size, 1, 1)
image_embeds = image_embeds.to(device=device, dtype=dtype)
return image_embeds
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
batch_size: int = 1,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
batch_size (`int`):
batch size of prompts, defaults to 1
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
argument.
prompt_embeds_mask (`torch.Tensor`, *optional*):
Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
argument using self.tokenizer_2 and self.text_encoder_2.
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
argument using self.tokenizer_2 and self.text_encoder_2.
"""
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
if prompt is None:
prompt = [""] * batch_size
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_mllm_prompt_embeds(
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
prompt=prompt,
device=device,
tokenizer_max_length=self.tokenizer_max_length,
system_message=self.system_message,
crop_start=self.prompt_template_encode_start_idx,
)
if prompt_embeds_2 is None:
prompt_embeds_2, prompt_embeds_mask_2 = self._get_byt5_prompt_embeds(
tokenizer=self.tokenizer_2,
text_encoder=self.text_encoder_2,
prompt=prompt,
device=device,
tokenizer_max_length=self.tokenizer_2_max_length,
)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len)
_, seq_len_2, _ = prompt_embeds_2.shape
prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1)
prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_videos_per_prompt, seq_len_2, -1)
prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1)
prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2)
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
prompt_embeds_mask = prompt_embeds_mask.to(device=device, dtype=dtype)
prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dtype)
prompt_embeds_mask_2 = prompt_embeds_mask_2.to(device=device, dtype=dtype)
return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2
def check_inputs(
self,
prompt,
image: PIL.Image.Image,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_embeds_mask=None,
negative_prompt_embeds_mask=None,
prompt_embeds_2=None,
prompt_embeds_mask_2=None,
negative_prompt_embeds_2=None,
negative_prompt_embeds_mask_2=None,
):
if not isinstance(image, PIL.Image.Image):
raise ValueError(f"`image` has to be of type `PIL.Image.Image` but is {type(image)}")
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_embeds_mask is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if prompt is None and prompt_embeds_2 is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
)
if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None:
raise ValueError(
"If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`."
)
if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None:
raise ValueError(
"If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`."
)
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.prepare_latents
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int = 32,
height: int = 720,
width: int = 1280,
num_frames: int = 129,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (
batch_size,
num_channels_latents,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def prepare_cond_latents_and_mask(
self,
latents: torch.Tensor,
image: PIL.Image.Image,
batch_size: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
):
"""
Prepare conditional latents and mask for t2v generation.
Args:
latents: Main latents tensor (B, C, F, H, W)
Returns:
tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v
"""
batch, channels, frames, height, width = latents.shape
image_latents = self._get_vae_image_latents(
vae=self.vae,
image_processor=self.video_processor,
image=image,
height=height,
width=width,
device=device,
)
latent_condition = image_latents.repeat(batch_size, 1, frames, 1, 1)
latent_condition[:,:,1:, :, :] = 0
latent_condition = latent_condition.to(device=device, dtype=dtype)
latent_mask = torch.zeros(
batch, 1, frames, height, width,
dtype=dtype,
device=device
)
latent_mask[:,:, 0, :, :] = 1.0
return latent_condition, latent_mask
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: PIL.Image.Image,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
num_frames: int = 121,
num_inference_steps: int = 50,
sigmas: List[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
height (`int`, defaults to `720`):
The height in pixels of the generated image.
width (`int`, defaults to `1280`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `129`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
true_cfg_scale (`float`, *optional*, defaults to 1.0):
True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
`negative_prompt` is provided.
guidance_scale (`float`, defaults to `6.0`):
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
a model to generate images more aligned with `prompt` at the expense of lower image quality.
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~HunyuanVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
where the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
image=image,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
prompt_embeds_2=prompt_embeds_2,
prompt_embeds_mask_2=prompt_embeds_mask_2,
negative_prompt_embeds_2=negative_prompt_embeds_2,
negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
)
height, width = self.video_processor.calculate_default_height_width(height=image.size[1], width=image.size[0], target_size=self.target_size)
image = self.video_processor.resize(image, height=height, width=width, resize_mode="crop")
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
device = self._execution_device
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
batch_size=batch_size,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
prompt_embeds_2=prompt_embeds_2,
prompt_embeds_mask_2=prompt_embeds_mask_2,
)
if self.guider._enabled and self.guider.num_conditions >1 :
negative_prompt_embeds, negative_prompt_embeds_mask, negative_prompt_embeds_2, negative_prompt_embeds_mask_2 = self.encode_prompt(
prompt=negative_prompt,
device=device,
dtype=self.transformer.dtype,
batch_size=batch_size,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
prompt_embeds_2=negative_prompt_embeds_2,
prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
)
# 4. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
# 5. Prepare latent variables
latents = self.prepare_latents(
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=self.num_channels_latents,
height=height,
width=width,
num_frames=num_frames,
dtype=self.transformer.dtype,
device=device,
generator=generator,
latents=latents,
)
cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(
latents =latenets,
image=image,
batch_size=batch_size * num_videos_per_prompt,
height=height,
width=width,
dtype=self.transformer.dtype,
device=device
)
image_embeds = self.encode_image(
image=image,
batch_size=batch_size * num_videos_per_prompt,
device=device,
dtype=self.transformer.dtype,
)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents, cond_latents_concat, mask_concat], dim=1)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
# Step 1: Collect model inputs needed for the guidance method
# conditional inputs should always be first element in the tuple
guider_inputs = {
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
"encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
"encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2),
"encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2),
}
# Step 2: Update guider's internal state for this denoising step
self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
# Step 3: Prepare batched model inputs based on the guidance method
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = self.guider.prepare_inputs(guider_inputs)
# Step 4: Run the denoiser for each batch
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
for guider_state_batch in guider_state:
self.guider.prepare_models(self.transformer)
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
cond_kwargs = {
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
}
# e.g. "pred_cond"/"pred_uncond"
context_name = getattr(guider_state_batch, self.guider._identifier_key)
with self.transformer.cache_context(context_name):
# Run denoiser and store noise prediction in this batch
guider_state_batch.noise_pred = self.transformer(
hidden_states=latent_model_input,
image_embeds=image_embeds,
timestep=timestep,
attention_kwargs=self.attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
# Cleanup model (e.g., remove hooks)
self.guider.cleanup_models(self.transformer)
# Step 5: Combine predictions using the guidance method
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
# Continuing the CFG example, the guider receives:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
# ]
# And extracts predictions using the __guidance_identifier__:
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
# Then applies CFG formula:
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
noise_pred = self.guider(guider_state)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return HunyuanVideo15PipelineOutput(frames=video)